99 lines
3.4 KiB
Python
99 lines
3.4 KiB
Python
import uuid
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
from sklearn import tree
|
|
from sklearn.tree._tree import TREE_UNDEFINED # type: ignore
|
|
|
|
from backend.tree.model import ComparisonType, Rule, RuleAtom, TreeNode
|
|
|
|
|
|
def get_rules(
|
|
tree: tree.BaseDecisionTree, feature_names: List[str], classes=None
|
|
) -> List[Rule]:
|
|
tree_ = tree.tree_ # type: ignore
|
|
feature_name = [
|
|
feature_names[i] if i != TREE_UNDEFINED else "undefined!" for i in tree_.feature
|
|
]
|
|
|
|
rules: List[Rule] = []
|
|
antecedent: List[RuleAtom] = []
|
|
|
|
def recurse(node, antecedent, rules):
|
|
|
|
if tree_.feature[node] != TREE_UNDEFINED:
|
|
name = feature_name[node]
|
|
threshold = float(tree_.threshold[node])
|
|
p1, p2 = list(antecedent), list(antecedent)
|
|
p1.append(RuleAtom(name, ComparisonType.LESS.value, threshold))
|
|
recurse(tree_.children_left[node], p1, rules)
|
|
p2.append(RuleAtom(name, ComparisonType.GREATER.value, threshold))
|
|
recurse(tree_.children_right[node], p2, rules)
|
|
else:
|
|
if classes is None:
|
|
value = float(tree_.value[node][0][0])
|
|
rules.append(Rule(antecedent, value)) # type: ignore
|
|
else:
|
|
index = np.argmax(tree_.value[node][0])
|
|
value = float(classes[index])
|
|
rules.append(Rule(antecedent, value)) # type: ignore
|
|
|
|
recurse(0, antecedent, rules)
|
|
|
|
# sort by antecedent length
|
|
samples_count = [len(rule.antecedent) for rule in rules]
|
|
sorted_index = list(np.argsort(samples_count))
|
|
return [rules[index] for index in sorted_index]
|
|
|
|
|
|
def get_tree(
|
|
tree: tree.BaseDecisionTree, feature_names: List[str], classes=None
|
|
) -> List[TreeNode]:
|
|
tree_ = tree.tree_ # type: ignore
|
|
feature_name = [
|
|
feature_names[i] if i != TREE_UNDEFINED else "undefined!" for i in tree_.feature
|
|
]
|
|
|
|
nodes: List[TreeNode] = []
|
|
|
|
def recurse(node, parent_node, nodes):
|
|
parent: str | None = None if parent_node is None else parent_node.name
|
|
if tree_.feature[node] != TREE_UNDEFINED:
|
|
feature = feature_name[node]
|
|
threshold = float(tree_.threshold[node])
|
|
p1 = TreeNode(
|
|
parent=parent,
|
|
name=str(uuid.uuid4()),
|
|
level=node,
|
|
variable=feature,
|
|
type=ComparisonType.LESS.value,
|
|
value=threshold,
|
|
)
|
|
recurse(tree_.children_left[node], p1, nodes)
|
|
p2 = TreeNode(
|
|
parent=parent,
|
|
name=str(uuid.uuid4()),
|
|
level=node,
|
|
variable=feature,
|
|
type=ComparisonType.GREATER.value,
|
|
value=threshold,
|
|
)
|
|
nodes.append(p1)
|
|
nodes.append(p2)
|
|
recurse(tree_.children_right[node], p2, nodes)
|
|
else:
|
|
if classes is None:
|
|
value = float(tree_.value[node][0][0])
|
|
nodes.append(TreeNode(parent, None, node, "result", "=", value))
|
|
else:
|
|
index = np.argmax(tree_.value[node][0])
|
|
value = float(classes[index])
|
|
nodes.append(TreeNode(parent, None, node, "result", "=", value))
|
|
|
|
recurse(0, None, nodes)
|
|
|
|
# sort by node level
|
|
levels = [node.level for node in nodes]
|
|
sorted_index = list(np.argsort(levels))
|
|
return [nodes[index] for index in sorted_index]
|