test save_net
This commit is contained in:
parent
6aa2a6c92e
commit
5e4eb03b89
@ -503,7 +503,7 @@ class FTS(object):
|
||||
params = [ None for k in self.transformations]
|
||||
|
||||
for c, t in enumerate(self.transformations, start=0):
|
||||
ndata = t.apply(ndata,params[c])
|
||||
ndata = t.apply(ndata, params[c], )
|
||||
|
||||
return ndata
|
||||
|
||||
|
72
pyFTS/common/transformations/som.py
Normal file
72
pyFTS/common/transformations/som.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""
|
||||
Kohonen Self Organizing Maps for Fuzzy Time Series
|
||||
"""
|
||||
import pandas as pd
|
||||
import SimpSOM as sps
|
||||
from pyFTS.models.multivariate import wmvfts
|
||||
from typing import Tuple
|
||||
from pyFTS.common.Transformations import Transformation
|
||||
|
||||
|
||||
class SOMTransformation(Transformation):
|
||||
def __init__(self,
|
||||
grid_dimension: Tuple,
|
||||
**kwargs):
|
||||
# SOM attributes
|
||||
self.net: sps.somNet = None
|
||||
self.data: pd.DataFrame = None
|
||||
self.grid_dimension: Tuple = grid_dimension
|
||||
self.pbc = kwargs.get('PBC', True)
|
||||
|
||||
# debug attributes
|
||||
self.name = 'Kohonen Self Organizing Maps FTS'
|
||||
self.shortname = 'SOM-FTS'
|
||||
|
||||
# def apply(self, data, endogen_variable, param, **kwargs): #TODO(CASCALHO) MELHORAR DOCSTRING
|
||||
# """
|
||||
# Transform dataset from M-DIMENSION to 3-dimension
|
||||
# """
|
||||
# pass
|
||||
|
||||
def __repr__(self):
|
||||
status = "is trained" if self.is_trained else "not trained"
|
||||
return f'{self.name}-{status}'
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def __del__(self):
|
||||
del self.net
|
||||
|
||||
def train(self,
|
||||
data: pd.DataFrame,
|
||||
percentage_train: float = .7,
|
||||
leaning_rate: float = 0.01,
|
||||
epochs: int = 10000):
|
||||
self.data = data.values
|
||||
limit = round(len(self.data) * percentage_train)
|
||||
train = self.data[:limit]
|
||||
x, y = self.grid_dimension
|
||||
self.net = sps.somNet(x, y, train, PBC=self.pbc)
|
||||
self.net.train(startLearnRate=leaning_rate,
|
||||
epochs=epochs)
|
||||
|
||||
def save_net(self,
|
||||
filename: str = "SomNet trained"):
|
||||
self.net.save(filename)
|
||||
|
||||
def show_grid(self,
|
||||
graph_type: str = 'nodes_graph',
|
||||
**kwargs):
|
||||
if graph_type == 'nodes_graph':
|
||||
colnum = kwargs.get('colnum', 0)
|
||||
self.net.nodes_graph(colnum=colnum)
|
||||
else:
|
||||
self.net.diff_graph()
|
||||
|
||||
|
||||
"""
|
||||
Requisitos
|
||||
- apply(herdado de transformations): transforma os conjunto de dados
|
||||
- inverse - não é necessária
|
||||
"""
|
@ -7,7 +7,7 @@ from pyFTS.models.multivariate import wmvfts
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class SOMFTS:
|
||||
class SOMPartitioner:
|
||||
def __init__(self,
|
||||
grid_dimension: Tuple,
|
||||
**kwargs):
|
||||
@ -17,10 +17,6 @@ class SOMFTS:
|
||||
self.grid_dimension: Tuple = grid_dimension
|
||||
self.pbc = kwargs.get('PBC', True)
|
||||
|
||||
# fts attributes
|
||||
self.fts_method = kwargs.get('fts_method', wmvfts.WeightedMVFTS)
|
||||
self.order = kwargs.get('order', 2)
|
||||
self.is_trained = False
|
||||
|
||||
# debug attributes
|
||||
self.name = 'Kohonen Self Organizing Maps FTS'
|
||||
@ -60,4 +56,12 @@ class SOMFTS:
|
||||
colnum = kwargs.get('colnum', 0)
|
||||
self.net.nodes_graph(colnum=colnum)
|
||||
else:
|
||||
self.net.diff_graph()
|
||||
self.net.diff_graph()
|
||||
|
||||
|
||||
|
||||
"""
|
||||
Requisitos
|
||||
|
||||
|
||||
"""
|
47
pyFTS/tests/test_SOMTransformation.py
Normal file
47
pyFTS/tests/test_SOMTransformation.py
Normal file
@ -0,0 +1,47 @@
|
||||
import unittest
|
||||
from pyFTS.common.transformations.som import SOMTransformation
|
||||
import pandas as pd
|
||||
import os
|
||||
|
||||
class MyTestCase(unittest.TestCase):
|
||||
def test_apply(self):
|
||||
self.assertEqual(True, False)
|
||||
|
||||
def test_save_net(self):
|
||||
som_transformer = self.som_transformer_trained()
|
||||
|
||||
filename = 'test_net.npy'
|
||||
som_transformer.save_net(filename)
|
||||
files = os.listdir()
|
||||
|
||||
if filename in files:
|
||||
is_in_files = True
|
||||
os.remove(filename)
|
||||
else:
|
||||
is_in_files = False
|
||||
|
||||
self.assertEqual(True, is_in_files)
|
||||
|
||||
def test_train(self):
|
||||
self.assertEqual()
|
||||
|
||||
@staticmethod
|
||||
def simple_dataset():
|
||||
data = [
|
||||
[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 0, 0],
|
||||
[1, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0],
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
return df
|
||||
|
||||
def som_transformer_trained(self):
|
||||
data = self.simple_dataset()
|
||||
som_transformer = SOMTransformation(grid_dimension=(2, 2))
|
||||
som_transformer.train(data=data, epochs=100)
|
||||
return som_transformer
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user