import uuid from typing import List import numpy as np from sklearn import tree from sklearn.tree._tree import TREE_UNDEFINED # type: ignore from backend.tree.model import ComparisonType, Rule, RuleAtom, TreeNode def get_rules( tree: tree.BaseDecisionTree, feature_names: List[str], classes=None ) -> List[Rule]: tree_ = tree.tree_ # type: ignore feature_name = [ feature_names[i] if i != TREE_UNDEFINED else "undefined!" for i in tree_.feature ] rules: List[Rule] = [] antecedent: List[RuleAtom] = [] def recurse(node, antecedent, rules): if tree_.feature[node] != TREE_UNDEFINED: name = feature_name[node] threshold = float(tree_.threshold[node]) p1, p2 = list(antecedent), list(antecedent) p1.append(RuleAtom(name, ComparisonType.LESS.value, threshold)) recurse(tree_.children_left[node], p1, rules) p2.append(RuleAtom(name, ComparisonType.GREATER.value, threshold)) recurse(tree_.children_right[node], p2, rules) else: if classes is None: value = float(tree_.value[node][0][0]) rules.append(Rule(antecedent, value)) # type: ignore else: index = np.argmax(tree_.value[node][0]) value = float(classes[index]) rules.append(Rule(antecedent, value)) # type: ignore recurse(0, antecedent, rules) # sort by antecedent length samples_count = [len(rule.antecedent) for rule in rules] sorted_index = list(np.argsort(samples_count)) return [rules[index] for index in sorted_index] def get_tree( tree: tree.BaseDecisionTree, feature_names: List[str], classes=None ) -> List[TreeNode]: tree_ = tree.tree_ # type: ignore feature_name = [ feature_names[i] if i != TREE_UNDEFINED else "undefined!" for i in tree_.feature ] nodes: List[TreeNode] = [] def recurse(node, parent_node, nodes): parent: str | None = None if parent_node is None else parent_node.name if tree_.feature[node] != TREE_UNDEFINED: feature = feature_name[node] threshold = float(tree_.threshold[node]) p1 = TreeNode( parent=parent, name=str(uuid.uuid4()), level=node, variable=feature, type=ComparisonType.LESS.value, value=threshold, ) recurse(tree_.children_left[node], p1, nodes) p2 = TreeNode( parent=parent, name=str(uuid.uuid4()), level=node, variable=feature, type=ComparisonType.GREATER.value, value=threshold, ) nodes.append(p1) nodes.append(p2) recurse(tree_.children_right[node], p2, nodes) else: if classes is None: value = float(tree_.value[node][0][0]) nodes.append(TreeNode(parent, None, node, "result", "=", value)) else: index = np.argmax(tree_.value[node][0]) value = float(classes[index]) nodes.append(TreeNode(parent, None, node, "result", "=", value)) recurse(0, None, nodes) # sort by node level levels = [node.level for node in nodes] sorted_index = list(np.argsort(levels)) return [nodes[index] for index in sorted_index]