Small bugfix in GranularWMVFTS

This commit is contained in:
Petrônio Cândido 2019-04-12 12:18:45 -03:00
parent dca52fb9ed
commit 7319bce515
3 changed files with 5 additions and 9 deletions

View File

@ -54,7 +54,7 @@ def fuzzyfy_instance_clustered(data_point, cluster, **kwargs):
alpha_cut = kwargs.get('alpha_cut', 0.0) alpha_cut = kwargs.get('alpha_cut', 0.0)
mode = kwargs.get('mode', 'sets') mode = kwargs.get('mode', 'sets')
fsets = [] fsets = []
for fset in cluster.search(data_point): for fset in cluster.search(data_point, type='name'):
if cluster.sets[fset].membership(data_point) > alpha_cut: if cluster.sets[fset].membership(data_point) > alpha_cut:
if mode == 'sets': if mode == 'sets':
fsets.append(fset) fsets.append(fset)

View File

@ -48,13 +48,9 @@ class IncrementalGridCluster(partitioner.MultivariatePartitioner):
return ret return ret
if self.kdtree is not None: if self.kdtree is not None:
fsets = self.search(data, **kwargs) fsets = self.search(data, type='name')
else: else:
fsets = self.incremental_search(data, **kwargs) fsets = self.incremental_search(data, type='name')
if len(fsets) == 0:
fsets = self.incremental_search(data, **kwargs)
raise Exception("{}".format(data))
mode = kwargs.get('mode', 'sets') mode = kwargs.get('mode', 'sets')
if mode == 'sets': if mode == 'sets':

View File

@ -188,7 +188,7 @@ vavg = variable.Variable("Radiation", data_label="glo_avg", alias='rad',
from pyFTS.models.multivariate import mvfts, wmvfts, cmvfts, grid from pyFTS.models.multivariate import mvfts, wmvfts, cmvfts, grid
fs = grid.IncrementalGridCluster(explanatory_variables=[vmonth, vhour, vavg], target_variable=vavg) fs = grid.GridCluster(explanatory_variables=[vmonth, vhour, vavg], target_variable=vavg)
model = cmvfts.ClusteredMVFTS(explanatory_variables=[vmonth, vhour, vavg], target_variable=vavg, model = cmvfts.ClusteredMVFTS(explanatory_variables=[vmonth, vhour, vavg], target_variable=vavg,
@ -196,4 +196,4 @@ model = cmvfts.ClusteredMVFTS(explanatory_variables=[vmonth, vhour, vavg], targe
model.fit(train) model.fit(train)
print(len(model)) model.predict(test)