Fix classification tree and rules generation

This commit is contained in:
Aleksey Filippov 2025-03-11 21:55:13 +04:00
parent a4bdf7c88c
commit 6919b5c4a4
2 changed files with 27 additions and 26 deletions

View File

@ -15,6 +15,7 @@ def fit_classification_model(
fitted_model = model.fit(data.X_train.values, data.y_train.values.ravel())
y = (data.y_train, fitted_model.predict(data.X_train.values))
y_pred = (data.y_test, fitted_model.predict(data.X_test.values))
classes = fitted_model.classes_.tolist() # type: ignore
return ClassificationResult(
precision=get_metric(classification.precision, y, y_pred),
recall=get_metric(classification.recall, y, y_pred),
@ -22,6 +23,6 @@ def fit_classification_model(
f1=get_metric(classification.f1, y, y_pred),
mcc=get_metric(classification.mcc, y, y_pred),
cohen_kappa=get_metric(classification.cohen_kappa, y, y_pred),
rules=get_rules(fitted_model, list(data.X_train.columns)),
tree=get_tree(fitted_model, list(data.X_train.columns)),
rules=get_rules(fitted_model, list(data.X_train.columns), classes),
tree=get_tree(fitted_model, list(data.X_train.columns), classes),
)

View File

@ -23,7 +23,7 @@ def get_rules(
if tree_.feature[node] != TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
threshold = float(tree_.threshold[node])
p1, p2 = list(antecedent), list(antecedent)
p1.append(RuleAtom(name, ComparisonType.LESS.value, threshold))
recurse(tree_.children_left[node], p1, rules)
@ -31,10 +31,12 @@ def get_rules(
recurse(tree_.children_right[node], p2, rules)
else:
if classes is None:
rules.append(Rule(antecedent, tree_.value[node][0][0])) # type: ignore
value = float(tree_.value[node][0][0])
rules.append(Rule(antecedent, value)) # type: ignore
else:
value = np.argmax(tree_.value[node][0])
rules.append(Rule(antecedent, classes[value])) # type: ignore
index = np.argmax(tree_.value[node][0])
value = float(classes[index])
rules.append(Rule(antecedent, value)) # type: ignore
recurse(0, antecedent, rules)
@ -58,37 +60,35 @@ def get_tree(
parent: str | None = None if parent_node is None else parent_node.name
if tree_.feature[node] != TREE_UNDEFINED:
feature = feature_name[node]
threshold = tree_.threshold[node]
threshold = float(tree_.threshold[node])
p1 = TreeNode(
parent,
str(uuid.uuid4()),
node,
feature,
ComparisonType.LESS.value,
threshold,
parent=parent,
name=str(uuid.uuid4()),
level=node,
variable=feature,
type=ComparisonType.LESS.value,
value=threshold,
)
recurse(tree_.children_left[node], p1, nodes)
p2 = TreeNode(
parent,
str(uuid.uuid4()),
node,
feature,
ComparisonType.GREATER.value,
threshold,
parent=parent,
name=str(uuid.uuid4()),
level=node,
variable=feature,
type=ComparisonType.GREATER.value,
value=threshold,
)
nodes.append(p1)
nodes.append(p2)
recurse(tree_.children_right[node], p2, nodes)
else:
if classes is None:
nodes.append(
TreeNode(parent, None, node, "result", "=", tree_.value[node][0][0])
)
value = float(tree_.value[node][0][0])
nodes.append(TreeNode(parent, None, node, "result", "=", value))
else:
value = np.argmax(tree_.value[node][0])
nodes.append(
TreeNode(parent, None, node, "result", "=", classes[value])
)
index = np.argmax(tree_.value[node][0])
value = float(classes[index])
nodes.append(TreeNode(parent, None, node, "result", "=", value))
recurse(0, None, nodes)