From 41818258b227e8b1af8e50546a85d00e5188ecda Mon Sep 17 00:00:00 2001 From: matheus_cascalho Date: Wed, 2 Dec 2020 18:19:54 -0300 Subject: [PATCH] teste apply --- pyFTS/common/transformations/som.py | 27 ++++++++++++++++++++------- pyFTS/tests/test_SOMTransformation.py | 16 +++++++++++++++- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/pyFTS/common/transformations/som.py b/pyFTS/common/transformations/som.py index af66c26..84e4773 100644 --- a/pyFTS/common/transformations/som.py +++ b/pyFTS/common/transformations/som.py @@ -6,7 +6,7 @@ import SimpSOM as sps from pyFTS.models.multivariate import wmvfts from typing import Tuple from pyFTS.common.Transformations import Transformation - +from typing import List class SOMTransformation(Transformation): def __init__(self, @@ -22,14 +22,27 @@ class SOMTransformation(Transformation): self.name = 'Kohonen Self Organizing Maps FTS' self.shortname = 'SOM-FTS' - # def apply(self, data, endogen_variable, param, **kwargs): #TODO(CASCALHO) MELHORAR DOCSTRING - # """ - # Transform dataset from M-DIMENSION to 3-dimension - # """ - # pass + def apply(self, + data: pd.DataFrame, + endogen_variable=None, + names: List[str] = ['x', 'y'], + param=None, + **kwargs): #TODO(CASCALHO) MELHORAR DOCSTRING + """ + Transform dataset from M-DIMENSION to 3-dimension + """ + if self.net is None: + cols = data.columns[:-1] + train = data[cols] + self.train(data=train) + new_data = self.net.project(data.values) + new_data = pd.DataFrame(new_data, columns=names) + endogen = endogen_variable if endogen_variable is not None else data.columns[-1] + new_data[endogen] = data[endogen] + return new_data def __repr__(self): - status = "is trained" if self.is_trained else "not trained" + status = "is trained" if self.net is not None else "not trained" return f'{self.name}-{status}' def __str__(self): diff --git a/pyFTS/tests/test_SOMTransformation.py b/pyFTS/tests/test_SOMTransformation.py index e1fab22..253263f 100644 --- a/pyFTS/tests/test_SOMTransformation.py +++ b/pyFTS/tests/test_SOMTransformation.py @@ -2,10 +2,22 @@ import unittest from pyFTS.common.transformations.som import SOMTransformation import pandas as pd import os +import numpy as np class MyTestCase(unittest.TestCase): def test_apply(self): - self.assertEqual(True, False) + data = [ + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + ] + som = self.som_transformer_trained() + transformed = som.apply(data=pd.DataFrame(data)) + uniques = np.unique(transformed) + + self.assertEqual(1, len(uniques.shape)) + self.assertEqual(3, transformed.values.shape[1]) def test_save_net(self): som_transformer = self.som_transformer_trained() @@ -22,6 +34,8 @@ class MyTestCase(unittest.TestCase): self.assertEqual(True, is_in_files) + # def + def test_train(self): self.assertEqual()