Fix classification tree and rules generation
This commit is contained in:
parent
a4bdf7c88c
commit
6919b5c4a4
@ -15,6 +15,7 @@ def fit_classification_model(
|
||||
fitted_model = model.fit(data.X_train.values, data.y_train.values.ravel())
|
||||
y = (data.y_train, fitted_model.predict(data.X_train.values))
|
||||
y_pred = (data.y_test, fitted_model.predict(data.X_test.values))
|
||||
classes = fitted_model.classes_.tolist() # type: ignore
|
||||
return ClassificationResult(
|
||||
precision=get_metric(classification.precision, y, y_pred),
|
||||
recall=get_metric(classification.recall, y, y_pred),
|
||||
@ -22,6 +23,6 @@ def fit_classification_model(
|
||||
f1=get_metric(classification.f1, y, y_pred),
|
||||
mcc=get_metric(classification.mcc, y, y_pred),
|
||||
cohen_kappa=get_metric(classification.cohen_kappa, y, y_pred),
|
||||
rules=get_rules(fitted_model, list(data.X_train.columns)),
|
||||
tree=get_tree(fitted_model, list(data.X_train.columns)),
|
||||
rules=get_rules(fitted_model, list(data.X_train.columns), classes),
|
||||
tree=get_tree(fitted_model, list(data.X_train.columns), classes),
|
||||
)
|
||||
|
@ -23,7 +23,7 @@ def get_rules(
|
||||
|
||||
if tree_.feature[node] != TREE_UNDEFINED:
|
||||
name = feature_name[node]
|
||||
threshold = tree_.threshold[node]
|
||||
threshold = float(tree_.threshold[node])
|
||||
p1, p2 = list(antecedent), list(antecedent)
|
||||
p1.append(RuleAtom(name, ComparisonType.LESS.value, threshold))
|
||||
recurse(tree_.children_left[node], p1, rules)
|
||||
@ -31,10 +31,12 @@ def get_rules(
|
||||
recurse(tree_.children_right[node], p2, rules)
|
||||
else:
|
||||
if classes is None:
|
||||
rules.append(Rule(antecedent, tree_.value[node][0][0])) # type: ignore
|
||||
value = float(tree_.value[node][0][0])
|
||||
rules.append(Rule(antecedent, value)) # type: ignore
|
||||
else:
|
||||
value = np.argmax(tree_.value[node][0])
|
||||
rules.append(Rule(antecedent, classes[value])) # type: ignore
|
||||
index = np.argmax(tree_.value[node][0])
|
||||
value = float(classes[index])
|
||||
rules.append(Rule(antecedent, value)) # type: ignore
|
||||
|
||||
recurse(0, antecedent, rules)
|
||||
|
||||
@ -58,37 +60,35 @@ def get_tree(
|
||||
parent: str | None = None if parent_node is None else parent_node.name
|
||||
if tree_.feature[node] != TREE_UNDEFINED:
|
||||
feature = feature_name[node]
|
||||
threshold = tree_.threshold[node]
|
||||
threshold = float(tree_.threshold[node])
|
||||
p1 = TreeNode(
|
||||
parent,
|
||||
str(uuid.uuid4()),
|
||||
node,
|
||||
feature,
|
||||
ComparisonType.LESS.value,
|
||||
threshold,
|
||||
parent=parent,
|
||||
name=str(uuid.uuid4()),
|
||||
level=node,
|
||||
variable=feature,
|
||||
type=ComparisonType.LESS.value,
|
||||
value=threshold,
|
||||
)
|
||||
recurse(tree_.children_left[node], p1, nodes)
|
||||
p2 = TreeNode(
|
||||
parent,
|
||||
str(uuid.uuid4()),
|
||||
node,
|
||||
feature,
|
||||
ComparisonType.GREATER.value,
|
||||
threshold,
|
||||
parent=parent,
|
||||
name=str(uuid.uuid4()),
|
||||
level=node,
|
||||
variable=feature,
|
||||
type=ComparisonType.GREATER.value,
|
||||
value=threshold,
|
||||
)
|
||||
nodes.append(p1)
|
||||
nodes.append(p2)
|
||||
recurse(tree_.children_right[node], p2, nodes)
|
||||
else:
|
||||
if classes is None:
|
||||
nodes.append(
|
||||
TreeNode(parent, None, node, "result", "=", tree_.value[node][0][0])
|
||||
)
|
||||
value = float(tree_.value[node][0][0])
|
||||
nodes.append(TreeNode(parent, None, node, "result", "=", value))
|
||||
else:
|
||||
value = np.argmax(tree_.value[node][0])
|
||||
nodes.append(
|
||||
TreeNode(parent, None, node, "result", "=", classes[value])
|
||||
)
|
||||
index = np.argmax(tree_.value[node][0])
|
||||
value = float(classes[index])
|
||||
nodes.append(TreeNode(parent, None, node, "result", "=", value))
|
||||
|
||||
recurse(0, None, nodes)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user