Fix classification tree and rules generation
This commit is contained in:
parent
a4bdf7c88c
commit
6919b5c4a4
@ -15,6 +15,7 @@ def fit_classification_model(
|
|||||||
fitted_model = model.fit(data.X_train.values, data.y_train.values.ravel())
|
fitted_model = model.fit(data.X_train.values, data.y_train.values.ravel())
|
||||||
y = (data.y_train, fitted_model.predict(data.X_train.values))
|
y = (data.y_train, fitted_model.predict(data.X_train.values))
|
||||||
y_pred = (data.y_test, fitted_model.predict(data.X_test.values))
|
y_pred = (data.y_test, fitted_model.predict(data.X_test.values))
|
||||||
|
classes = fitted_model.classes_.tolist() # type: ignore
|
||||||
return ClassificationResult(
|
return ClassificationResult(
|
||||||
precision=get_metric(classification.precision, y, y_pred),
|
precision=get_metric(classification.precision, y, y_pred),
|
||||||
recall=get_metric(classification.recall, y, y_pred),
|
recall=get_metric(classification.recall, y, y_pred),
|
||||||
@ -22,6 +23,6 @@ def fit_classification_model(
|
|||||||
f1=get_metric(classification.f1, y, y_pred),
|
f1=get_metric(classification.f1, y, y_pred),
|
||||||
mcc=get_metric(classification.mcc, y, y_pred),
|
mcc=get_metric(classification.mcc, y, y_pred),
|
||||||
cohen_kappa=get_metric(classification.cohen_kappa, y, y_pred),
|
cohen_kappa=get_metric(classification.cohen_kappa, y, y_pred),
|
||||||
rules=get_rules(fitted_model, list(data.X_train.columns)),
|
rules=get_rules(fitted_model, list(data.X_train.columns), classes),
|
||||||
tree=get_tree(fitted_model, list(data.X_train.columns)),
|
tree=get_tree(fitted_model, list(data.X_train.columns), classes),
|
||||||
)
|
)
|
||||||
|
@ -23,7 +23,7 @@ def get_rules(
|
|||||||
|
|
||||||
if tree_.feature[node] != TREE_UNDEFINED:
|
if tree_.feature[node] != TREE_UNDEFINED:
|
||||||
name = feature_name[node]
|
name = feature_name[node]
|
||||||
threshold = tree_.threshold[node]
|
threshold = float(tree_.threshold[node])
|
||||||
p1, p2 = list(antecedent), list(antecedent)
|
p1, p2 = list(antecedent), list(antecedent)
|
||||||
p1.append(RuleAtom(name, ComparisonType.LESS.value, threshold))
|
p1.append(RuleAtom(name, ComparisonType.LESS.value, threshold))
|
||||||
recurse(tree_.children_left[node], p1, rules)
|
recurse(tree_.children_left[node], p1, rules)
|
||||||
@ -31,10 +31,12 @@ def get_rules(
|
|||||||
recurse(tree_.children_right[node], p2, rules)
|
recurse(tree_.children_right[node], p2, rules)
|
||||||
else:
|
else:
|
||||||
if classes is None:
|
if classes is None:
|
||||||
rules.append(Rule(antecedent, tree_.value[node][0][0])) # type: ignore
|
value = float(tree_.value[node][0][0])
|
||||||
|
rules.append(Rule(antecedent, value)) # type: ignore
|
||||||
else:
|
else:
|
||||||
value = np.argmax(tree_.value[node][0])
|
index = np.argmax(tree_.value[node][0])
|
||||||
rules.append(Rule(antecedent, classes[value])) # type: ignore
|
value = float(classes[index])
|
||||||
|
rules.append(Rule(antecedent, value)) # type: ignore
|
||||||
|
|
||||||
recurse(0, antecedent, rules)
|
recurse(0, antecedent, rules)
|
||||||
|
|
||||||
@ -58,37 +60,35 @@ def get_tree(
|
|||||||
parent: str | None = None if parent_node is None else parent_node.name
|
parent: str | None = None if parent_node is None else parent_node.name
|
||||||
if tree_.feature[node] != TREE_UNDEFINED:
|
if tree_.feature[node] != TREE_UNDEFINED:
|
||||||
feature = feature_name[node]
|
feature = feature_name[node]
|
||||||
threshold = tree_.threshold[node]
|
threshold = float(tree_.threshold[node])
|
||||||
p1 = TreeNode(
|
p1 = TreeNode(
|
||||||
parent,
|
parent=parent,
|
||||||
str(uuid.uuid4()),
|
name=str(uuid.uuid4()),
|
||||||
node,
|
level=node,
|
||||||
feature,
|
variable=feature,
|
||||||
ComparisonType.LESS.value,
|
type=ComparisonType.LESS.value,
|
||||||
threshold,
|
value=threshold,
|
||||||
)
|
)
|
||||||
recurse(tree_.children_left[node], p1, nodes)
|
recurse(tree_.children_left[node], p1, nodes)
|
||||||
p2 = TreeNode(
|
p2 = TreeNode(
|
||||||
parent,
|
parent=parent,
|
||||||
str(uuid.uuid4()),
|
name=str(uuid.uuid4()),
|
||||||
node,
|
level=node,
|
||||||
feature,
|
variable=feature,
|
||||||
ComparisonType.GREATER.value,
|
type=ComparisonType.GREATER.value,
|
||||||
threshold,
|
value=threshold,
|
||||||
)
|
)
|
||||||
nodes.append(p1)
|
nodes.append(p1)
|
||||||
nodes.append(p2)
|
nodes.append(p2)
|
||||||
recurse(tree_.children_right[node], p2, nodes)
|
recurse(tree_.children_right[node], p2, nodes)
|
||||||
else:
|
else:
|
||||||
if classes is None:
|
if classes is None:
|
||||||
nodes.append(
|
value = float(tree_.value[node][0][0])
|
||||||
TreeNode(parent, None, node, "result", "=", tree_.value[node][0][0])
|
nodes.append(TreeNode(parent, None, node, "result", "=", value))
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
value = np.argmax(tree_.value[node][0])
|
index = np.argmax(tree_.value[node][0])
|
||||||
nodes.append(
|
value = float(classes[index])
|
||||||
TreeNode(parent, None, node, "result", "=", classes[value])
|
nodes.append(TreeNode(parent, None, node, "result", "=", value))
|
||||||
)
|
|
||||||
|
|
||||||
recurse(0, None, nodes)
|
recurse(0, None, nodes)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user