diff --git a/pyFTS/models/incremental/Retrainer.py b/pyFTS/models/incremental/Retrainer.py index e211d9d..044c2b6 100644 --- a/pyFTS/models/incremental/Retrainer.py +++ b/pyFTS/models/incremental/Retrainer.py @@ -32,7 +32,10 @@ class Retrainer(fts.FTS): """The memory window length""" self.auto_update = False """If true the model is updated at each time and not recreated""" + self.batch_size = kwargs.get('batch_size', 10) + """The batch interval between each retraining""" self.is_high_order = True + self.uod_clip = False def train(self, data, **kwargs): self.partitioner = self.partitioner_method(data=data, **self.partitioner_params) @@ -50,10 +53,12 @@ class Retrainer(fts.FTS): _train = data[k - horizon: k - self.order] _test = data[k - self.order: k] - if self.auto_update: - self.model.train(_train) - else: - self.train(_train, **kwargs) + if k % self.batch_size == 0 or self.model is None: + print("Treinando {}".format(k)) + if self.auto_update: + self.model.train(_train) + else: + self.train(_train, **kwargs) ret.extend(self.model.predict(_test, **kwargs)) diff --git a/pyFTS/tests/general.py b/pyFTS/tests/general.py index 647b08d..6021eda 100644 --- a/pyFTS/tests/general.py +++ b/pyFTS/tests/general.py @@ -16,55 +16,19 @@ from pyFTS.common import Transformations tdiff = Transformations.Differential(1) -dataset = pd.read_csv('/home/petronio/Downloads/priceHong') -dataset['hour'] = dataset.index.values % 24 -split = 24 * 800 -#train = data[:split].flatten() -#test = data[split:].flatten() - -#print(train) - -from pyFTS.models.multivariate import common, variable, mvfts -from pyFTS.partitioners import Grid -from pyFTS.models.seasonal.common import DateTime -from pyFTS.models.seasonal import partitioner as seasonal - -vhour = variable.Variable("Hour", data_label="hour", partitioner=seasonal.TimeGridPartitioner, npart=24, - data=dataset, partitioner_specific={'seasonality': DateTime.hour_of_day, 'type': 'common'}) -vprice = variable.Variable("Price", data_label="price", partitioner=Grid.GridPartitioner, npart=25, - data=dataset) - - -fig, ax = plt.subplots(nrows=2, ncols=1,figsize=[15,5]) - -vhour.partitioner.plot(ax[0]) -vprice.partitioner.plot(ax[1]) - -model = mvfts.MVFTS() -#model.shortname += ' ' + key -model.append_variable(vhour) -model.append_variable(vprice) -# model.shortname += ' ' + w -model.target_variable = vprice -model.fit(dataset.iloc[:split]) - -''' from pyFTS.data import TAIEX, SP500, NASDAQ dataset = TAIEX.get_data() -partitioner = Grid.GridPartitioner(data=dataset[:800], npart=30) #, transformation=tdiff) +from pyFTS.models.incremental import Retrainer -model = pwfts.ProbabilisticWeightedFTS(partitioner=partitioner, order=2) -#model.append_transformation(tdiff) +model = Retrainer.Retrainer(partitioner_params = {'npart': 30}, + fts_method=hofts.HighOrderFTS, order = 2, + window_length = 500, batch_size = 100) -model.fit(dataset[:800]) +model.predict(dataset) -print(model) - -ret = model.predict([5000.00, 5200.00, 5400.00], explain=True) -''' ''' #dataset = SP500.get_data()[11500:16000] #dataset = NASDAQ.get_data()