Introducing batch_size on Retraining
This commit is contained in:
parent
e4ee163660
commit
b677a0e627
@ -32,7 +32,10 @@ class Retrainer(fts.FTS):
|
|||||||
"""The memory window length"""
|
"""The memory window length"""
|
||||||
self.auto_update = False
|
self.auto_update = False
|
||||||
"""If true the model is updated at each time and not recreated"""
|
"""If true the model is updated at each time and not recreated"""
|
||||||
|
self.batch_size = kwargs.get('batch_size', 10)
|
||||||
|
"""The batch interval between each retraining"""
|
||||||
self.is_high_order = True
|
self.is_high_order = True
|
||||||
|
self.uod_clip = False
|
||||||
|
|
||||||
def train(self, data, **kwargs):
|
def train(self, data, **kwargs):
|
||||||
self.partitioner = self.partitioner_method(data=data, **self.partitioner_params)
|
self.partitioner = self.partitioner_method(data=data, **self.partitioner_params)
|
||||||
@ -50,10 +53,12 @@ class Retrainer(fts.FTS):
|
|||||||
_train = data[k - horizon: k - self.order]
|
_train = data[k - horizon: k - self.order]
|
||||||
_test = data[k - self.order: k]
|
_test = data[k - self.order: k]
|
||||||
|
|
||||||
if self.auto_update:
|
if k % self.batch_size == 0 or self.model is None:
|
||||||
self.model.train(_train)
|
print("Treinando {}".format(k))
|
||||||
else:
|
if self.auto_update:
|
||||||
self.train(_train, **kwargs)
|
self.model.train(_train)
|
||||||
|
else:
|
||||||
|
self.train(_train, **kwargs)
|
||||||
|
|
||||||
ret.extend(self.model.predict(_test, **kwargs))
|
ret.extend(self.model.predict(_test, **kwargs))
|
||||||
|
|
||||||
|
@ -16,55 +16,19 @@ from pyFTS.common import Transformations
|
|||||||
|
|
||||||
tdiff = Transformations.Differential(1)
|
tdiff = Transformations.Differential(1)
|
||||||
|
|
||||||
dataset = pd.read_csv('/home/petronio/Downloads/priceHong')
|
|
||||||
dataset['hour'] = dataset.index.values % 24
|
|
||||||
|
|
||||||
split = 24 * 800
|
|
||||||
#train = data[:split].flatten()
|
|
||||||
#test = data[split:].flatten()
|
|
||||||
|
|
||||||
#print(train)
|
|
||||||
|
|
||||||
from pyFTS.models.multivariate import common, variable, mvfts
|
|
||||||
from pyFTS.partitioners import Grid
|
|
||||||
from pyFTS.models.seasonal.common import DateTime
|
|
||||||
from pyFTS.models.seasonal import partitioner as seasonal
|
|
||||||
|
|
||||||
vhour = variable.Variable("Hour", data_label="hour", partitioner=seasonal.TimeGridPartitioner, npart=24,
|
|
||||||
data=dataset, partitioner_specific={'seasonality': DateTime.hour_of_day, 'type': 'common'})
|
|
||||||
vprice = variable.Variable("Price", data_label="price", partitioner=Grid.GridPartitioner, npart=25,
|
|
||||||
data=dataset)
|
|
||||||
|
|
||||||
|
|
||||||
fig, ax = plt.subplots(nrows=2, ncols=1,figsize=[15,5])
|
|
||||||
|
|
||||||
vhour.partitioner.plot(ax[0])
|
|
||||||
vprice.partitioner.plot(ax[1])
|
|
||||||
|
|
||||||
model = mvfts.MVFTS()
|
|
||||||
#model.shortname += ' ' + key
|
|
||||||
model.append_variable(vhour)
|
|
||||||
model.append_variable(vprice)
|
|
||||||
# model.shortname += ' ' + w
|
|
||||||
model.target_variable = vprice
|
|
||||||
model.fit(dataset.iloc[:split])
|
|
||||||
|
|
||||||
'''
|
|
||||||
from pyFTS.data import TAIEX, SP500, NASDAQ
|
from pyFTS.data import TAIEX, SP500, NASDAQ
|
||||||
|
|
||||||
dataset = TAIEX.get_data()
|
dataset = TAIEX.get_data()
|
||||||
|
|
||||||
partitioner = Grid.GridPartitioner(data=dataset[:800], npart=30) #, transformation=tdiff)
|
from pyFTS.models.incremental import Retrainer
|
||||||
|
|
||||||
model = pwfts.ProbabilisticWeightedFTS(partitioner=partitioner, order=2)
|
model = Retrainer.Retrainer(partitioner_params = {'npart': 30},
|
||||||
#model.append_transformation(tdiff)
|
fts_method=hofts.HighOrderFTS, order = 2,
|
||||||
|
window_length = 500, batch_size = 100)
|
||||||
|
|
||||||
model.fit(dataset[:800])
|
model.predict(dataset)
|
||||||
|
|
||||||
print(model)
|
|
||||||
|
|
||||||
ret = model.predict([5000.00, 5200.00, 5400.00], explain=True)
|
|
||||||
'''
|
|
||||||
'''
|
'''
|
||||||
#dataset = SP500.get_data()[11500:16000]
|
#dataset = SP500.get_data()[11500:16000]
|
||||||
#dataset = NASDAQ.get_data()
|
#dataset = NASDAQ.get_data()
|
||||||
|
Loading…
Reference in New Issue
Block a user