Introducing batch_size on Retraining
This commit is contained in:
parent
e4ee163660
commit
b677a0e627
@ -32,7 +32,10 @@ class Retrainer(fts.FTS):
|
||||
"""The memory window length"""
|
||||
self.auto_update = False
|
||||
"""If true the model is updated at each time and not recreated"""
|
||||
self.batch_size = kwargs.get('batch_size', 10)
|
||||
"""The batch interval between each retraining"""
|
||||
self.is_high_order = True
|
||||
self.uod_clip = False
|
||||
|
||||
def train(self, data, **kwargs):
|
||||
self.partitioner = self.partitioner_method(data=data, **self.partitioner_params)
|
||||
@ -50,6 +53,8 @@ class Retrainer(fts.FTS):
|
||||
_train = data[k - horizon: k - self.order]
|
||||
_test = data[k - self.order: k]
|
||||
|
||||
if k % self.batch_size == 0 or self.model is None:
|
||||
print("Treinando {}".format(k))
|
||||
if self.auto_update:
|
||||
self.model.train(_train)
|
||||
else:
|
||||
|
@ -16,55 +16,19 @@ from pyFTS.common import Transformations
|
||||
|
||||
tdiff = Transformations.Differential(1)
|
||||
|
||||
dataset = pd.read_csv('/home/petronio/Downloads/priceHong')
|
||||
dataset['hour'] = dataset.index.values % 24
|
||||
|
||||
split = 24 * 800
|
||||
#train = data[:split].flatten()
|
||||
#test = data[split:].flatten()
|
||||
|
||||
#print(train)
|
||||
|
||||
from pyFTS.models.multivariate import common, variable, mvfts
|
||||
from pyFTS.partitioners import Grid
|
||||
from pyFTS.models.seasonal.common import DateTime
|
||||
from pyFTS.models.seasonal import partitioner as seasonal
|
||||
|
||||
vhour = variable.Variable("Hour", data_label="hour", partitioner=seasonal.TimeGridPartitioner, npart=24,
|
||||
data=dataset, partitioner_specific={'seasonality': DateTime.hour_of_day, 'type': 'common'})
|
||||
vprice = variable.Variable("Price", data_label="price", partitioner=Grid.GridPartitioner, npart=25,
|
||||
data=dataset)
|
||||
|
||||
|
||||
fig, ax = plt.subplots(nrows=2, ncols=1,figsize=[15,5])
|
||||
|
||||
vhour.partitioner.plot(ax[0])
|
||||
vprice.partitioner.plot(ax[1])
|
||||
|
||||
model = mvfts.MVFTS()
|
||||
#model.shortname += ' ' + key
|
||||
model.append_variable(vhour)
|
||||
model.append_variable(vprice)
|
||||
# model.shortname += ' ' + w
|
||||
model.target_variable = vprice
|
||||
model.fit(dataset.iloc[:split])
|
||||
|
||||
'''
|
||||
from pyFTS.data import TAIEX, SP500, NASDAQ
|
||||
|
||||
dataset = TAIEX.get_data()
|
||||
|
||||
partitioner = Grid.GridPartitioner(data=dataset[:800], npart=30) #, transformation=tdiff)
|
||||
from pyFTS.models.incremental import Retrainer
|
||||
|
||||
model = pwfts.ProbabilisticWeightedFTS(partitioner=partitioner, order=2)
|
||||
#model.append_transformation(tdiff)
|
||||
model = Retrainer.Retrainer(partitioner_params = {'npart': 30},
|
||||
fts_method=hofts.HighOrderFTS, order = 2,
|
||||
window_length = 500, batch_size = 100)
|
||||
|
||||
model.fit(dataset[:800])
|
||||
model.predict(dataset)
|
||||
|
||||
print(model)
|
||||
|
||||
ret = model.predict([5000.00, 5200.00, 5400.00], explain=True)
|
||||
'''
|
||||
'''
|
||||
#dataset = SP500.get_data()[11500:16000]
|
||||
#dataset = NASDAQ.get_data()
|
||||
|
Loading…
Reference in New Issue
Block a user