bugfix in Measures.get_point_statistics for supporting multivariate models
This commit is contained in:
parent
d15814b545
commit
393388f722
@ -324,7 +324,9 @@ def get_point_statistics(data, model, **kwargs):
|
|||||||
if steps_ahead == 1:
|
if steps_ahead == 1:
|
||||||
forecasts = model.predict(ndata, **kwargs)
|
forecasts = model.predict(ndata, **kwargs)
|
||||||
|
|
||||||
if model.is_multivariate:
|
if model.is_multivariate and model.has_seasonality:
|
||||||
|
ndata = model.indexer.get_data(ndata)
|
||||||
|
elif model.is_multivariate:
|
||||||
ndata = ndata[model.target_variable.data_label].values
|
ndata = ndata[model.target_variable.data_label].values
|
||||||
|
|
||||||
if not isinstance(forecasts, (list, np.ndarray)):
|
if not isinstance(forecasts, (list, np.ndarray)):
|
||||||
|
@ -31,7 +31,15 @@ class SeasonalIndexer(object):
|
|||||||
|
|
||||||
|
|
||||||
class LinearSeasonalIndexer(SeasonalIndexer):
|
class LinearSeasonalIndexer(SeasonalIndexer):
|
||||||
|
"""Use the data array/list position to index the seasonality """
|
||||||
def __init__(self, seasons, units, ignore=None, **kwargs):
|
def __init__(self, seasons, units, ignore=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Indexer for array/list position
|
||||||
|
:param seasons: A list with the season group (i.e: 7 for week, 30 for month, etc)
|
||||||
|
:param units: A list with the units used for each season group, the default is 1 for each
|
||||||
|
:param ignore:
|
||||||
|
:param kwargs:
|
||||||
|
"""
|
||||||
super(LinearSeasonalIndexer, self).__init__(len(seasons), **kwargs)
|
super(LinearSeasonalIndexer, self).__init__(len(seasons), **kwargs)
|
||||||
self.seasons = seasons
|
self.seasons = seasons
|
||||||
self.units = units
|
self.units = units
|
||||||
@ -81,11 +89,19 @@ class LinearSeasonalIndexer(SeasonalIndexer):
|
|||||||
|
|
||||||
|
|
||||||
class DataFrameSeasonalIndexer(SeasonalIndexer):
|
class DataFrameSeasonalIndexer(SeasonalIndexer):
|
||||||
def __init__(self,index_fields,index_seasons, data_fields,**kwargs):
|
"""Use the Pandas.DataFrame index position to index the seasonality """
|
||||||
|
def __init__(self,index_fields,index_seasons, data_field,**kwargs):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param index_fields: DataFrame field to use as index
|
||||||
|
:param index_seasons: A list with the season group, i. e., multiples of positions that are considered a season (i.e: 7 for week, 30 for month, etc)
|
||||||
|
:param data_fields: DataFrame field to use as data
|
||||||
|
:param kwargs:
|
||||||
|
"""
|
||||||
super(DataFrameSeasonalIndexer, self).__init__(len(index_seasons), **kwargs)
|
super(DataFrameSeasonalIndexer, self).__init__(len(index_seasons), **kwargs)
|
||||||
self.fields = index_fields
|
self.fields = index_fields
|
||||||
self.seasons = index_seasons
|
self.seasons = index_seasons
|
||||||
self.data_fields = data_fields
|
self.data_field = data_field
|
||||||
|
|
||||||
def get_season_of_data(self,data):
|
def get_season_of_data(self,data):
|
||||||
#data = data.copy()
|
#data = data.copy()
|
||||||
@ -111,25 +127,34 @@ class DataFrameSeasonalIndexer(SeasonalIndexer):
|
|||||||
data = data[data[f]== season[c]]
|
data = data[data[f]== season[c]]
|
||||||
else:
|
else:
|
||||||
data = data[(data[f] // self.seasons[c]) == season[c]]
|
data = data[(data[f] // self.seasons[c]) == season[c]]
|
||||||
return data[self.data_fields]
|
return data[self.data_field]
|
||||||
|
|
||||||
def get_index_by_season(self, indexes):
|
def get_index_by_season(self, indexes):
|
||||||
raise Exception("Operation not available!")
|
raise Exception("Operation not available!")
|
||||||
|
|
||||||
def get_data(self, data):
|
def get_data(self, data):
|
||||||
return data[self.data_fields].tolist()
|
return data[self.data_field].tolist()
|
||||||
|
|
||||||
def set_data(self, data, value):
|
def set_data(self, data, value):
|
||||||
data.loc[:,self.data_fields] = value
|
data.loc[:,self.data_field] = value
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class DateTimeSeasonalIndexer(SeasonalIndexer):
|
class DateTimeSeasonalIndexer(SeasonalIndexer):
|
||||||
def __init__(self,date_field, index_fields, index_seasons, data_fields,**kwargs):
|
"""Use a Pandas.DataFrame date field to index the seasonality """
|
||||||
|
def __init__(self,date_field, index_fields, index_seasons, data_field,**kwargs):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param date_field: DataFrame field that contains the datetime field used on index
|
||||||
|
:param index_fields: List with commom.DataTime fields
|
||||||
|
:param index_seasons: Multiples of index_fields, the default is 1
|
||||||
|
:param data_field: DataFrame field with the time series data
|
||||||
|
:param kwargs:
|
||||||
|
"""
|
||||||
super(DateTimeSeasonalIndexer, self).__init__(len(index_seasons), **kwargs)
|
super(DateTimeSeasonalIndexer, self).__init__(len(index_seasons), **kwargs)
|
||||||
self.fields = index_fields
|
self.fields = index_fields
|
||||||
self.seasons = index_seasons
|
self.seasons = index_seasons
|
||||||
self.data_fields = data_fields
|
self.data_field = data_field
|
||||||
self.date_field = date_field
|
self.date_field = date_field
|
||||||
|
|
||||||
def get_season_of_data(self, data):
|
def get_season_of_data(self, data):
|
||||||
@ -151,7 +176,7 @@ class DateTimeSeasonalIndexer(SeasonalIndexer):
|
|||||||
date = data[self.date_field]
|
date = data[self.date_field]
|
||||||
season = []
|
season = []
|
||||||
for c, f in enumerate(self.fields, start=0):
|
for c, f in enumerate(self.fields, start=0):
|
||||||
season.append(self.strip_datepart(date, f, self.seasons[c]))
|
season.append(common.strip_datepart(date, f, self.seasons[c]))
|
||||||
ret.append(season)
|
ret.append(season)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
@ -166,10 +191,12 @@ class DateTimeSeasonalIndexer(SeasonalIndexer):
|
|||||||
raise Exception("Operation not available!")
|
raise Exception("Operation not available!")
|
||||||
|
|
||||||
def get_data(self, data):
|
def get_data(self, data):
|
||||||
return data[self.data_fields].tolist()
|
return data[self.data_field].tolist()
|
||||||
|
|
||||||
def get_index(self, data):
|
def get_index(self, data):
|
||||||
return data[self.date_field].tolist() if isinstance(data, pd.DataFrame) else data[self.date_field]
|
return data[self.date_field].tolist() if isinstance(data, pd.DataFrame) else data[self.date_field]
|
||||||
|
|
||||||
def set_data(self, data, value):
|
def set_data(self, data, value):
|
||||||
raise Exception("Operation not available!")
|
raise Exception("Operation not available!")
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user