From c7ee8c3cfe957926f644f1c545ffd99fb4d10633 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petr=C3=B4nio=20C=C3=A2ndido?= Date: Fri, 29 Jun 2018 16:32:21 -0300 Subject: [PATCH] Mixing univariate and multivariate models in EnsembleFTS --- pyFTS/models/ensemble/ensemble.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pyFTS/models/ensemble/ensemble.py b/pyFTS/models/ensemble/ensemble.py index b2b8715..2bfd0ff 100644 --- a/pyFTS/models/ensemble/ensemble.py +++ b/pyFTS/models/ensemble/ensemble.py @@ -31,21 +31,33 @@ class EnsembleFTS(fts.FTS): self.alpha = kwargs.get("alpha", 0.05) self.point_method = kwargs.get('point_method', 'mean') self.interval_method = kwargs.get('interval_method', 'quantile') + self.order = 1 def append_model(self, model): self.models.append(model) if model.order > self.order: self.order = model.order + if model.is_multivariate: + self.is_multivariate = True + + if model.has_seasonality: + self.has_seasonality = True + + def train(self, data, **kwargs): pass def get_models_forecasts(self,data): tmp = [] for model in self.models: - if self.is_multivariate or self.has_seasonality: + if model.is_multivariate or model.has_seasonality: forecast = model.forecast(data) else: + + if isinstance(data, pd.DataFrame) and self.indexer is not None: + data = self.indexer.get_data(data) + sample = data[-model.order:] forecast = model.forecast(sample) if isinstance(forecast, (list,np.ndarray)) and len(forecast) > 0: