Bugfixes on Retrainer

This commit is contained in:
Petrônio Cândido 2018-11-05 18:40:26 -02:00
parent 96651d1059
commit 9bbb5d4c4d
2 changed files with 25 additions and 5 deletions

View File

@ -40,7 +40,10 @@ class Retrainer(fts.FTS):
def train(self, data, **kwargs): def train(self, data, **kwargs):
self.partitioner = self.partitioner_method(data=data, **self.partitioner_params) self.partitioner = self.partitioner_method(data=data, **self.partitioner_params)
self.model = self.fts_method(partitioner=self.partitioner, order=self.order, **self.fts_params) self.model = self.fts_method(partitioner=self.partitioner, **self.fts_params)
if self.model.is_high_order:
self.model.order = self.model = self.fts_method(partitioner=self.partitioner,
order=self.order, **self.fts_params)
self.model.fit(data, **kwargs) self.model.fit(data, **kwargs)
self.shortname = self.model.shortname self.shortname = self.model.shortname

View File

@ -23,13 +23,30 @@ dataset = TAIEX.get_data()
from pyFTS.models.incremental import Retrainer from pyFTS.models.incremental import Retrainer
model = Retrainer.Retrainer(partitioner_params = {'npart': 30}, from pyFTS.models.incremental import Retrainer
fts_method=hofts.HighOrderFTS, order = 2, from pyFTS.benchmarks import benchmarks as bchmk
window_length = 500, batch_size = 100)
models = []
for method in bchmk.get_point_methods():
model = Retrainer.Retrainer(partitioner_params = {'npart': 30},
fts_method=method,
window_length = 500, batch_size = 100)
models.append(model)
#model.predict(dataset) #model.predict(dataset)
Measures.get_point_statistics(dataset, model) from pyFTS.partitioners import Grid, Util as pUtil
from pyFTS.benchmarks import benchmarks as bchmk, naive
tag = 'benchmarks_retrainer'
bchmk.sliding_window_benchmarks(dataset, 2000, train=.1, inc=0.1,
models=[model],
build_methods = False,
benchmark_models=False,
partitions=[35],
progress=False, type='point',
file="nsfts_benchmarks.db", dataset='teste', tag=tag)
''' '''
#dataset = SP500.get_data()[11500:16000] #dataset = SP500.get_data()[11500:16000]