diff --git a/pyFTS/common/Transformations.py b/pyFTS/common/Transformations.py index aa18b71..ad78545 100644 --- a/pyFTS/common/Transformations.py +++ b/pyFTS/common/Transformations.py @@ -15,6 +15,8 @@ class Transformation(object): def __init__(self, **kwargs): self.is_invertible = True + self.is_multivariate = False + """detemine if this transformation can be applied to multivariate data""" self.minimal_length = 1 self.name = '' diff --git a/pyFTS/common/transformations/som.py b/pyFTS/common/transformations/som.py index a5aee39..a965ae6 100644 --- a/pyFTS/common/transformations/som.py +++ b/pyFTS/common/transformations/som.py @@ -18,6 +18,7 @@ class SOMTransformation(Transformation): self.data: pd.DataFrame = None self.grid_dimension: Tuple = grid_dimension self.pbc = kwargs.get('PBC', True) + self.is_multivariate = True # debug attributes self.name = 'Kohonen Self Organizing Maps FTS' @@ -25,8 +26,6 @@ class SOMTransformation(Transformation): def apply(self, data: pd.DataFrame, - endogen_variable=None, - names: Tuple[str] = ('x', 'y'), param=None, **kwargs): """ @@ -45,6 +44,9 @@ class SOMTransformation(Transformation): """ + endogen_variable = kwargs.get('endogen_variable', None) + names = kwargs.get('names', ('x', 'y')) + if endogen_variable not in data.columns: endogen_variable = None cols = data.columns[:-1] if endogen_variable is None else [col for col in data.columns if diff --git a/pyFTS/models/multivariate/mvfts.py b/pyFTS/models/multivariate/mvfts.py index 632b5db..aba3afb 100644 --- a/pyFTS/models/multivariate/mvfts.py +++ b/pyFTS/models/multivariate/mvfts.py @@ -35,6 +35,12 @@ class MVFTS(fts.FTS): self.name = "Multivariate FTS" self.uod_clip = False + def append_transformation(self, transformation, **kwargs): + if not transformation.is_multivariate: + raise Exception('The transformation is not multivariate') + self.transformations.append(transformation) + self.transformations_param.append(kwargs) + def append_variable(self, var): """ Append a new endogenous variable to the model @@ -53,6 +59,9 @@ class MVFTS(fts.FTS): def apply_transformations(self, data, params=None, updateUoD=False, **kwargs): ndata = data.copy(deep=True) + for ct, transformation in enumerate(self.transformations): + ndata = transformation.apply(ndata, **self.transformations_param[ct]) + for var in self.explanatory_variables: try: values = ndata[var.data_label].values #if isinstance(ndata, pd.DataFrame) else ndata[var.data_label]