Refactorings on IncrementalEnsemble
This commit is contained in:
parent
d05ed762fe
commit
10c1ca5437
@ -65,17 +65,18 @@ class IncrementalEnsembleFTS(ensemble.EnsembleFTS):
|
|||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
|
|
||||||
for k in np.arange(self.max_lag, l):
|
for k in np.arange(0, l):
|
||||||
|
|
||||||
data_window.append(data[k - self.max_lag])
|
data_window.append(data[k])
|
||||||
|
|
||||||
if k >= self.window_length:
|
if k >= self.window_length:
|
||||||
data_window.pop(0)
|
data_window.pop(0)
|
||||||
|
|
||||||
if k % self.batch_size == 0 and k - self.max_lag >= self.window_length:
|
if k % self.batch_size == 0 and k >= self.window_length:
|
||||||
self.train(data_window, **kwargs)
|
self.train(data_window, **kwargs)
|
||||||
|
|
||||||
sample = data[k - self.max_lag: k]
|
if len(self.models) > 0:
|
||||||
|
sample = data[k: k + self.max_lag]
|
||||||
tmp = self.get_models_forecasts(sample)
|
tmp = self.get_models_forecasts(sample)
|
||||||
point = self.get_point(tmp)
|
point = self.get_point(tmp)
|
||||||
ret.append(point)
|
ret.append(point)
|
||||||
|
@ -7,19 +7,30 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pyFTS.partitioners import Grid
|
from pyFTS.partitioners import Grid
|
||||||
from pyFTS.common import Transformations
|
from pyFTS.common import Transformations
|
||||||
from pyFTS.models import chen
|
from pyFTS.models import chen, hofts
|
||||||
from pyFTS.models.incremental import IncrementalEnsemble, TimeVariant
|
from pyFTS.models.incremental import IncrementalEnsemble, TimeVariant
|
||||||
|
|
||||||
from pyFTS.data import AirPassengers
|
from pyFTS.data import AirPassengers, artificial
|
||||||
|
|
||||||
|
mu_local = 5
|
||||||
|
sigma_local = 0.25
|
||||||
|
mu_drift = 10
|
||||||
|
sigma_drift = 1.
|
||||||
|
deflen = 100
|
||||||
|
totlen = deflen * 10
|
||||||
|
order = 10
|
||||||
|
|
||||||
|
|
||||||
passengers = AirPassengers.get_data()
|
signal = artificial.SignalEmulator()\
|
||||||
|
.stationary_gaussian(mu_local,sigma_local,length=deflen//2,it=10)\
|
||||||
|
.stationary_gaussian(mu_drift,sigma_drift,length=deflen//2,it=10, additive=False)\
|
||||||
|
.run()
|
||||||
|
|
||||||
model = IncrementalEnsemble.IncrementalEnsembleFTS(order=2, window_length=20, batch_size=5)
|
model2 = IncrementalEnsemble.IncrementalEnsembleFTS(partitioner_method=Grid.GridPartitioner, partitioner_params={'npart': 15},
|
||||||
|
fts_method=hofts.WeightedHighOrderFTS, fts_params={}, order=2 ,
|
||||||
|
batch_size=20, window_length=100, num_models=5)
|
||||||
|
|
||||||
model.fit(passengers[:40])
|
forecasts = model2.predict(signal)
|
||||||
|
|
||||||
forecasts = model.predict(passengers[40:])
|
|
||||||
|
|
||||||
print(forecasts)
|
print(forecasts)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user