From d2725a94aae86f52127491cdc9df0366d3b01458 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petr=C3=B4nio=20C=C3=A2ndido?= Date: Fri, 12 Apr 2019 14:06:03 -0300 Subject: [PATCH] Small bugfix in GranularWMVFTS --- pyFTS/models/multivariate/cmvfts.py | 4 ++-- pyFTS/models/multivariate/granular.py | 3 ++- pyFTS/tests/multivariate.py | 13 ++++++------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pyFTS/models/multivariate/cmvfts.py b/pyFTS/models/multivariate/cmvfts.py index 173be22..4696174 100644 --- a/pyFTS/models/multivariate/cmvfts.py +++ b/pyFTS/models/multivariate/cmvfts.py @@ -44,9 +44,9 @@ class ClusteredMVFTS(mvfts.MVFTS): def train(self, data, **kwargs): + self.fts_params['order'] = self.order + self.model = self.fts_method(partitioner=self.partitioner, **self.fts_params) - if self.model.is_high_order: - self.model.order = self.order ndata = self.check_data(data) diff --git a/pyFTS/models/multivariate/granular.py b/pyFTS/models/multivariate/granular.py index a451c08..c7e31e5 100644 --- a/pyFTS/models/multivariate/granular.py +++ b/pyFTS/models/multivariate/granular.py @@ -21,6 +21,7 @@ class GranularWMVFTS(cmvfts.ClusteredMVFTS): def train(self, data, **kwargs): self.partitioner = grid.IncrementalGridCluster( explanatory_variables=self.explanatory_variables, - target_variable=self.target_variable) + target_variable=self.target_variable, + neighbors=self.knn) super(GranularWMVFTS, self).train(data,**kwargs) diff --git a/pyFTS/tests/multivariate.py b/pyFTS/tests/multivariate.py index bb3010d..2a45476 100644 --- a/pyFTS/tests/multivariate.py +++ b/pyFTS/tests/multivariate.py @@ -186,14 +186,13 @@ vavg = variable.Variable("Radiation", data_label="glo_avg", alias='rad', partitioner=Grid.GridPartitioner, npart=25, alpha_cut=.3, data=train) -from pyFTS.models.multivariate import mvfts, wmvfts, cmvfts, grid +from pyFTS.models.multivariate import mvfts, wmvfts, cmvfts, grid, granular -fs = grid.GridCluster(explanatory_variables=[vmonth, vhour, vavg], target_variable=vavg) - - -model = cmvfts.ClusteredMVFTS(explanatory_variables=[vmonth, vhour, vavg], target_variable=vavg, - partitioner=fs, knn=3) +model = granular.GranularWMVFTS(explanatory_variables=[vmonth, vhour, vavg], target_variable=vavg, + order=2, knn=7) model.fit(train) -model.predict(test) +print(model) + +#model.predict(test)