Compare commits
No commits in common. "bc0c132a35afee23d8d4b2605408c6efc7271ae4" and "5817314562530ac170e16eedb0caf92fd9ed12f2" have entirely different histories.
bc0c132a35
...
5817314562
@ -46,9 +46,8 @@ class Dataset:
|
|||||||
random_state: int,
|
random_state: int,
|
||||||
is_classification: bool = False,
|
is_classification: bool = False,
|
||||||
) -> SplittedDataset:
|
) -> SplittedDataset:
|
||||||
target = params.target or data.columns[-1]
|
|
||||||
X = data.drop([params.target], axis=1)
|
X = data.drop([params.target], axis=1)
|
||||||
y = data[[target]]
|
y = data[[params.target]]
|
||||||
stratify = None if not is_classification else y
|
stratify = None if not is_classification else y
|
||||||
X_train, X_test, y_train, y_test = train_test_split(
|
X_train, X_test, y_train, y_test = train_test_split(
|
||||||
X,
|
X,
|
||||||
|
@ -8,7 +8,7 @@ class DatasetUploadDto(Schema):
|
|||||||
|
|
||||||
class DatasetDto(Schema):
|
class DatasetDto(Schema):
|
||||||
input = fields.List(fields.String(), load_default=None)
|
input = fields.List(fields.String(), load_default=None)
|
||||||
target = fields.String(load_default=None)
|
target = fields.String(required=True)
|
||||||
sep = fields.String(load_default=",")
|
sep = fields.String(load_default=",")
|
||||||
decimal = fields.String(load_default=".")
|
decimal = fields.String(load_default=".")
|
||||||
train_volume = fields.Float(load_default=0.8, validate=Range(min=0.1, max=0.9))
|
train_volume = fields.Float(load_default=0.8, validate=Range(min=0.1, max=0.9))
|
||||||
|
@ -35,7 +35,7 @@ def get_rules(
|
|||||||
rules.append(Rule(antecedent, value)) # type: ignore
|
rules.append(Rule(antecedent, value)) # type: ignore
|
||||||
else:
|
else:
|
||||||
index = np.argmax(tree_.value[node][0])
|
index = np.argmax(tree_.value[node][0])
|
||||||
value = str(classes[index])
|
value = float(classes[index])
|
||||||
rules.append(Rule(antecedent, value)) # type: ignore
|
rules.append(Rule(antecedent, value)) # type: ignore
|
||||||
|
|
||||||
recurse(0, antecedent, rules)
|
recurse(0, antecedent, rules)
|
||||||
@ -87,7 +87,7 @@ def get_tree(
|
|||||||
nodes.append(TreeNode(parent, None, node, "result", "=", value))
|
nodes.append(TreeNode(parent, None, node, "result", "=", value))
|
||||||
else:
|
else:
|
||||||
index = np.argmax(tree_.value[node][0])
|
index = np.argmax(tree_.value[node][0])
|
||||||
value = str(classes[index])
|
value = float(classes[index])
|
||||||
nodes.append(TreeNode(parent, None, node, "result", "=", value))
|
nodes.append(TreeNode(parent, None, node, "result", "=", value))
|
||||||
|
|
||||||
recurse(0, None, nodes)
|
recurse(0, None, nodes)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user