Mixing univariate and multivariate models in EnsembleFTS
This commit is contained in:
parent
3580f0b4b3
commit
c7ee8c3cfe
@ -31,21 +31,33 @@ class EnsembleFTS(fts.FTS):
|
|||||||
self.alpha = kwargs.get("alpha", 0.05)
|
self.alpha = kwargs.get("alpha", 0.05)
|
||||||
self.point_method = kwargs.get('point_method', 'mean')
|
self.point_method = kwargs.get('point_method', 'mean')
|
||||||
self.interval_method = kwargs.get('interval_method', 'quantile')
|
self.interval_method = kwargs.get('interval_method', 'quantile')
|
||||||
|
self.order = 1
|
||||||
|
|
||||||
def append_model(self, model):
|
def append_model(self, model):
|
||||||
self.models.append(model)
|
self.models.append(model)
|
||||||
if model.order > self.order:
|
if model.order > self.order:
|
||||||
self.order = model.order
|
self.order = model.order
|
||||||
|
|
||||||
|
if model.is_multivariate:
|
||||||
|
self.is_multivariate = True
|
||||||
|
|
||||||
|
if model.has_seasonality:
|
||||||
|
self.has_seasonality = True
|
||||||
|
|
||||||
|
|
||||||
def train(self, data, **kwargs):
|
def train(self, data, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_models_forecasts(self,data):
|
def get_models_forecasts(self,data):
|
||||||
tmp = []
|
tmp = []
|
||||||
for model in self.models:
|
for model in self.models:
|
||||||
if self.is_multivariate or self.has_seasonality:
|
if model.is_multivariate or model.has_seasonality:
|
||||||
forecast = model.forecast(data)
|
forecast = model.forecast(data)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
if isinstance(data, pd.DataFrame) and self.indexer is not None:
|
||||||
|
data = self.indexer.get_data(data)
|
||||||
|
|
||||||
sample = data[-model.order:]
|
sample = data[-model.order:]
|
||||||
forecast = model.forecast(sample)
|
forecast = model.forecast(sample)
|
||||||
if isinstance(forecast, (list,np.ndarray)) and len(forecast) > 0:
|
if isinstance(forecast, (list,np.ndarray)) and len(forecast) > 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user