diff --git a/pyFTS/models/multivariate/wmvfts.py b/pyFTS/models/multivariate/wmvfts.py index 4c075eb..a944970 100644 --- a/pyFTS/models/multivariate/wmvfts.py +++ b/pyFTS/models/multivariate/wmvfts.py @@ -69,6 +69,8 @@ class WeightedMVFTS(mvfts.MVFTS): self.shortname = "WeightedMVFTS" self.name = "Weighted Multivariate FTS" self.has_classification = True + self.class_weigths : dict = kwargs.get("class_weights", {}) + def generate_flrg(self, flrs): for flr in flrs: @@ -98,7 +100,7 @@ class WeightedMVFTS(mvfts.MVFTS): for k,v in _flrg.RHS.items(): classification[k] += (v / _flrg.count) * mb - classification = activation(classification) + classification = activation(classification, self.class_weigths) ret.append(classification)