99 lines
3.4 KiB
Python

import uuid
from typing import List
import numpy as np
from sklearn import tree
from sklearn.tree._tree import TREE_UNDEFINED # type: ignore
from backend.tree.model import ComparisonType, Rule, RuleAtom, TreeNode
def get_rules(
tree: tree.BaseDecisionTree, feature_names: List[str], classes=None
) -> List[Rule]:
tree_ = tree.tree_ # type: ignore
feature_name = [
feature_names[i] if i != TREE_UNDEFINED else "undefined!" for i in tree_.feature
]
rules: List[Rule] = []
antecedent: List[RuleAtom] = []
def recurse(node, antecedent, rules):
if tree_.feature[node] != TREE_UNDEFINED:
name = feature_name[node]
threshold = float(tree_.threshold[node])
p1, p2 = list(antecedent), list(antecedent)
p1.append(RuleAtom(name, ComparisonType.LESS.value, threshold))
recurse(tree_.children_left[node], p1, rules)
p2.append(RuleAtom(name, ComparisonType.GREATER.value, threshold))
recurse(tree_.children_right[node], p2, rules)
else:
if classes is None:
value = float(tree_.value[node][0][0])
rules.append(Rule(antecedent, value)) # type: ignore
else:
index = np.argmax(tree_.value[node][0])
value = float(classes[index])
rules.append(Rule(antecedent, value)) # type: ignore
recurse(0, antecedent, rules)
# sort by antecedent length
samples_count = [len(rule.antecedent) for rule in rules]
sorted_index = list(np.argsort(samples_count))
return [rules[index] for index in sorted_index]
def get_tree(
tree: tree.BaseDecisionTree, feature_names: List[str], classes=None
) -> List[TreeNode]:
tree_ = tree.tree_ # type: ignore
feature_name = [
feature_names[i] if i != TREE_UNDEFINED else "undefined!" for i in tree_.feature
]
nodes: List[TreeNode] = []
def recurse(node, parent_node, nodes):
parent: str | None = None if parent_node is None else parent_node.name
if tree_.feature[node] != TREE_UNDEFINED:
feature = feature_name[node]
threshold = float(tree_.threshold[node])
p1 = TreeNode(
parent=parent,
name=str(uuid.uuid4()),
level=node,
variable=feature,
type=ComparisonType.LESS.value,
value=threshold,
)
recurse(tree_.children_left[node], p1, nodes)
p2 = TreeNode(
parent=parent,
name=str(uuid.uuid4()),
level=node,
variable=feature,
type=ComparisonType.GREATER.value,
value=threshold,
)
nodes.append(p1)
nodes.append(p2)
recurse(tree_.children_right[node], p2, nodes)
else:
if classes is None:
value = float(tree_.value[node][0][0])
nodes.append(TreeNode(parent, None, node, "result", "=", value))
else:
index = np.argmax(tree_.value[node][0])
value = float(classes[index])
nodes.append(TreeNode(parent, None, node, "result", "=", value))
recurse(0, None, nodes)
# sort by node level
levels = [node.level for node in nodes]
sorted_index = list(np.argsort(levels))
return [nodes[index] for index in sorted_index]