Compare commits

..

No commits in common. "bc0c132a35afee23d8d4b2605408c6efc7271ae4" and "5817314562530ac170e16eedb0caf92fd9ed12f2" have entirely different histories.

3 changed files with 4 additions and 5 deletions

View File

@ -46,9 +46,8 @@ class Dataset:
random_state: int, random_state: int,
is_classification: bool = False, is_classification: bool = False,
) -> SplittedDataset: ) -> SplittedDataset:
target = params.target or data.columns[-1]
X = data.drop([params.target], axis=1) X = data.drop([params.target], axis=1)
y = data[[target]] y = data[[params.target]]
stratify = None if not is_classification else y stratify = None if not is_classification else y
X_train, X_test, y_train, y_test = train_test_split( X_train, X_test, y_train, y_test = train_test_split(
X, X,

View File

@ -8,7 +8,7 @@ class DatasetUploadDto(Schema):
class DatasetDto(Schema): class DatasetDto(Schema):
input = fields.List(fields.String(), load_default=None) input = fields.List(fields.String(), load_default=None)
target = fields.String(load_default=None) target = fields.String(required=True)
sep = fields.String(load_default=",") sep = fields.String(load_default=",")
decimal = fields.String(load_default=".") decimal = fields.String(load_default=".")
train_volume = fields.Float(load_default=0.8, validate=Range(min=0.1, max=0.9)) train_volume = fields.Float(load_default=0.8, validate=Range(min=0.1, max=0.9))

View File

@ -35,7 +35,7 @@ def get_rules(
rules.append(Rule(antecedent, value)) # type: ignore rules.append(Rule(antecedent, value)) # type: ignore
else: else:
index = np.argmax(tree_.value[node][0]) index = np.argmax(tree_.value[node][0])
value = str(classes[index]) value = float(classes[index])
rules.append(Rule(antecedent, value)) # type: ignore rules.append(Rule(antecedent, value)) # type: ignore
recurse(0, antecedent, rules) recurse(0, antecedent, rules)
@ -87,7 +87,7 @@ def get_tree(
nodes.append(TreeNode(parent, None, node, "result", "=", value)) nodes.append(TreeNode(parent, None, node, "result", "=", value))
else: else:
index = np.argmax(tree_.value[node][0]) index = np.argmax(tree_.value[node][0])
value = str(classes[index]) value = float(classes[index])
nodes.append(TreeNode(parent, None, node, "result", "=", value)) nodes.append(TreeNode(parent, None, node, "result", "=", value))
recurse(0, None, nodes) recurse(0, None, nodes)