Refactorings on IncrementalEnsemble

This commit is contained in:
Petrônio Cândido 2019-02-27 14:57:36 -03:00
parent d05ed762fe
commit 10c1ca5437
2 changed files with 26 additions and 14 deletions

View File

@ -65,20 +65,21 @@ class IncrementalEnsembleFTS(ensemble.EnsembleFTS):
ret = []
for k in np.arange(self.max_lag, l):
for k in np.arange(0, l):
data_window.append(data[k - self.max_lag])
data_window.append(data[k])
if k >= self.window_length:
data_window.pop(0)
if k % self.batch_size == 0 and k - self.max_lag >= self.window_length:
if k % self.batch_size == 0 and k >= self.window_length:
self.train(data_window, **kwargs)
sample = data[k - self.max_lag: k]
tmp = self.get_models_forecasts(sample)
point = self.get_point(tmp)
ret.append(point)
if len(self.models) > 0:
sample = data[k: k + self.max_lag]
tmp = self.get_models_forecasts(sample)
point = self.get_point(tmp)
ret.append(point)
return ret

View File

@ -7,19 +7,30 @@ import numpy as np
import pandas as pd
from pyFTS.partitioners import Grid
from pyFTS.common import Transformations
from pyFTS.models import chen
from pyFTS.models import chen, hofts
from pyFTS.models.incremental import IncrementalEnsemble, TimeVariant
from pyFTS.data import AirPassengers
from pyFTS.data import AirPassengers, artificial
mu_local = 5
sigma_local = 0.25
mu_drift = 10
sigma_drift = 1.
deflen = 100
totlen = deflen * 10
order = 10
passengers = AirPassengers.get_data()
signal = artificial.SignalEmulator()\
.stationary_gaussian(mu_local,sigma_local,length=deflen//2,it=10)\
.stationary_gaussian(mu_drift,sigma_drift,length=deflen//2,it=10, additive=False)\
.run()
model = IncrementalEnsemble.IncrementalEnsembleFTS(order=2, window_length=20, batch_size=5)
model2 = IncrementalEnsemble.IncrementalEnsembleFTS(partitioner_method=Grid.GridPartitioner, partitioner_params={'npart': 15},
fts_method=hofts.WeightedHighOrderFTS, fts_params={}, order=2 ,
batch_size=20, window_length=100, num_models=5)
model.fit(passengers[:40])
forecasts = model.predict(passengers[40:])
forecasts = model2.predict(signal)
print(forecasts)