Bugfixes on Retrainer
This commit is contained in:
parent
96651d1059
commit
9bbb5d4c4d
@ -40,7 +40,10 @@ class Retrainer(fts.FTS):
|
|||||||
|
|
||||||
def train(self, data, **kwargs):
|
def train(self, data, **kwargs):
|
||||||
self.partitioner = self.partitioner_method(data=data, **self.partitioner_params)
|
self.partitioner = self.partitioner_method(data=data, **self.partitioner_params)
|
||||||
self.model = self.fts_method(partitioner=self.partitioner, order=self.order, **self.fts_params)
|
self.model = self.fts_method(partitioner=self.partitioner, **self.fts_params)
|
||||||
|
if self.model.is_high_order:
|
||||||
|
self.model.order = self.model = self.fts_method(partitioner=self.partitioner,
|
||||||
|
order=self.order, **self.fts_params)
|
||||||
self.model.fit(data, **kwargs)
|
self.model.fit(data, **kwargs)
|
||||||
self.shortname = self.model.shortname
|
self.shortname = self.model.shortname
|
||||||
|
|
||||||
|
@ -23,13 +23,30 @@ dataset = TAIEX.get_data()
|
|||||||
|
|
||||||
from pyFTS.models.incremental import Retrainer
|
from pyFTS.models.incremental import Retrainer
|
||||||
|
|
||||||
model = Retrainer.Retrainer(partitioner_params = {'npart': 30},
|
from pyFTS.models.incremental import Retrainer
|
||||||
fts_method=hofts.HighOrderFTS, order = 2,
|
from pyFTS.benchmarks import benchmarks as bchmk
|
||||||
|
|
||||||
|
models = []
|
||||||
|
for method in bchmk.get_point_methods():
|
||||||
|
model = Retrainer.Retrainer(partitioner_params = {'npart': 30},
|
||||||
|
fts_method=method,
|
||||||
window_length = 500, batch_size = 100)
|
window_length = 500, batch_size = 100)
|
||||||
|
models.append(model)
|
||||||
|
|
||||||
#model.predict(dataset)
|
#model.predict(dataset)
|
||||||
|
|
||||||
Measures.get_point_statistics(dataset, model)
|
from pyFTS.partitioners import Grid, Util as pUtil
|
||||||
|
from pyFTS.benchmarks import benchmarks as bchmk, naive
|
||||||
|
|
||||||
|
tag = 'benchmarks_retrainer'
|
||||||
|
|
||||||
|
bchmk.sliding_window_benchmarks(dataset, 2000, train=.1, inc=0.1,
|
||||||
|
models=[model],
|
||||||
|
build_methods = False,
|
||||||
|
benchmark_models=False,
|
||||||
|
partitions=[35],
|
||||||
|
progress=False, type='point',
|
||||||
|
file="nsfts_benchmarks.db", dataset='teste', tag=tag)
|
||||||
|
|
||||||
'''
|
'''
|
||||||
#dataset = SP500.get_data()[11500:16000]
|
#dataset = SP500.get_data()[11500:16000]
|
||||||
|
Loading…
Reference in New Issue
Block a user