Update wmvfts.py - class_weights
This commit is contained in:
parent
02f6022a53
commit
b8cfea5278
@ -69,6 +69,8 @@ class WeightedMVFTS(mvfts.MVFTS):
|
|||||||
self.shortname = "WeightedMVFTS"
|
self.shortname = "WeightedMVFTS"
|
||||||
self.name = "Weighted Multivariate FTS"
|
self.name = "Weighted Multivariate FTS"
|
||||||
self.has_classification = True
|
self.has_classification = True
|
||||||
|
self.class_weigths : dict = kwargs.get("class_weights", {})
|
||||||
|
|
||||||
|
|
||||||
def generate_flrg(self, flrs):
|
def generate_flrg(self, flrs):
|
||||||
for flr in flrs:
|
for flr in flrs:
|
||||||
@ -98,7 +100,7 @@ class WeightedMVFTS(mvfts.MVFTS):
|
|||||||
for k,v in _flrg.RHS.items():
|
for k,v in _flrg.RHS.items():
|
||||||
classification[k] += (v / _flrg.count) * mb
|
classification[k] += (v / _flrg.count) * mb
|
||||||
|
|
||||||
classification = activation(classification)
|
classification = activation(classification, self.class_weigths)
|
||||||
|
|
||||||
ret.append(classification)
|
ret.append(classification)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user