From 9bbb5d4c4dedff2d6ceb25a4ba88b3660567809c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petr=C3=B4nio=20C=C3=A2ndido?= Date: Mon, 5 Nov 2018 18:40:26 -0200 Subject: [PATCH] Bugfixes on Retrainer --- pyFTS/models/incremental/Retrainer.py | 5 ++++- pyFTS/tests/general.py | 25 +++++++++++++++++++++---- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/pyFTS/models/incremental/Retrainer.py b/pyFTS/models/incremental/Retrainer.py index 50af370..455756e 100644 --- a/pyFTS/models/incremental/Retrainer.py +++ b/pyFTS/models/incremental/Retrainer.py @@ -40,7 +40,10 @@ class Retrainer(fts.FTS): def train(self, data, **kwargs): self.partitioner = self.partitioner_method(data=data, **self.partitioner_params) - self.model = self.fts_method(partitioner=self.partitioner, order=self.order, **self.fts_params) + self.model = self.fts_method(partitioner=self.partitioner, **self.fts_params) + if self.model.is_high_order: + self.model.order = self.model = self.fts_method(partitioner=self.partitioner, + order=self.order, **self.fts_params) self.model.fit(data, **kwargs) self.shortname = self.model.shortname diff --git a/pyFTS/tests/general.py b/pyFTS/tests/general.py index 1483cb8..1aac53b 100644 --- a/pyFTS/tests/general.py +++ b/pyFTS/tests/general.py @@ -23,13 +23,30 @@ dataset = TAIEX.get_data() from pyFTS.models.incremental import Retrainer -model = Retrainer.Retrainer(partitioner_params = {'npart': 30}, - fts_method=hofts.HighOrderFTS, order = 2, - window_length = 500, batch_size = 100) +from pyFTS.models.incremental import Retrainer +from pyFTS.benchmarks import benchmarks as bchmk + +models = [] +for method in bchmk.get_point_methods(): + model = Retrainer.Retrainer(partitioner_params = {'npart': 30}, + fts_method=method, + window_length = 500, batch_size = 100) + models.append(model) #model.predict(dataset) -Measures.get_point_statistics(dataset, model) +from pyFTS.partitioners import Grid, Util as pUtil +from pyFTS.benchmarks import benchmarks as bchmk, naive + +tag = 'benchmarks_retrainer' + +bchmk.sliding_window_benchmarks(dataset, 2000, train=.1, inc=0.1, + models=[model], + build_methods = False, + benchmark_models=False, + partitions=[35], + progress=False, type='point', + file="nsfts_benchmarks.db", dataset='teste', tag=tag) ''' #dataset = SP500.get_data()[11500:16000]