Bugfix in cmvfts

This commit is contained in:
Petrônio Cândido 2018-11-13 00:00:22 -02:00
parent 75e69a1ae1
commit ef71a86a7f
5 changed files with 9 additions and 4 deletions
pyFTS

View File

@ -345,7 +345,10 @@ def get_point_statistics(data, model, **kwargs):
if not isinstance(forecasts, (list, np.ndarray)): if not isinstance(forecasts, (list, np.ndarray)):
forecasts = [forecasts] forecasts = [forecasts]
if len(forecasts) != len(ndata) - model.max_lag:
forecasts = np.array(forecasts[:-1]) forecasts = np.array(forecasts[:-1])
else:
forecasts = np.array(forecasts)
ret.append(np.round(rmse(ndata[model.max_lag:], forecasts), 2)) ret.append(np.round(rmse(ndata[model.max_lag:], forecasts), 2))
ret.append(np.round(mape(ndata[model.max_lag:], forecasts), 2)) ret.append(np.round(mape(ndata[model.max_lag:], forecasts), 2))

View File

@ -175,7 +175,7 @@ class HighOrderFTS(fts.FTS):
for flrg in flrgs: for flrg in flrgs:
if flrg.get_key() not in self.flrgs: if flrg.get_key() not in self.flrgs:
self.flrgs[flrg.get_key()] = flrg; self.flrgs[flrg.get_key()] = flrg
for st in rhs: for st in rhs:
self.flrgs[flrg.get_key()].append_rhs(st) self.flrgs[flrg.get_key()].append_rhs(st)

View File

@ -19,7 +19,7 @@ class ClusteredMVFTS(mvfts.MVFTS):
self.cluster = None self.cluster = None
"""The most recent trained clusterer""" """The most recent trained clusterer"""
self.fts_method = kwargs.get('fts_method', hofts.HighOrderFTS) self.fts_method = kwargs.get('fts_method', hofts.WeightedHighOrderFTS)
"""The FTS method to be called when a new model is build""" """The FTS method to be called when a new model is build"""
self.fts_params = kwargs.get('fts_params', {}) self.fts_params = kwargs.get('fts_params', {})
"""The FTS method specific parameters""" """The FTS method specific parameters"""

View File

@ -31,6 +31,8 @@ model.fit(dataset) #[22, 22, 23, 23, 24])
print(model) print(model)
Measures.get_point_statistics(dataset, model)
''' '''
#dataset = SP500.get_data()[11500:16000] #dataset = SP500.get_data()[11500:16000]
#dataset = NASDAQ.get_data() #dataset = NASDAQ.get_data()

View File

@ -100,7 +100,7 @@ model1.append_variable(vprice)
model1.target_variable = vprice model1.target_variable = vprice
model1.fit(train_mv) model1.fit(train_mv)
#print(model1) print(model1)
print(Measures.get_point_statistics(test_mv, model1)) print(Measures.get_point_statistics(test_mv, model1))
#""" #"""