Update wmvfts.py - class_weights

This commit is contained in:
Petrônio Cândido de Lima e Silva 2023-05-26 14:13:41 -03:00 committed by GitHub
parent 02f6022a53
commit b8cfea5278
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -69,6 +69,8 @@ class WeightedMVFTS(mvfts.MVFTS):
self.shortname = "WeightedMVFTS"
self.name = "Weighted Multivariate FTS"
self.has_classification = True
self.class_weigths : dict = kwargs.get("class_weights", {})
def generate_flrg(self, flrs):
for flr in flrs:
@ -98,7 +100,7 @@ class WeightedMVFTS(mvfts.MVFTS):
for k,v in _flrg.RHS.items():
classification[k] += (v / _flrg.count) * mb
classification = activation(classification)
classification = activation(classification, self.class_weigths)
ret.append(classification)