diff --git a/pyFTS/models/ensemble/ensemble.py b/pyFTS/models/ensemble/ensemble.py index b2b8715..2bfd0ff 100644 --- a/pyFTS/models/ensemble/ensemble.py +++ b/pyFTS/models/ensemble/ensemble.py @@ -31,21 +31,33 @@ class EnsembleFTS(fts.FTS): self.alpha = kwargs.get("alpha", 0.05) self.point_method = kwargs.get('point_method', 'mean') self.interval_method = kwargs.get('interval_method', 'quantile') + self.order = 1 def append_model(self, model): self.models.append(model) if model.order > self.order: self.order = model.order + if model.is_multivariate: + self.is_multivariate = True + + if model.has_seasonality: + self.has_seasonality = True + + def train(self, data, **kwargs): pass def get_models_forecasts(self,data): tmp = [] for model in self.models: - if self.is_multivariate or self.has_seasonality: + if model.is_multivariate or model.has_seasonality: forecast = model.forecast(data) else: + + if isinstance(data, pd.DataFrame) and self.indexer is not None: + data = self.indexer.get_data(data) + sample = data[-model.order:] forecast = model.forecast(sample) if isinstance(forecast, (list,np.ndarray)) and len(forecast) > 0: