adaptação da FTS para receber métodos de transformação multivariada

This commit is contained in:
matheus_cascalho 2020-12-09 19:35:22 -03:00
parent 81036e506f
commit ba8bf1c4ea
3 changed files with 15 additions and 2 deletions

View File

@ -15,6 +15,8 @@ class Transformation(object):
def __init__(self, **kwargs):
self.is_invertible = True
self.is_multivariate = False
"""detemine if this transformation can be applied to multivariate data"""
self.minimal_length = 1
self.name = ''

View File

@ -18,6 +18,7 @@ class SOMTransformation(Transformation):
self.data: pd.DataFrame = None
self.grid_dimension: Tuple = grid_dimension
self.pbc = kwargs.get('PBC', True)
self.is_multivariate = True
# debug attributes
self.name = 'Kohonen Self Organizing Maps FTS'
@ -25,8 +26,6 @@ class SOMTransformation(Transformation):
def apply(self,
data: pd.DataFrame,
endogen_variable=None,
names: Tuple[str] = ('x', 'y'),
param=None,
**kwargs):
"""
@ -45,6 +44,9 @@ class SOMTransformation(Transformation):
"""
endogen_variable = kwargs.get('endogen_variable', None)
names = kwargs.get('names', ('x', 'y'))
if endogen_variable not in data.columns:
endogen_variable = None
cols = data.columns[:-1] if endogen_variable is None else [col for col in data.columns if

View File

@ -35,6 +35,12 @@ class MVFTS(fts.FTS):
self.name = "Multivariate FTS"
self.uod_clip = False
def append_transformation(self, transformation, **kwargs):
if not transformation.is_multivariate:
raise Exception('The transformation is not multivariate')
self.transformations.append(transformation)
self.transformations_param.append(kwargs)
def append_variable(self, var):
"""
Append a new endogenous variable to the model
@ -53,6 +59,9 @@ class MVFTS(fts.FTS):
def apply_transformations(self, data, params=None, updateUoD=False, **kwargs):
ndata = data.copy(deep=True)
for ct, transformation in enumerate(self.transformations):
ndata = transformation.apply(ndata, **self.transformations_param[ct])
for var in self.explanatory_variables:
try:
values = ndata[var.data_label].values #if isinstance(ndata, pd.DataFrame) else ndata[var.data_label]