Add initial classification tree support, some minor fixes
This commit is contained in:
parent
9eb4492feb
commit
dec4025331
10
src/rules.py
10
src/rules.py
@ -71,7 +71,7 @@ class Rule:
|
||||
|
||||
|
||||
# https://mljar.com/blog/extract-rules-decision-tree/
|
||||
def get_rules(tree, feature_names) -> List[Rule]:
|
||||
def get_rules(tree, feature_names, classes=None) -> List[Rule]:
|
||||
tree_ = tree.tree_
|
||||
feature_name = [
|
||||
feature_names[i] if i != TREE_UNDEFINED else "undefined!" for i in tree_.feature
|
||||
@ -91,7 +91,11 @@ def get_rules(tree, feature_names) -> List[Rule]:
|
||||
p2.append(RuleAtom(name, ComparisonType.GREATER, threshold))
|
||||
recurse(tree_.children_right[node], p2, rules)
|
||||
else:
|
||||
rules.append(Rule(antecedent, tree_.value[node][0][0]))
|
||||
if classes is None:
|
||||
rules.append(Rule(antecedent, tree_.value[node][0][0]))
|
||||
else:
|
||||
value = np.argmax(tree_.value[node][0])
|
||||
rules.append(Rule(antecedent, classes[value]))
|
||||
|
||||
recurse(0, antecedent, rules)
|
||||
|
||||
@ -302,6 +306,8 @@ def _get_fuzzy_rules(
|
||||
for rule in rules:
|
||||
antecedent = []
|
||||
for atom in rule.get_antecedent():
|
||||
if fuzzy_variables.get(atom.get_varaible(), None) is None:
|
||||
continue
|
||||
antecedent.append(
|
||||
_get_fuzzy_rule_atom(
|
||||
fuzzy_variables[atom.get_varaible()], atom.get_value()
|
||||
|
Loading…
x
Reference in New Issue
Block a user