diff --git a/dt-cart/backend/tree/__init__.py b/dt-cart/backend/tree/__init__.py index 1aee135..e36523e 100644 --- a/dt-cart/backend/tree/__init__.py +++ b/dt-cart/backend/tree/__init__.py @@ -35,7 +35,7 @@ def get_rules( rules.append(Rule(antecedent, value)) # type: ignore else: index = np.argmax(tree_.value[node][0]) - value = float(classes[index]) + value = str(classes[index]) rules.append(Rule(antecedent, value)) # type: ignore recurse(0, antecedent, rules) @@ -87,7 +87,7 @@ def get_tree( nodes.append(TreeNode(parent, None, node, "result", "=", value)) else: index = np.argmax(tree_.value[node][0]) - value = float(classes[index]) + value = str(classes[index]) nodes.append(TreeNode(parent, None, node, "result", "=", value)) recurse(0, None, nodes)