treinamento da rede
This commit is contained in:
parent
e2438afee3
commit
6aa2a6c92e
63
pyFTS/models/multivariate/som.py
Normal file
63
pyFTS/models/multivariate/som.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
"""
|
||||||
|
Kohonen Self Organizing Maps for Fuzzy Time Series
|
||||||
|
"""
|
||||||
|
import pandas as pd
|
||||||
|
import SimpSOM as sps
|
||||||
|
from pyFTS.models.multivariate import wmvfts
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class SOMFTS:
|
||||||
|
def __init__(self,
|
||||||
|
grid_dimension: Tuple,
|
||||||
|
**kwargs):
|
||||||
|
# SOM attributes
|
||||||
|
self.net: sps.somNet = None
|
||||||
|
self.data: pd.DataFrame = None
|
||||||
|
self.grid_dimension: Tuple = grid_dimension
|
||||||
|
self.pbc = kwargs.get('PBC', True)
|
||||||
|
|
||||||
|
# fts attributes
|
||||||
|
self.fts_method = kwargs.get('fts_method', wmvfts.WeightedMVFTS)
|
||||||
|
self.order = kwargs.get('order', 2)
|
||||||
|
self.is_trained = False
|
||||||
|
|
||||||
|
# debug attributes
|
||||||
|
self.name = 'Kohonen Self Organizing Maps FTS'
|
||||||
|
self.shortname = 'SOM-FTS'
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
status = "is trained" if self.is_trained else "not trained"
|
||||||
|
return f'{self.name}-{status}'
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
del self.net
|
||||||
|
|
||||||
|
def train(self,
|
||||||
|
data: pd.DataFrame,
|
||||||
|
percentage_train: float = .7,
|
||||||
|
leaning_rate: float = 0.01,
|
||||||
|
epochs: int = 10000):
|
||||||
|
self.data = data
|
||||||
|
limit = len(self.data) * percentage_train
|
||||||
|
train = data[:limit]
|
||||||
|
x, y = self.grid_dimension
|
||||||
|
self.net = sps.somNet(x, y, train, self.pbc)
|
||||||
|
self.net.train(startLearnRate=leaning_rate,
|
||||||
|
epochs=epochs)
|
||||||
|
|
||||||
|
def save_net(self,
|
||||||
|
filename: str = "SomNet trained"):
|
||||||
|
self.net.save(filename)
|
||||||
|
|
||||||
|
def show_grid(self,
|
||||||
|
graph_type: str = 'nodes_graph',
|
||||||
|
**kwargs):
|
||||||
|
if graph_type == 'nodes_graph':
|
||||||
|
colnum = kwargs.get('colnum', 0)
|
||||||
|
self.net.nodes_graph(colnum=colnum)
|
||||||
|
else:
|
||||||
|
self.net.diff_graph()
|
Loading…
Reference in New Issue
Block a user