Compare commits

..

2 Commits

Author SHA1 Message Date
bc0c132a35 Add default dataset target value 2025-03-12 13:31:13 +04:00
f617f3e41c Fix classification value type 2025-03-12 13:30:58 +04:00
3 changed files with 5 additions and 4 deletions

View File

@ -46,8 +46,9 @@ class Dataset:
random_state: int,
is_classification: bool = False,
) -> SplittedDataset:
target = params.target or data.columns[-1]
X = data.drop([params.target], axis=1)
y = data[[params.target]]
y = data[[target]]
stratify = None if not is_classification else y
X_train, X_test, y_train, y_test = train_test_split(
X,

View File

@ -8,7 +8,7 @@ class DatasetUploadDto(Schema):
class DatasetDto(Schema):
input = fields.List(fields.String(), load_default=None)
target = fields.String(required=True)
target = fields.String(load_default=None)
sep = fields.String(load_default=",")
decimal = fields.String(load_default=".")
train_volume = fields.Float(load_default=0.8, validate=Range(min=0.1, max=0.9))

View File

@ -35,7 +35,7 @@ def get_rules(
rules.append(Rule(antecedent, value)) # type: ignore
else:
index = np.argmax(tree_.value[node][0])
value = float(classes[index])
value = str(classes[index])
rules.append(Rule(antecedent, value)) # type: ignore
recurse(0, antecedent, rules)
@ -87,7 +87,7 @@ def get_tree(
nodes.append(TreeNode(parent, None, node, "result", "=", value))
else:
index = np.argmax(tree_.value[node][0])
value = float(classes[index])
value = str(classes[index])
nodes.append(TreeNode(parent, None, node, "result", "=", value))
recurse(0, None, nodes)