Introducing batch_size on Retraining

This commit is contained in:
Petrônio Cândido 2018-11-01 15:55:35 -03:00
parent e4ee163660
commit b677a0e627
2 changed files with 14 additions and 45 deletions

View File

@ -32,7 +32,10 @@ class Retrainer(fts.FTS):
"""The memory window length"""
self.auto_update = False
"""If true the model is updated at each time and not recreated"""
self.batch_size = kwargs.get('batch_size', 10)
"""The batch interval between each retraining"""
self.is_high_order = True
self.uod_clip = False
def train(self, data, **kwargs):
self.partitioner = self.partitioner_method(data=data, **self.partitioner_params)
@ -50,10 +53,12 @@ class Retrainer(fts.FTS):
_train = data[k - horizon: k - self.order]
_test = data[k - self.order: k]
if self.auto_update:
self.model.train(_train)
else:
self.train(_train, **kwargs)
if k % self.batch_size == 0 or self.model is None:
print("Treinando {}".format(k))
if self.auto_update:
self.model.train(_train)
else:
self.train(_train, **kwargs)
ret.extend(self.model.predict(_test, **kwargs))

View File

@ -16,55 +16,19 @@ from pyFTS.common import Transformations
tdiff = Transformations.Differential(1)
dataset = pd.read_csv('/home/petronio/Downloads/priceHong')
dataset['hour'] = dataset.index.values % 24
split = 24 * 800
#train = data[:split].flatten()
#test = data[split:].flatten()
#print(train)
from pyFTS.models.multivariate import common, variable, mvfts
from pyFTS.partitioners import Grid
from pyFTS.models.seasonal.common import DateTime
from pyFTS.models.seasonal import partitioner as seasonal
vhour = variable.Variable("Hour", data_label="hour", partitioner=seasonal.TimeGridPartitioner, npart=24,
data=dataset, partitioner_specific={'seasonality': DateTime.hour_of_day, 'type': 'common'})
vprice = variable.Variable("Price", data_label="price", partitioner=Grid.GridPartitioner, npart=25,
data=dataset)
fig, ax = plt.subplots(nrows=2, ncols=1,figsize=[15,5])
vhour.partitioner.plot(ax[0])
vprice.partitioner.plot(ax[1])
model = mvfts.MVFTS()
#model.shortname += ' ' + key
model.append_variable(vhour)
model.append_variable(vprice)
# model.shortname += ' ' + w
model.target_variable = vprice
model.fit(dataset.iloc[:split])
'''
from pyFTS.data import TAIEX, SP500, NASDAQ
dataset = TAIEX.get_data()
partitioner = Grid.GridPartitioner(data=dataset[:800], npart=30) #, transformation=tdiff)
from pyFTS.models.incremental import Retrainer
model = pwfts.ProbabilisticWeightedFTS(partitioner=partitioner, order=2)
#model.append_transformation(tdiff)
model = Retrainer.Retrainer(partitioner_params = {'npart': 30},
fts_method=hofts.HighOrderFTS, order = 2,
window_length = 500, batch_size = 100)
model.fit(dataset[:800])
model.predict(dataset)
print(model)
ret = model.predict([5000.00, 5200.00, 5400.00], explain=True)
'''
'''
#dataset = SP500.get_data()[11500:16000]
#dataset = NASDAQ.get_data()