99 lines
3.2 KiB
Python

import uuid
from typing import List
import numpy as np
from sklearn import tree
from sklearn.tree._tree import TREE_UNDEFINED # type: ignore
from backend.tree.model import ComparisonType, Rule, RuleAtom, TreeNode
def get_rules(
tree: tree.BaseDecisionTree, feature_names: List[str], classes=None
) -> List[Rule]:
tree_ = tree.tree_ # type: ignore
feature_name = [
feature_names[i] if i != TREE_UNDEFINED else "undefined!" for i in tree_.feature
]
rules: List[Rule] = []
antecedent: List[RuleAtom] = []
def recurse(node, antecedent, rules):
if tree_.feature[node] != TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
p1, p2 = list(antecedent), list(antecedent)
p1.append(RuleAtom(name, ComparisonType.LESS.value, threshold))
recurse(tree_.children_left[node], p1, rules)
p2.append(RuleAtom(name, ComparisonType.GREATER.value, threshold))
recurse(tree_.children_right[node], p2, rules)
else:
if classes is None:
rules.append(Rule(antecedent, tree_.value[node][0][0])) # type: ignore
else:
value = np.argmax(tree_.value[node][0])
rules.append(Rule(antecedent, classes[value])) # type: ignore
recurse(0, antecedent, rules)
# sort by antecedent length
samples_count = [len(rule.antecedent) for rule in rules]
sorted_index = list(np.argsort(samples_count))
return [rules[index] for index in sorted_index]
def get_tree(
tree: tree.BaseDecisionTree, feature_names: List[str], classes=None
) -> List[TreeNode]:
tree_ = tree.tree_ # type: ignore
feature_name = [
feature_names[i] if i != TREE_UNDEFINED else "undefined!" for i in tree_.feature
]
nodes: List[TreeNode] = []
def recurse(node, parent_node, nodes):
parent: str | None = None if parent_node is None else parent_node.name
if tree_.feature[node] != TREE_UNDEFINED:
feature = feature_name[node]
threshold = tree_.threshold[node]
p1 = TreeNode(
parent,
str(uuid.uuid4()),
node,
feature,
ComparisonType.LESS.value,
threshold,
)
recurse(tree_.children_left[node], p1, nodes)
p2 = TreeNode(
parent,
str(uuid.uuid4()),
node,
feature,
ComparisonType.GREATER.value,
threshold,
)
nodes.append(p1)
nodes.append(p2)
recurse(tree_.children_right[node], p2, nodes)
else:
if classes is None:
nodes.append(
TreeNode(parent, None, node, "result", "=", tree_.value[node][0][0])
)
else:
value = np.argmax(tree_.value[node][0])
nodes.append(
TreeNode(parent, None, node, "result", "=", classes[value])
)
recurse(0, None, nodes)
# sort by node level
levels = [node.level for node in nodes]
sorted_index = list(np.argsort(levels))
return [nodes[index] for index in sorted_index]