Refactorings on IncrementalEnsemble
This commit is contained in:
parent
d05ed762fe
commit
10c1ca5437
@ -65,17 +65,18 @@ class IncrementalEnsembleFTS(ensemble.EnsembleFTS):
|
||||
|
||||
ret = []
|
||||
|
||||
for k in np.arange(self.max_lag, l):
|
||||
for k in np.arange(0, l):
|
||||
|
||||
data_window.append(data[k - self.max_lag])
|
||||
data_window.append(data[k])
|
||||
|
||||
if k >= self.window_length:
|
||||
data_window.pop(0)
|
||||
|
||||
if k % self.batch_size == 0 and k - self.max_lag >= self.window_length:
|
||||
if k % self.batch_size == 0 and k >= self.window_length:
|
||||
self.train(data_window, **kwargs)
|
||||
|
||||
sample = data[k - self.max_lag: k]
|
||||
if len(self.models) > 0:
|
||||
sample = data[k: k + self.max_lag]
|
||||
tmp = self.get_models_forecasts(sample)
|
||||
point = self.get_point(tmp)
|
||||
ret.append(point)
|
||||
|
@ -7,19 +7,30 @@ import numpy as np
|
||||
import pandas as pd
|
||||
from pyFTS.partitioners import Grid
|
||||
from pyFTS.common import Transformations
|
||||
from pyFTS.models import chen
|
||||
from pyFTS.models import chen, hofts
|
||||
from pyFTS.models.incremental import IncrementalEnsemble, TimeVariant
|
||||
|
||||
from pyFTS.data import AirPassengers
|
||||
from pyFTS.data import AirPassengers, artificial
|
||||
|
||||
mu_local = 5
|
||||
sigma_local = 0.25
|
||||
mu_drift = 10
|
||||
sigma_drift = 1.
|
||||
deflen = 100
|
||||
totlen = deflen * 10
|
||||
order = 10
|
||||
|
||||
|
||||
passengers = AirPassengers.get_data()
|
||||
signal = artificial.SignalEmulator()\
|
||||
.stationary_gaussian(mu_local,sigma_local,length=deflen//2,it=10)\
|
||||
.stationary_gaussian(mu_drift,sigma_drift,length=deflen//2,it=10, additive=False)\
|
||||
.run()
|
||||
|
||||
model = IncrementalEnsemble.IncrementalEnsembleFTS(order=2, window_length=20, batch_size=5)
|
||||
model2 = IncrementalEnsemble.IncrementalEnsembleFTS(partitioner_method=Grid.GridPartitioner, partitioner_params={'npart': 15},
|
||||
fts_method=hofts.WeightedHighOrderFTS, fts_params={}, order=2 ,
|
||||
batch_size=20, window_length=100, num_models=5)
|
||||
|
||||
model.fit(passengers[:40])
|
||||
|
||||
forecasts = model.predict(passengers[40:])
|
||||
forecasts = model2.predict(signal)
|
||||
|
||||
print(forecasts)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user