diff --git a/src/rules.py b/src/rules.py index b990aa7..00a2e0f 100644 --- a/src/rules.py +++ b/src/rules.py @@ -71,7 +71,7 @@ class Rule: # https://mljar.com/blog/extract-rules-decision-tree/ -def get_rules(tree, feature_names) -> List[Rule]: +def get_rules(tree, feature_names, classes=None) -> List[Rule]: tree_ = tree.tree_ feature_name = [ feature_names[i] if i != TREE_UNDEFINED else "undefined!" for i in tree_.feature @@ -91,7 +91,11 @@ def get_rules(tree, feature_names) -> List[Rule]: p2.append(RuleAtom(name, ComparisonType.GREATER, threshold)) recurse(tree_.children_right[node], p2, rules) else: - rules.append(Rule(antecedent, tree_.value[node][0][0])) + if classes is None: + rules.append(Rule(antecedent, tree_.value[node][0][0])) + else: + value = np.argmax(tree_.value[node][0]) + rules.append(Rule(antecedent, classes[value])) recurse(0, antecedent, rules) @@ -302,6 +306,8 @@ def _get_fuzzy_rules( for rule in rules: antecedent = [] for atom in rule.get_antecedent(): + if fuzzy_variables.get(atom.get_varaible(), None) is None: + continue antecedent.append( _get_fuzzy_rule_atom( fuzzy_variables[atom.get_varaible()], atom.get_value()