diff --git a/pyFTS/benchmarks/Measures.py b/pyFTS/benchmarks/Measures.py index 72bf9a7..77e3de6 100644 --- a/pyFTS/benchmarks/Measures.py +++ b/pyFTS/benchmarks/Measures.py @@ -324,7 +324,9 @@ def get_point_statistics(data, model, **kwargs): if steps_ahead == 1: forecasts = model.predict(ndata, **kwargs) - if model.is_multivariate: + if model.is_multivariate and model.has_seasonality: + ndata = model.indexer.get_data(ndata) + elif model.is_multivariate: ndata = ndata[model.target_variable.data_label].values if not isinstance(forecasts, (list, np.ndarray)): diff --git a/pyFTS/models/seasonal/SeasonalIndexer.py b/pyFTS/models/seasonal/SeasonalIndexer.py index d97878b..470d90f 100644 --- a/pyFTS/models/seasonal/SeasonalIndexer.py +++ b/pyFTS/models/seasonal/SeasonalIndexer.py @@ -31,7 +31,15 @@ class SeasonalIndexer(object): class LinearSeasonalIndexer(SeasonalIndexer): + """Use the data array/list position to index the seasonality """ def __init__(self, seasons, units, ignore=None, **kwargs): + """ + Indexer for array/list position + :param seasons: A list with the season group (i.e: 7 for week, 30 for month, etc) + :param units: A list with the units used for each season group, the default is 1 for each + :param ignore: + :param kwargs: + """ super(LinearSeasonalIndexer, self).__init__(len(seasons), **kwargs) self.seasons = seasons self.units = units @@ -81,11 +89,19 @@ class LinearSeasonalIndexer(SeasonalIndexer): class DataFrameSeasonalIndexer(SeasonalIndexer): - def __init__(self,index_fields,index_seasons, data_fields,**kwargs): + """Use the Pandas.DataFrame index position to index the seasonality """ + def __init__(self,index_fields,index_seasons, data_field,**kwargs): + """ + + :param index_fields: DataFrame field to use as index + :param index_seasons: A list with the season group, i. e., multiples of positions that are considered a season (i.e: 7 for week, 30 for month, etc) + :param data_fields: DataFrame field to use as data + :param kwargs: + """ super(DataFrameSeasonalIndexer, self).__init__(len(index_seasons), **kwargs) self.fields = index_fields self.seasons = index_seasons - self.data_fields = data_fields + self.data_field = data_field def get_season_of_data(self,data): #data = data.copy() @@ -111,25 +127,34 @@ class DataFrameSeasonalIndexer(SeasonalIndexer): data = data[data[f]== season[c]] else: data = data[(data[f] // self.seasons[c]) == season[c]] - return data[self.data_fields] + return data[self.data_field] def get_index_by_season(self, indexes): raise Exception("Operation not available!") def get_data(self, data): - return data[self.data_fields].tolist() + return data[self.data_field].tolist() def set_data(self, data, value): - data.loc[:,self.data_fields] = value + data.loc[:,self.data_field] = value return data class DateTimeSeasonalIndexer(SeasonalIndexer): - def __init__(self,date_field, index_fields, index_seasons, data_fields,**kwargs): + """Use a Pandas.DataFrame date field to index the seasonality """ + def __init__(self,date_field, index_fields, index_seasons, data_field,**kwargs): + """ + + :param date_field: DataFrame field that contains the datetime field used on index + :param index_fields: List with commom.DataTime fields + :param index_seasons: Multiples of index_fields, the default is 1 + :param data_field: DataFrame field with the time series data + :param kwargs: + """ super(DateTimeSeasonalIndexer, self).__init__(len(index_seasons), **kwargs) self.fields = index_fields self.seasons = index_seasons - self.data_fields = data_fields + self.data_field = data_field self.date_field = date_field def get_season_of_data(self, data): @@ -151,7 +176,7 @@ class DateTimeSeasonalIndexer(SeasonalIndexer): date = data[self.date_field] season = [] for c, f in enumerate(self.fields, start=0): - season.append(self.strip_datepart(date, f, self.seasons[c])) + season.append(common.strip_datepart(date, f, self.seasons[c])) ret.append(season) return ret @@ -166,10 +191,12 @@ class DateTimeSeasonalIndexer(SeasonalIndexer): raise Exception("Operation not available!") def get_data(self, data): - return data[self.data_fields].tolist() + return data[self.data_field].tolist() def get_index(self, data): return data[self.date_field].tolist() if isinstance(data, pd.DataFrame) else data[self.date_field] def set_data(self, data, value): - raise Exception("Operation not available!") \ No newline at end of file + raise Exception("Operation not available!") + +