Improvements on IFTS and WIFTS (forecast_interval_ahead)
This commit is contained in:
parent
e10aa1e872
commit
2149b4d041
@ -91,6 +91,12 @@ class FTS(object):
|
||||
|
||||
return best
|
||||
|
||||
def clip_uod(self, ndata):
|
||||
if self.uod_clip:
|
||||
ndata = np.clip(ndata, self.original_min, self.original_max)
|
||||
return ndata
|
||||
|
||||
|
||||
def predict(self, data, **kwargs):
|
||||
"""
|
||||
Forecast using trained model
|
||||
@ -116,8 +122,7 @@ class FTS(object):
|
||||
else:
|
||||
ndata = self.apply_transformations(data)
|
||||
|
||||
if self.uod_clip:
|
||||
ndata = np.clip(ndata, self.original_min, self.original_max)
|
||||
ndata = self.clip_uod(ndata)
|
||||
|
||||
if 'distributed' in kwargs:
|
||||
distributed = kwargs.pop('distributed')
|
||||
|
@ -52,17 +52,16 @@ class IntervalFTS(hofts.HighOrderFTS):
|
||||
mb = [fuzzySets[k].membership(data[k]) for k in np.arange(0, len(data))]
|
||||
return mb
|
||||
|
||||
|
||||
def forecast_interval(self, ndata, **kwargs):
|
||||
|
||||
ret = []
|
||||
|
||||
l = len(ndata)
|
||||
|
||||
if l <= self.order:
|
||||
if l < self.order:
|
||||
return ndata
|
||||
|
||||
for k in np.arange(self.max_lag, l):
|
||||
for k in np.arange(self.max_lag, l+1):
|
||||
|
||||
sample = ndata[k - self.max_lag: k]
|
||||
|
||||
@ -88,6 +87,16 @@ class IntervalFTS(hofts.HighOrderFTS):
|
||||
|
||||
return ret
|
||||
|
||||
def forecast_ahead_interval(self, data, steps, **kwargs):
|
||||
ret = [[x, x] for x in data[:self.max_lag]]
|
||||
for k in np.arange(self.max_lag, self.max_lag + steps):
|
||||
interval_lower = self.clip_uod(self.forecast_interval([x[0] for x in ret[k - self.max_lag: k]])[0])
|
||||
interval_upper = self.clip_uod(self.forecast_interval([x[1] for x in ret[k - self.max_lag: k]])[0])
|
||||
interval = [np.nanmin(interval_lower), np.nanmax(interval_upper)]
|
||||
ret.append(interval)
|
||||
|
||||
return ret[-steps:]
|
||||
|
||||
|
||||
class WeightedIntervalFTS(hofts.WeightedHighOrderFTS):
|
||||
"""
|
||||
@ -128,17 +137,15 @@ class WeightedIntervalFTS(hofts.WeightedHighOrderFTS):
|
||||
mb = [fuzzySets[k].membership(data[k]) for k in np.arange(0, len(data))]
|
||||
return mb
|
||||
|
||||
|
||||
def forecast_interval(self, ndata, **kwargs):
|
||||
|
||||
ret = []
|
||||
|
||||
l = len(ndata)
|
||||
|
||||
if l <= self.order:
|
||||
if l < self.order:
|
||||
return ndata
|
||||
|
||||
for k in np.arange(self.max_lag, l):
|
||||
for k in np.arange(self.max_lag, l+1):
|
||||
|
||||
sample = ndata[k - self.max_lag: k]
|
||||
|
||||
@ -163,3 +170,15 @@ class WeightedIntervalFTS(hofts.WeightedHighOrderFTS):
|
||||
ret.append([lo_, up_])
|
||||
|
||||
return ret
|
||||
|
||||
def forecast_ahead_interval(self, data, steps, **kwargs):
|
||||
ret = [[x, x] for x in data[:self.max_lag]]
|
||||
for k in np.arange(self.max_lag, self.max_lag + steps):
|
||||
interval_lower = self.clip_uod(self.forecast_interval([x[0] for x in ret[k - self.max_lag: k]])[0])
|
||||
interval_upper = self.clip_uod(self.forecast_interval([x[1] for x in ret[k - self.max_lag: k]])[0])
|
||||
interval = [np.nanmin(interval_lower), np.nanmax(interval_upper)]
|
||||
ret.append(interval)
|
||||
|
||||
return ret[-steps:]
|
||||
|
||||
|
||||
|
@ -61,18 +61,19 @@ methods_parameters = [
|
||||
{'order':2 }
|
||||
]
|
||||
|
||||
for dataset_name, dataset in datasets.items():
|
||||
bchmk.sliding_window_benchmarks2(dataset, 1000, train=0.8, inc=0.2,
|
||||
benchmark_models=True,
|
||||
#for dataset_name, dataset in datasets.items():
|
||||
bchmk.sliding_window_benchmarks2(TAIEX.get_data()[:5000], 1000, train=0.8, inc=0.2,
|
||||
benchmark_models=False,
|
||||
benchmark_methods=methods,
|
||||
benchmark_methods_parameters=methods_parameters,
|
||||
methods=[],
|
||||
methods_parameters=[],
|
||||
methods=[ifts.IntervalFTS, ifts.WeightedIntervalFTS],
|
||||
methods_parameters=[{},{}],
|
||||
transformations=[None],
|
||||
orders=[3],
|
||||
orders=[1,2,3],
|
||||
steps_ahead=[10],
|
||||
partitions=[None],
|
||||
partitions=[33],
|
||||
type='interval',
|
||||
#distributed=True, nodes=['192.168.0.110', '192.168.0.107','192.168.0.106'],
|
||||
file="tmp.db", dataset=dataset_name, tag="experiments")
|
||||
#file="tmp.db", dataset=dataset_name, tag="experiments")
|
||||
file="tmp.db", dataset='TAIEX', tag="experiments")
|
||||
#'''
|
Loading…
Reference in New Issue
Block a user