Improvements on MVFTS.forecast_ahead for supporting CMVFTS many steps ahead forecasts

This commit is contained in:
Petrônio Cândido 2019-06-17 11:40:24 -03:00
parent 8818cfe529
commit 6e32f1ecb9
2 changed files with 4 additions and 6 deletions

View File

@ -55,7 +55,7 @@ def fuzzyfy_instance_clustered(data_point, cluster, **kwargs):
mode = kwargs.get('mode', 'sets') mode = kwargs.get('mode', 'sets')
fsets = [] fsets = []
for fset in cluster.search(data_point, type='name'): for fset in cluster.search(data_point, type='name'):
if cluster.sets[fset].membership(data_point) > alpha_cut: if cluster.sets[fset].membership(data_point) >= alpha_cut:
if mode == 'sets': if mode == 'sets':
fsets.append(fset) fsets.append(fset)
elif mode =='both': elif mode =='both':

View File

@ -172,6 +172,7 @@ class MVFTS(fts.FTS):
start = kwargs.get('start_at', self.max_lag) start = kwargs.get('start_at', self.max_lag)
ndata = ndata.loc[ndata.index[start-self.max_lag:start]]
ret = [] ret = []
for k in np.arange(start, start+steps): for k in np.arange(start, start+steps):
ix = ndata.index[k-self.max_lag:k] ix = ndata.index[k-self.max_lag:k]
@ -188,11 +189,11 @@ class MVFTS(fts.FTS):
for data_label in generators.keys(): for data_label in generators.keys():
if data_label != self.target_variable.data_label: if data_label != self.target_variable.data_label:
if isinstance(generators[data_label], LambdaType): if isinstance(generators[data_label], LambdaType):
last_data_point = ndata.loc[sample.index[-1]] last_data_point = ndata.loc[ndata.index[-1]]
new_data_point[data_label] = generators[data_label](last_data_point[data_label]) new_data_point[data_label] = generators[data_label](last_data_point[data_label])
elif isinstance(generators[data_label], fts.FTS): elif isinstance(generators[data_label], fts.FTS):
model = generators[data_label] model = generators[data_label]
last_data_point = ndata.loc[[sample.index[-model.order]]] last_data_point = ndata.loc[[ndata.index[-model.order]]]
if not model.is_multivariate: if not model.is_multivariate:
last_data_point = last_data_point[data_label].values last_data_point = last_data_point[data_label].values
@ -200,9 +201,6 @@ class MVFTS(fts.FTS):
new_data_point[self.target_variable.data_label] = tmp new_data_point[self.target_variable.data_label] = tmp
print(k)
print(new_data_point)
ndata = ndata.append(new_data_point, ignore_index=True) ndata = ndata.append(new_data_point, ignore_index=True)
return ret[-steps:] return ret[-steps:]