Small bugfix in GranularWMVFTS

This commit is contained in:
Petrônio Cândido 2019-04-12 14:06:03 -03:00
parent 7319bce515
commit d2725a94aa
3 changed files with 10 additions and 10 deletions

View File

@ -44,9 +44,9 @@ class ClusteredMVFTS(mvfts.MVFTS):
def train(self, data, **kwargs): def train(self, data, **kwargs):
self.fts_params['order'] = self.order
self.model = self.fts_method(partitioner=self.partitioner, **self.fts_params) self.model = self.fts_method(partitioner=self.partitioner, **self.fts_params)
if self.model.is_high_order:
self.model.order = self.order
ndata = self.check_data(data) ndata = self.check_data(data)

View File

@ -21,6 +21,7 @@ class GranularWMVFTS(cmvfts.ClusteredMVFTS):
def train(self, data, **kwargs): def train(self, data, **kwargs):
self.partitioner = grid.IncrementalGridCluster( self.partitioner = grid.IncrementalGridCluster(
explanatory_variables=self.explanatory_variables, explanatory_variables=self.explanatory_variables,
target_variable=self.target_variable) target_variable=self.target_variable,
neighbors=self.knn)
super(GranularWMVFTS, self).train(data,**kwargs) super(GranularWMVFTS, self).train(data,**kwargs)

View File

@ -186,14 +186,13 @@ vavg = variable.Variable("Radiation", data_label="glo_avg", alias='rad',
partitioner=Grid.GridPartitioner, npart=25, alpha_cut=.3, partitioner=Grid.GridPartitioner, npart=25, alpha_cut=.3,
data=train) data=train)
from pyFTS.models.multivariate import mvfts, wmvfts, cmvfts, grid from pyFTS.models.multivariate import mvfts, wmvfts, cmvfts, grid, granular
fs = grid.GridCluster(explanatory_variables=[vmonth, vhour, vavg], target_variable=vavg) model = granular.GranularWMVFTS(explanatory_variables=[vmonth, vhour, vavg], target_variable=vavg,
order=2, knn=7)
model = cmvfts.ClusteredMVFTS(explanatory_variables=[vmonth, vhour, vavg], target_variable=vavg,
partitioner=fs, knn=3)
model.fit(train) model.fit(train)
model.predict(test) print(model)
#model.predict(test)