From b8cfea52786e4b8b887edc01fe94c0f0a4ff2390 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petr=C3=B4nio=20C=C3=A2ndido=20de=20Lima=20e=20Silva?= Date: Fri, 26 May 2023 14:13:41 -0300 Subject: [PATCH] Update wmvfts.py - class_weights --- pyFTS/models/multivariate/wmvfts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyFTS/models/multivariate/wmvfts.py b/pyFTS/models/multivariate/wmvfts.py index 4c075eb..a944970 100644 --- a/pyFTS/models/multivariate/wmvfts.py +++ b/pyFTS/models/multivariate/wmvfts.py @@ -69,6 +69,8 @@ class WeightedMVFTS(mvfts.MVFTS): self.shortname = "WeightedMVFTS" self.name = "Weighted Multivariate FTS" self.has_classification = True + self.class_weigths : dict = kwargs.get("class_weights", {}) + def generate_flrg(self, flrs): for flr in flrs: @@ -98,7 +100,7 @@ class WeightedMVFTS(mvfts.MVFTS): for k,v in _flrg.RHS.items(): classification[k] += (v / _flrg.count) * mb - classification = activation(classification) + classification = activation(classification, self.class_weigths) ret.append(classification)