diff --git a/dt-cart/backend/classification/__init__.py b/dt-cart/backend/classification/__init__.py index 5a00947..0ddeb19 100644 --- a/dt-cart/backend/classification/__init__.py +++ b/dt-cart/backend/classification/__init__.py @@ -15,6 +15,7 @@ def fit_classification_model( fitted_model = model.fit(data.X_train.values, data.y_train.values.ravel()) y = (data.y_train, fitted_model.predict(data.X_train.values)) y_pred = (data.y_test, fitted_model.predict(data.X_test.values)) + classes = fitted_model.classes_.tolist() # type: ignore return ClassificationResult( precision=get_metric(classification.precision, y, y_pred), recall=get_metric(classification.recall, y, y_pred), @@ -22,6 +23,6 @@ def fit_classification_model( f1=get_metric(classification.f1, y, y_pred), mcc=get_metric(classification.mcc, y, y_pred), cohen_kappa=get_metric(classification.cohen_kappa, y, y_pred), - rules=get_rules(fitted_model, list(data.X_train.columns)), - tree=get_tree(fitted_model, list(data.X_train.columns)), + rules=get_rules(fitted_model, list(data.X_train.columns), classes), + tree=get_tree(fitted_model, list(data.X_train.columns), classes), ) diff --git a/dt-cart/backend/tree/__init__.py b/dt-cart/backend/tree/__init__.py index e12563b..1aee135 100644 --- a/dt-cart/backend/tree/__init__.py +++ b/dt-cart/backend/tree/__init__.py @@ -23,7 +23,7 @@ def get_rules( if tree_.feature[node] != TREE_UNDEFINED: name = feature_name[node] - threshold = tree_.threshold[node] + threshold = float(tree_.threshold[node]) p1, p2 = list(antecedent), list(antecedent) p1.append(RuleAtom(name, ComparisonType.LESS.value, threshold)) recurse(tree_.children_left[node], p1, rules) @@ -31,10 +31,12 @@ def get_rules( recurse(tree_.children_right[node], p2, rules) else: if classes is None: - rules.append(Rule(antecedent, tree_.value[node][0][0])) # type: ignore + value = float(tree_.value[node][0][0]) + rules.append(Rule(antecedent, value)) # type: ignore else: - value = np.argmax(tree_.value[node][0]) - rules.append(Rule(antecedent, classes[value])) # type: ignore + index = np.argmax(tree_.value[node][0]) + value = float(classes[index]) + rules.append(Rule(antecedent, value)) # type: ignore recurse(0, antecedent, rules) @@ -58,37 +60,35 @@ def get_tree( parent: str | None = None if parent_node is None else parent_node.name if tree_.feature[node] != TREE_UNDEFINED: feature = feature_name[node] - threshold = tree_.threshold[node] + threshold = float(tree_.threshold[node]) p1 = TreeNode( - parent, - str(uuid.uuid4()), - node, - feature, - ComparisonType.LESS.value, - threshold, + parent=parent, + name=str(uuid.uuid4()), + level=node, + variable=feature, + type=ComparisonType.LESS.value, + value=threshold, ) recurse(tree_.children_left[node], p1, nodes) p2 = TreeNode( - parent, - str(uuid.uuid4()), - node, - feature, - ComparisonType.GREATER.value, - threshold, + parent=parent, + name=str(uuid.uuid4()), + level=node, + variable=feature, + type=ComparisonType.GREATER.value, + value=threshold, ) nodes.append(p1) nodes.append(p2) recurse(tree_.children_right[node], p2, nodes) else: if classes is None: - nodes.append( - TreeNode(parent, None, node, "result", "=", tree_.value[node][0][0]) - ) + value = float(tree_.value[node][0][0]) + nodes.append(TreeNode(parent, None, node, "result", "=", value)) else: - value = np.argmax(tree_.value[node][0]) - nodes.append( - TreeNode(parent, None, node, "result", "=", classes[value]) - ) + index = np.argmax(tree_.value[node][0]) + value = float(classes[index]) + nodes.append(TreeNode(parent, None, node, "result", "=", value)) recurse(0, None, nodes)