Refactorings on IncrementalEnsemble

This commit is contained in:
Petrônio Cândido 2019-02-27 14:57:36 -03:00
parent d05ed762fe
commit 10c1ca5437
2 changed files with 26 additions and 14 deletions

View File

@ -65,17 +65,18 @@ class IncrementalEnsembleFTS(ensemble.EnsembleFTS):
ret = [] ret = []
for k in np.arange(self.max_lag, l): for k in np.arange(0, l):
data_window.append(data[k - self.max_lag]) data_window.append(data[k])
if k >= self.window_length: if k >= self.window_length:
data_window.pop(0) data_window.pop(0)
if k % self.batch_size == 0 and k - self.max_lag >= self.window_length: if k % self.batch_size == 0 and k >= self.window_length:
self.train(data_window, **kwargs) self.train(data_window, **kwargs)
sample = data[k - self.max_lag: k] if len(self.models) > 0:
sample = data[k: k + self.max_lag]
tmp = self.get_models_forecasts(sample) tmp = self.get_models_forecasts(sample)
point = self.get_point(tmp) point = self.get_point(tmp)
ret.append(point) ret.append(point)

View File

@ -7,19 +7,30 @@ import numpy as np
import pandas as pd import pandas as pd
from pyFTS.partitioners import Grid from pyFTS.partitioners import Grid
from pyFTS.common import Transformations from pyFTS.common import Transformations
from pyFTS.models import chen from pyFTS.models import chen, hofts
from pyFTS.models.incremental import IncrementalEnsemble, TimeVariant from pyFTS.models.incremental import IncrementalEnsemble, TimeVariant
from pyFTS.data import AirPassengers from pyFTS.data import AirPassengers, artificial
mu_local = 5
sigma_local = 0.25
mu_drift = 10
sigma_drift = 1.
deflen = 100
totlen = deflen * 10
order = 10
passengers = AirPassengers.get_data() signal = artificial.SignalEmulator()\
.stationary_gaussian(mu_local,sigma_local,length=deflen//2,it=10)\
.stationary_gaussian(mu_drift,sigma_drift,length=deflen//2,it=10, additive=False)\
.run()
model = IncrementalEnsemble.IncrementalEnsembleFTS(order=2, window_length=20, batch_size=5) model2 = IncrementalEnsemble.IncrementalEnsembleFTS(partitioner_method=Grid.GridPartitioner, partitioner_params={'npart': 15},
fts_method=hofts.WeightedHighOrderFTS, fts_params={}, order=2 ,
batch_size=20, window_length=100, num_models=5)
model.fit(passengers[:40]) forecasts = model2.predict(signal)
forecasts = model.predict(passengers[40:])
print(forecasts) print(forecasts)