Compare commits

..

No commits in common. "bc0c132a35afee23d8d4b2605408c6efc7271ae4" and "5817314562530ac170e16eedb0caf92fd9ed12f2" have entirely different histories.

3 changed files with 4 additions and 5 deletions

View File

@ -46,9 +46,8 @@ class Dataset:
random_state: int,
is_classification: bool = False,
) -> SplittedDataset:
target = params.target or data.columns[-1]
X = data.drop([params.target], axis=1)
y = data[[target]]
y = data[[params.target]]
stratify = None if not is_classification else y
X_train, X_test, y_train, y_test = train_test_split(
X,

View File

@ -8,7 +8,7 @@ class DatasetUploadDto(Schema):
class DatasetDto(Schema):
input = fields.List(fields.String(), load_default=None)
target = fields.String(load_default=None)
target = fields.String(required=True)
sep = fields.String(load_default=",")
decimal = fields.String(load_default=".")
train_volume = fields.Float(load_default=0.8, validate=Range(min=0.1, max=0.9))

View File

@ -35,7 +35,7 @@ def get_rules(
rules.append(Rule(antecedent, value)) # type: ignore
else:
index = np.argmax(tree_.value[node][0])
value = str(classes[index])
value = float(classes[index])
rules.append(Rule(antecedent, value)) # type: ignore
recurse(0, antecedent, rules)
@ -87,7 +87,7 @@ def get_tree(
nodes.append(TreeNode(parent, None, node, "result", "=", value))
else:
index = np.argmax(tree_.value[node][0])
value = str(classes[index])
value = float(classes[index])
nodes.append(TreeNode(parent, None, node, "result", "=", value))
recurse(0, None, nodes)