Bugfix in cmvfts
This commit is contained in:
parent
75e69a1ae1
commit
ef71a86a7f
@ -345,7 +345,10 @@ def get_point_statistics(data, model, **kwargs):
|
|||||||
if not isinstance(forecasts, (list, np.ndarray)):
|
if not isinstance(forecasts, (list, np.ndarray)):
|
||||||
forecasts = [forecasts]
|
forecasts = [forecasts]
|
||||||
|
|
||||||
|
if len(forecasts) != len(ndata) - model.max_lag:
|
||||||
forecasts = np.array(forecasts[:-1])
|
forecasts = np.array(forecasts[:-1])
|
||||||
|
else:
|
||||||
|
forecasts = np.array(forecasts)
|
||||||
|
|
||||||
ret.append(np.round(rmse(ndata[model.max_lag:], forecasts), 2))
|
ret.append(np.round(rmse(ndata[model.max_lag:], forecasts), 2))
|
||||||
ret.append(np.round(mape(ndata[model.max_lag:], forecasts), 2))
|
ret.append(np.round(mape(ndata[model.max_lag:], forecasts), 2))
|
||||||
|
@ -175,7 +175,7 @@ class HighOrderFTS(fts.FTS):
|
|||||||
for flrg in flrgs:
|
for flrg in flrgs:
|
||||||
|
|
||||||
if flrg.get_key() not in self.flrgs:
|
if flrg.get_key() not in self.flrgs:
|
||||||
self.flrgs[flrg.get_key()] = flrg;
|
self.flrgs[flrg.get_key()] = flrg
|
||||||
|
|
||||||
for st in rhs:
|
for st in rhs:
|
||||||
self.flrgs[flrg.get_key()].append_rhs(st)
|
self.flrgs[flrg.get_key()].append_rhs(st)
|
||||||
|
@ -19,7 +19,7 @@ class ClusteredMVFTS(mvfts.MVFTS):
|
|||||||
self.cluster = None
|
self.cluster = None
|
||||||
"""The most recent trained clusterer"""
|
"""The most recent trained clusterer"""
|
||||||
|
|
||||||
self.fts_method = kwargs.get('fts_method', hofts.HighOrderFTS)
|
self.fts_method = kwargs.get('fts_method', hofts.WeightedHighOrderFTS)
|
||||||
"""The FTS method to be called when a new model is build"""
|
"""The FTS method to be called when a new model is build"""
|
||||||
self.fts_params = kwargs.get('fts_params', {})
|
self.fts_params = kwargs.get('fts_params', {})
|
||||||
"""The FTS method specific parameters"""
|
"""The FTS method specific parameters"""
|
||||||
|
@ -31,6 +31,8 @@ model.fit(dataset) #[22, 22, 23, 23, 24])
|
|||||||
|
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
|
Measures.get_point_statistics(dataset, model)
|
||||||
|
|
||||||
'''
|
'''
|
||||||
#dataset = SP500.get_data()[11500:16000]
|
#dataset = SP500.get_data()[11500:16000]
|
||||||
#dataset = NASDAQ.get_data()
|
#dataset = NASDAQ.get_data()
|
||||||
|
@ -100,7 +100,7 @@ model1.append_variable(vprice)
|
|||||||
model1.target_variable = vprice
|
model1.target_variable = vprice
|
||||||
model1.fit(train_mv)
|
model1.fit(train_mv)
|
||||||
|
|
||||||
#print(model1)
|
print(model1)
|
||||||
|
|
||||||
print(Measures.get_point_statistics(test_mv, model1))
|
print(Measures.get_point_statistics(test_mv, model1))
|
||||||
#"""
|
#"""
|
Loading…
Reference in New Issue
Block a user