Bugfixes on Retrainer

This commit is contained in:
Petrônio Cândido 2018-11-03 09:35:31 -03:00
parent b677a0e627
commit 932ab9168e
2 changed files with 6 additions and 4 deletions

View File

@ -36,6 +36,7 @@ class Retrainer(fts.FTS):
"""The batch interval between each retraining""" """The batch interval between each retraining"""
self.is_high_order = True self.is_high_order = True
self.uod_clip = False self.uod_clip = False
self.max_lag = self.window_length + self.order
def train(self, data, **kwargs): def train(self, data, **kwargs):
self.partitioner = self.partitioner_method(data=data, **self.partitioner_params) self.partitioner = self.partitioner_method(data=data, **self.partitioner_params)
@ -49,12 +50,11 @@ class Retrainer(fts.FTS):
ret = [] ret = []
for k in np.arange(horizon, l): for k in np.arange(horizon, l+1):
_train = data[k - horizon: k - self.order] _train = data[k - horizon: k - self.order]
_test = data[k - self.order: k] _test = data[k - self.order: k]
if k % self.batch_size == 0 or self.model is None: if k % self.batch_size == 0 or self.model is None:
print("Treinando {}".format(k))
if self.auto_update: if self.auto_update:
self.model.train(_train) self.model.train(_train)
else: else:

View File

@ -10,7 +10,7 @@ import pandas as pd
from pyFTS.common import Util as cUtil, FuzzySet from pyFTS.common import Util as cUtil, FuzzySet
from pyFTS.partitioners import Grid, Entropy, Util as pUtil from pyFTS.partitioners import Grid, Entropy, Util as pUtil
from pyFTS.benchmarks import benchmarks as bchmk from pyFTS.benchmarks import benchmarks as bchmk, Measures
from pyFTS.models import chen, yu, cheng, ismailefendi, hofts, pwfts from pyFTS.models import chen, yu, cheng, ismailefendi, hofts, pwfts
from pyFTS.common import Transformations from pyFTS.common import Transformations
@ -27,7 +27,9 @@ model = Retrainer.Retrainer(partitioner_params = {'npart': 30},
fts_method=hofts.HighOrderFTS, order = 2, fts_method=hofts.HighOrderFTS, order = 2,
window_length = 500, batch_size = 100) window_length = 500, batch_size = 100)
model.predict(dataset) #model.predict(dataset)
Measures.get_point_statistics(dataset, model)
''' '''
#dataset = SP500.get_data()[11500:16000] #dataset = SP500.get_data()[11500:16000]