From 0b03fbfa57ff361248a5eb4e69e5e1b2d3fb4e48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petr=C3=B4nio=20C=C3=A2ndido?= Date: Fri, 5 Jul 2019 08:02:33 -0300 Subject: [PATCH] Bugfix in forecast_ahead --- pyFTS/common/fts.py | 6 +++--- pyFTS/tests/multivariate.py | 15 +++++++++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/pyFTS/common/fts.py b/pyFTS/common/fts.py index e1f61e8..25a2d9f 100644 --- a/pyFTS/common/fts.py +++ b/pyFTS/common/fts.py @@ -239,9 +239,9 @@ class FTS(object): start = kwargs.get('start_at',0) - ret = [] + ret = data[:start+self.max_lag] for k in np.arange(start+self.max_lag, steps+start+self.max_lag): - tmp = self.forecast(data[k-self.max_lag:k], **kwargs) + tmp = self.forecast(ret[k-self.max_lag:k], **kwargs) if isinstance(tmp,(list, np.ndarray)): tmp = tmp[-1] @@ -364,7 +364,7 @@ class FTS(object): if dump == 'time': print("[{0: %H:%M:%S}] Start training".format(datetime.datetime.now())) - if num_batches is not None: + if num_batches is not None and not self.is_wrapper: n = len(data) batch_size = int(n / num_batches) bcount = 1 diff --git a/pyFTS/tests/multivariate.py b/pyFTS/tests/multivariate.py index 4ceb69e..ec90ae5 100644 --- a/pyFTS/tests/multivariate.py +++ b/pyFTS/tests/multivariate.py @@ -29,12 +29,19 @@ from pyFTS.partitioners import Grid partitioner = Grid.GridPartitioner(data=train_data, npart=35) -from pyFTS.models import pwfts +from pyFTS.models import pwfts, hofts -model = pwfts.ProbabilisticWeightedFTS(partitioner=partitioner, order=2) -model.train(train_data) +#model = pwfts.ProbabilisticWeightedFTS(partitioner=partitioner, order=2) +#from pyFTS.models.incremental import TimeVariant -print(model.predict(test_data[:100])) +#model = TimeVariant.Retrainer(partitioner_method=Grid.GridPartitioner, partitioner_params={'npart': 35}, +# fts_method=pwfts.ProbabilisticWeightedFTS, fts_params={}, order=2 , +# batch_size=100, window_length=500) + +model = hofts.HighOrderFTS(partitioner=partitioner, order=2) +model.fit(train_data) + +print(model.predict(test_data, steps_ahead=10)) '''