Refactorings on IncrementalEnsemble
This commit is contained in:
parent
10c1ca5437
commit
8a01256185
@ -65,18 +65,20 @@ class IncrementalEnsembleFTS(ensemble.EnsembleFTS):
|
|||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
|
|
||||||
for k in np.arange(0, l):
|
for k in np.arange(self.max_lag, l):
|
||||||
|
|
||||||
data_window.append(data[k])
|
k2 = k - self.max_lag
|
||||||
|
|
||||||
if k >= self.window_length:
|
data_window.append(data[k2])
|
||||||
|
|
||||||
|
if k2 >= self.window_length:
|
||||||
data_window.pop(0)
|
data_window.pop(0)
|
||||||
|
|
||||||
if k % self.batch_size == 0 and k >= self.window_length:
|
if k % self.batch_size == 0 and k2 >= self.window_length:
|
||||||
self.train(data_window, **kwargs)
|
self.train(data_window, **kwargs)
|
||||||
|
|
||||||
if len(self.models) > 0:
|
if len(self.models) > 0:
|
||||||
sample = data[k: k + self.max_lag]
|
sample = data[k2: k]
|
||||||
tmp = self.get_models_forecasts(sample)
|
tmp = self.get_models_forecasts(sample)
|
||||||
point = self.get_point(tmp)
|
point = self.get_point(tmp)
|
||||||
ret.append(point)
|
ret.append(point)
|
||||||
|
Loading…
Reference in New Issue
Block a user