Fix classification tree and rules generation

This commit is contained in:
Aleksey Filippov 2025-03-11 21:55:13 +04:00
parent a4bdf7c88c
commit 6919b5c4a4
2 changed files with 27 additions and 26 deletions

View File

@ -15,6 +15,7 @@ def fit_classification_model(
fitted_model = model.fit(data.X_train.values, data.y_train.values.ravel()) fitted_model = model.fit(data.X_train.values, data.y_train.values.ravel())
y = (data.y_train, fitted_model.predict(data.X_train.values)) y = (data.y_train, fitted_model.predict(data.X_train.values))
y_pred = (data.y_test, fitted_model.predict(data.X_test.values)) y_pred = (data.y_test, fitted_model.predict(data.X_test.values))
classes = fitted_model.classes_.tolist() # type: ignore
return ClassificationResult( return ClassificationResult(
precision=get_metric(classification.precision, y, y_pred), precision=get_metric(classification.precision, y, y_pred),
recall=get_metric(classification.recall, y, y_pred), recall=get_metric(classification.recall, y, y_pred),
@ -22,6 +23,6 @@ def fit_classification_model(
f1=get_metric(classification.f1, y, y_pred), f1=get_metric(classification.f1, y, y_pred),
mcc=get_metric(classification.mcc, y, y_pred), mcc=get_metric(classification.mcc, y, y_pred),
cohen_kappa=get_metric(classification.cohen_kappa, y, y_pred), cohen_kappa=get_metric(classification.cohen_kappa, y, y_pred),
rules=get_rules(fitted_model, list(data.X_train.columns)), rules=get_rules(fitted_model, list(data.X_train.columns), classes),
tree=get_tree(fitted_model, list(data.X_train.columns)), tree=get_tree(fitted_model, list(data.X_train.columns), classes),
) )

View File

@ -23,7 +23,7 @@ def get_rules(
if tree_.feature[node] != TREE_UNDEFINED: if tree_.feature[node] != TREE_UNDEFINED:
name = feature_name[node] name = feature_name[node]
threshold = tree_.threshold[node] threshold = float(tree_.threshold[node])
p1, p2 = list(antecedent), list(antecedent) p1, p2 = list(antecedent), list(antecedent)
p1.append(RuleAtom(name, ComparisonType.LESS.value, threshold)) p1.append(RuleAtom(name, ComparisonType.LESS.value, threshold))
recurse(tree_.children_left[node], p1, rules) recurse(tree_.children_left[node], p1, rules)
@ -31,10 +31,12 @@ def get_rules(
recurse(tree_.children_right[node], p2, rules) recurse(tree_.children_right[node], p2, rules)
else: else:
if classes is None: if classes is None:
rules.append(Rule(antecedent, tree_.value[node][0][0])) # type: ignore value = float(tree_.value[node][0][0])
rules.append(Rule(antecedent, value)) # type: ignore
else: else:
value = np.argmax(tree_.value[node][0]) index = np.argmax(tree_.value[node][0])
rules.append(Rule(antecedent, classes[value])) # type: ignore value = float(classes[index])
rules.append(Rule(antecedent, value)) # type: ignore
recurse(0, antecedent, rules) recurse(0, antecedent, rules)
@ -58,37 +60,35 @@ def get_tree(
parent: str | None = None if parent_node is None else parent_node.name parent: str | None = None if parent_node is None else parent_node.name
if tree_.feature[node] != TREE_UNDEFINED: if tree_.feature[node] != TREE_UNDEFINED:
feature = feature_name[node] feature = feature_name[node]
threshold = tree_.threshold[node] threshold = float(tree_.threshold[node])
p1 = TreeNode( p1 = TreeNode(
parent, parent=parent,
str(uuid.uuid4()), name=str(uuid.uuid4()),
node, level=node,
feature, variable=feature,
ComparisonType.LESS.value, type=ComparisonType.LESS.value,
threshold, value=threshold,
) )
recurse(tree_.children_left[node], p1, nodes) recurse(tree_.children_left[node], p1, nodes)
p2 = TreeNode( p2 = TreeNode(
parent, parent=parent,
str(uuid.uuid4()), name=str(uuid.uuid4()),
node, level=node,
feature, variable=feature,
ComparisonType.GREATER.value, type=ComparisonType.GREATER.value,
threshold, value=threshold,
) )
nodes.append(p1) nodes.append(p1)
nodes.append(p2) nodes.append(p2)
recurse(tree_.children_right[node], p2, nodes) recurse(tree_.children_right[node], p2, nodes)
else: else:
if classes is None: if classes is None:
nodes.append( value = float(tree_.value[node][0][0])
TreeNode(parent, None, node, "result", "=", tree_.value[node][0][0]) nodes.append(TreeNode(parent, None, node, "result", "=", value))
)
else: else:
value = np.argmax(tree_.value[node][0]) index = np.argmax(tree_.value[node][0])
nodes.append( value = float(classes[index])
TreeNode(parent, None, node, "result", "=", classes[value]) nodes.append(TreeNode(parent, None, node, "result", "=", value))
)
recurse(0, None, nodes) recurse(0, None, nodes)