Add classification tree implementation, some refactoring
This commit is contained in:
parent
1682e25dcb
commit
a4bdf7c88c
@ -1,22 +1,37 @@
|
|||||||
from werkzeug.datastructures import FileStorage
|
from werkzeug.datastructures import FileStorage
|
||||||
|
|
||||||
from backend import api_bp, dataset_path, service
|
from backend import api_bp, dataset_path, service
|
||||||
|
from backend.classification.dto import ClassificationResultDto
|
||||||
from backend.dataset.dto import DatasetUploadDto
|
from backend.dataset.dto import DatasetUploadDto
|
||||||
from backend.dataset.model import DatasetParams
|
from backend.dataset.model import DatasetParams
|
||||||
from backend.dto import RegressionDto
|
from backend.dto import ClassificationDto, RegressionDto
|
||||||
from backend.regression.dto import RegressionResultDto
|
from backend.regression.dto import RegressionResultDto
|
||||||
from backend.regression.model import RegressionTreeParams
|
from backend.tree.model import DecisionTreeParams
|
||||||
|
|
||||||
|
|
||||||
@api_bp.post("/regression")
|
@api_bp.post("/regression")
|
||||||
@api_bp.input(DatasetUploadDto, location="files")
|
@api_bp.input(DatasetUploadDto, location="files")
|
||||||
@api_bp.input(RegressionDto, location="query")
|
@api_bp.input(RegressionDto, location="query")
|
||||||
@api_bp.output(RegressionResultDto)
|
@api_bp.output(RegressionResultDto)
|
||||||
def upload_dataset(files_data, query_data):
|
def regression(files_data, query_data):
|
||||||
uploaded_file: FileStorage = files_data["dataset"]
|
uploaded_file: FileStorage = files_data["dataset"]
|
||||||
schema = RegressionDto()
|
schema = RegressionDto()
|
||||||
dataset_params: DatasetParams = schema.get_dataset_params(query_data)
|
dataset_params: DatasetParams = schema.get_dataset_params(query_data)
|
||||||
tree_params: RegressionTreeParams = schema.get_tree_params(query_data)
|
tree_params: DecisionTreeParams = schema.get_tree_params(query_data)
|
||||||
return service.run_regression(
|
return service.run_regression(
|
||||||
dataset_path, uploaded_file, dataset_params, tree_params
|
dataset_path, uploaded_file, dataset_params, tree_params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_bp.post("/classification")
|
||||||
|
@api_bp.input(DatasetUploadDto, location="files")
|
||||||
|
@api_bp.input(ClassificationDto, location="query")
|
||||||
|
@api_bp.output(ClassificationResultDto)
|
||||||
|
def classification(files_data, query_data):
|
||||||
|
uploaded_file: FileStorage = files_data["dataset"]
|
||||||
|
schema = ClassificationDto()
|
||||||
|
dataset_params: DatasetParams = schema.get_dataset_params(query_data)
|
||||||
|
tree_params: DecisionTreeParams = schema.get_tree_params(query_data)
|
||||||
|
return service.run_classification(
|
||||||
|
dataset_path, uploaded_file, dataset_params, tree_params
|
||||||
|
)
|
||||||
|
27
dt-cart/backend/classification/__init__.py
Normal file
27
dt-cart/backend/classification/__init__.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from sklearn import tree
|
||||||
|
|
||||||
|
from backend.classification.model import ClassificationResult
|
||||||
|
from backend.dataset.model import SplittedDataset
|
||||||
|
from backend.metric import classification, get_metric
|
||||||
|
from backend.tree import get_rules, get_tree
|
||||||
|
from backend.tree.model import DecisionTreeParams
|
||||||
|
|
||||||
|
|
||||||
|
def fit_classification_model(
|
||||||
|
data: SplittedDataset,
|
||||||
|
params: DecisionTreeParams,
|
||||||
|
) -> ClassificationResult:
|
||||||
|
model = tree.DecisionTreeClassifier(**vars(params))
|
||||||
|
fitted_model = model.fit(data.X_train.values, data.y_train.values.ravel())
|
||||||
|
y = (data.y_train, fitted_model.predict(data.X_train.values))
|
||||||
|
y_pred = (data.y_test, fitted_model.predict(data.X_test.values))
|
||||||
|
return ClassificationResult(
|
||||||
|
precision=get_metric(classification.precision, y, y_pred),
|
||||||
|
recall=get_metric(classification.recall, y, y_pred),
|
||||||
|
accuracy=get_metric(classification.accuracy, y, y_pred),
|
||||||
|
f1=get_metric(classification.f1, y, y_pred),
|
||||||
|
mcc=get_metric(classification.mcc, y, y_pred),
|
||||||
|
cohen_kappa=get_metric(classification.cohen_kappa, y, y_pred),
|
||||||
|
rules=get_rules(fitted_model, list(data.X_train.columns)),
|
||||||
|
tree=get_tree(fitted_model, list(data.X_train.columns)),
|
||||||
|
)
|
23
dt-cart/backend/classification/dto.py
Normal file
23
dt-cart/backend/classification/dto.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from apiflask import Schema, fields
|
||||||
|
from apiflask.validators import OneOf
|
||||||
|
|
||||||
|
from backend.metric.dto import MetrciDto
|
||||||
|
from backend.tree.dto import DecisionTreeParamsDto, RuleDto, TreeNodeDto
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationTreeDto(DecisionTreeParamsDto):
|
||||||
|
criterion = fields.String(
|
||||||
|
load_default="gini",
|
||||||
|
validate=OneOf(["gini", "entropy", "log_loss"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationResultDto(Schema):
|
||||||
|
rules = fields.List(fields.Nested(RuleDto()))
|
||||||
|
tree = fields.List(fields.Nested(TreeNodeDto()))
|
||||||
|
precision = fields.Nested(MetrciDto())
|
||||||
|
recall = fields.Nested(MetrciDto())
|
||||||
|
accuracy = fields.Nested(MetrciDto())
|
||||||
|
f1 = fields.Nested(MetrciDto())
|
||||||
|
mcc = fields.Nested(MetrciDto())
|
||||||
|
cohen_kappa = fields.Nested(MetrciDto())
|
17
dt-cart/backend/classification/model.py
Normal file
17
dt-cart/backend/classification/model.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from backend.metric.model import MetricValue
|
||||||
|
from backend.tree.model import Rule, TreeNode
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClassificationResult:
|
||||||
|
tree: List[TreeNode]
|
||||||
|
rules: List[Rule]
|
||||||
|
precision: MetricValue
|
||||||
|
recall: MetricValue
|
||||||
|
accuracy: MetricValue
|
||||||
|
f1: MetricValue
|
||||||
|
mcc: MetricValue
|
||||||
|
cohen_kappa: MetricValue
|
@ -39,18 +39,38 @@ class Dataset:
|
|||||||
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
def split(
|
def __split(
|
||||||
self, data: DataFrame, params: DatasetParams, random_state: int
|
self,
|
||||||
|
data: DataFrame,
|
||||||
|
params: DatasetParams,
|
||||||
|
random_state: int,
|
||||||
|
is_classification: bool = False,
|
||||||
) -> SplittedDataset:
|
) -> SplittedDataset:
|
||||||
X = data.drop([params.target], axis=1)
|
X = data.drop([params.target], axis=1)
|
||||||
y = data[[params.target]]
|
y = data[[params.target]]
|
||||||
|
stratify = None if not is_classification else y
|
||||||
X_train, X_test, y_train, y_test = train_test_split(
|
X_train, X_test, y_train, y_test = train_test_split(
|
||||||
X,
|
X,
|
||||||
y,
|
y,
|
||||||
test_size=(1.0 - params.train_volume),
|
test_size=(1.0 - params.train_volume),
|
||||||
random_state=random_state,
|
random_state=random_state,
|
||||||
|
stratify=stratify,
|
||||||
)
|
)
|
||||||
return SplittedDataset(X_train, X_test, y_train, y_test)
|
return SplittedDataset(X_train, X_test, y_train, y_test)
|
||||||
|
|
||||||
|
def split_regression(
|
||||||
|
self, data: DataFrame, params: DatasetParams, random_state: int
|
||||||
|
) -> SplittedDataset:
|
||||||
|
return self.__split(
|
||||||
|
data=data, params=params, random_state=random_state, is_classification=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def split_classification(
|
||||||
|
self, data: DataFrame, params: DatasetParams, random_state: int
|
||||||
|
) -> SplittedDataset:
|
||||||
|
return self.__split(
|
||||||
|
data=data, params=params, random_state=random_state, is_classification=True
|
||||||
|
)
|
||||||
|
|
||||||
def remove(self):
|
def remove(self):
|
||||||
os.remove(self.__file_name)
|
os.remove(self.__file_name)
|
||||||
|
@ -1,18 +1,25 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
|
|
||||||
|
from backend.classification.dto import ClassificationTreeDto
|
||||||
from backend.dataset.dto import DatasetDto
|
from backend.dataset.dto import DatasetDto
|
||||||
from backend.dataset.model import DatasetParams
|
from backend.dataset.model import DatasetParams
|
||||||
from backend.regression.dto import RegressionTreeDto
|
from backend.regression.dto import RegressionTreeDto
|
||||||
from backend.regression.model import RegressionTreeParams
|
from backend.tree.model import DecisionTreeParams
|
||||||
|
|
||||||
|
|
||||||
class RegressionDto(DatasetDto, RegressionTreeDto):
|
class Desiralizer:
|
||||||
def get_dataset_params(self, data) -> DatasetParams:
|
def get_dataset_params(self, data) -> DatasetParams:
|
||||||
field_names = set(f.name for f in dataclasses.fields(DatasetParams))
|
field_names = set(f.name for f in dataclasses.fields(DatasetParams))
|
||||||
return DatasetParams(**{k: v for k, v in data.items() if k in field_names})
|
return DatasetParams(**{k: v for k, v in data.items() if k in field_names})
|
||||||
|
|
||||||
def get_tree_params(self, data) -> RegressionTreeParams:
|
def get_tree_params(self, data) -> DecisionTreeParams:
|
||||||
field_names = set(f.name for f in dataclasses.fields(RegressionTreeParams))
|
field_names = set(f.name for f in dataclasses.fields(DecisionTreeParams))
|
||||||
return RegressionTreeParams(
|
return DecisionTreeParams(**{k: v for k, v in data.items() if k in field_names})
|
||||||
**{k: v for k, v in data.items() if k in field_names}
|
|
||||||
)
|
|
||||||
|
class RegressionDto(DatasetDto, RegressionTreeDto, Desiralizer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationDto(DatasetDto, ClassificationTreeDto, Desiralizer):
|
||||||
|
pass
|
||||||
|
@ -1,30 +1,7 @@
|
|||||||
import math
|
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from sklearn import metrics
|
|
||||||
|
|
||||||
from backend.metric.model import MetricValue
|
from backend.metric.model import MetricValue
|
||||||
|
|
||||||
|
|
||||||
def mse(y, y_pred) -> float:
|
|
||||||
return float(metrics.mean_squared_error(y, y_pred))
|
|
||||||
|
|
||||||
|
|
||||||
def rmse(y, y_pred) -> float:
|
|
||||||
return float(math.sqrt(metrics.mean_squared_error(y, y_pred)))
|
|
||||||
|
|
||||||
|
|
||||||
def mae(y, y_pred) -> float:
|
|
||||||
return float(metrics.mean_absolute_error(y, y_pred))
|
|
||||||
|
|
||||||
|
|
||||||
def rmae(y, y_pred) -> float:
|
|
||||||
return float(math.sqrt(metrics.mean_absolute_error(y, y_pred)))
|
|
||||||
|
|
||||||
|
|
||||||
def r2(y, y_pred) -> float:
|
|
||||||
return float(metrics.r2_score(y, y_pred))
|
|
||||||
|
|
||||||
|
|
||||||
def get_metric(metric: Callable, y, y_pred) -> MetricValue:
|
def get_metric(metric: Callable, y, y_pred) -> MetricValue:
|
||||||
return MetricValue(metric(y[0], y[1]), metric(y_pred[0], y_pred[1]))
|
return MetricValue(metric(y[0], y[1]), metric(y_pred[0], y_pred[1]))
|
||||||
|
25
dt-cart/backend/metric/classification.py
Normal file
25
dt-cart/backend/metric/classification.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
|
||||||
|
def precision(y, y_pred) -> float:
|
||||||
|
return float(metrics.precision_score(y, y_pred))
|
||||||
|
|
||||||
|
|
||||||
|
def recall(y, y_pred) -> float:
|
||||||
|
return float(metrics.recall_score(y, y_pred))
|
||||||
|
|
||||||
|
|
||||||
|
def accuracy(y, y_pred) -> float:
|
||||||
|
return float(metrics.accuracy_score(y, y_pred))
|
||||||
|
|
||||||
|
|
||||||
|
def f1(y, y_pred) -> float:
|
||||||
|
return float(metrics.f1_score(y, y_pred))
|
||||||
|
|
||||||
|
|
||||||
|
def mcc(y, y_pred) -> float:
|
||||||
|
return float(metrics.matthews_corrcoef(y, y_pred))
|
||||||
|
|
||||||
|
|
||||||
|
def cohen_kappa(y, y_pred) -> float:
|
||||||
|
return float(metrics.cohen_kappa_score(y, y_pred))
|
23
dt-cart/backend/metric/regression.py
Normal file
23
dt-cart/backend/metric/regression.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
|
||||||
|
def mse(y, y_pred) -> float:
|
||||||
|
return float(metrics.mean_squared_error(y, y_pred))
|
||||||
|
|
||||||
|
|
||||||
|
def rmse(y, y_pred) -> float:
|
||||||
|
return float(math.sqrt(metrics.mean_squared_error(y, y_pred)))
|
||||||
|
|
||||||
|
|
||||||
|
def mae(y, y_pred) -> float:
|
||||||
|
return float(metrics.mean_absolute_error(y, y_pred))
|
||||||
|
|
||||||
|
|
||||||
|
def rmae(y, y_pred) -> float:
|
||||||
|
return float(math.sqrt(metrics.mean_absolute_error(y, y_pred)))
|
||||||
|
|
||||||
|
|
||||||
|
def r2(y, y_pred) -> float:
|
||||||
|
return float(metrics.r2_score(y, y_pred))
|
@ -1,25 +1,26 @@
|
|||||||
from sklearn import tree
|
from sklearn import tree
|
||||||
|
|
||||||
from backend import metric
|
|
||||||
from backend import tree as tree_helper
|
|
||||||
from backend.dataset.model import SplittedDataset
|
from backend.dataset.model import SplittedDataset
|
||||||
from backend.regression.model import RegressionResult, RegressionTreeParams
|
from backend.metric import get_metric, regression
|
||||||
|
from backend.regression.model import RegressionResult
|
||||||
|
from backend.tree import get_rules, get_tree
|
||||||
|
from backend.tree.model import DecisionTreeParams
|
||||||
|
|
||||||
|
|
||||||
def learn_regression_model(
|
def fit_regression_model(
|
||||||
data: SplittedDataset,
|
data: SplittedDataset,
|
||||||
params: RegressionTreeParams,
|
params: DecisionTreeParams,
|
||||||
) -> RegressionResult:
|
) -> RegressionResult:
|
||||||
model = tree.DecisionTreeRegressor(**vars(params))
|
model = tree.DecisionTreeRegressor(**vars(params))
|
||||||
fitted_model = model.fit(data.X_train.values, data.y_train.values.ravel())
|
fitted_model = model.fit(data.X_train.values, data.y_train.values.ravel())
|
||||||
y = (data.y_train, fitted_model.predict(data.X_train.values))
|
y = (data.y_train, fitted_model.predict(data.X_train.values))
|
||||||
y_pred = (data.y_test, fitted_model.predict(data.X_test.values))
|
y_pred = (data.y_test, fitted_model.predict(data.X_test.values))
|
||||||
return RegressionResult(
|
return RegressionResult(
|
||||||
mse=metric.get_metric(metric.mse, y, y_pred),
|
mse=get_metric(regression.mse, y, y_pred),
|
||||||
mae=metric.get_metric(metric.mae, y, y_pred),
|
mae=get_metric(regression.mae, y, y_pred),
|
||||||
rmse=metric.get_metric(metric.rmse, y, y_pred),
|
rmse=get_metric(regression.rmse, y, y_pred),
|
||||||
rmae=metric.get_metric(metric.rmae, y, y_pred),
|
rmae=get_metric(regression.rmae, y, y_pred),
|
||||||
r2=metric.get_metric(metric.r2, y, y_pred),
|
r2=get_metric(regression.r2, y, y_pred),
|
||||||
rules=tree_helper.get_rules(fitted_model, list(data.X_train.columns)),
|
rules=get_rules(fitted_model, list(data.X_train.columns)),
|
||||||
tree=tree_helper.get_tree(fitted_model, list(data.X_train.columns)),
|
tree=get_tree(fitted_model, list(data.X_train.columns)),
|
||||||
)
|
)
|
||||||
|
@ -1,29 +1,15 @@
|
|||||||
from apiflask import Schema, fields
|
from apiflask import Schema, fields
|
||||||
from apiflask.validators import OneOf, Range
|
from apiflask.validators import OneOf
|
||||||
|
|
||||||
from backend.metric.dto import MetrciDto
|
from backend.metric.dto import MetrciDto
|
||||||
from backend.tree.dto import RuleDto, TreeNodeDto
|
from backend.tree.dto import DecisionTreeParamsDto, RuleDto, TreeNodeDto
|
||||||
|
|
||||||
|
|
||||||
class RegressionTreeDto(Schema):
|
class RegressionTreeDto(DecisionTreeParamsDto):
|
||||||
criterion = fields.String(
|
criterion = fields.String(
|
||||||
load_default="squared_error",
|
load_default="squared_error",
|
||||||
validate=OneOf(["squared_error", "friedman_mse", "absolute_error", "poisson"]),
|
validate=OneOf(["squared_error", "friedman_mse", "absolute_error", "poisson"]),
|
||||||
)
|
)
|
||||||
splitter = fields.String(load_default="best", validate=OneOf(["best", "random"]))
|
|
||||||
max_depth = fields.Integer(load_default=None)
|
|
||||||
min_samples_split = fields.Integer(load_default=2, validate=Range(min=2))
|
|
||||||
min_samples_leaf = fields.Integer(load_default=1, validate=Range(min=1))
|
|
||||||
min_weight_fraction_leaf = fields.Float(load_default=0.0)
|
|
||||||
# TODO: Add float values support
|
|
||||||
max_features = fields.String(
|
|
||||||
load_default=None,
|
|
||||||
validate=OneOf(["auto", "sqrt", "log2", None]),
|
|
||||||
)
|
|
||||||
random_state = fields.Integer(load_default=None)
|
|
||||||
max_leaf_nodes = fields.Integer(load_default=None)
|
|
||||||
min_impurity_decrease = fields.Float(load_default=0.0)
|
|
||||||
ccp_alpha = fields.Float(load_default=0.0)
|
|
||||||
|
|
||||||
|
|
||||||
class RegressionResultDto(Schema):
|
class RegressionResultDto(Schema):
|
||||||
|
@ -5,21 +5,6 @@ from backend.metric.model import MetricValue
|
|||||||
from backend.tree.model import Rule, TreeNode
|
from backend.tree.model import Rule, TreeNode
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RegressionTreeParams:
|
|
||||||
criterion: str
|
|
||||||
splitter: str
|
|
||||||
max_depth: int
|
|
||||||
min_samples_split: int
|
|
||||||
min_samples_leaf: int
|
|
||||||
min_weight_fraction_leaf: float
|
|
||||||
max_features: str
|
|
||||||
random_state: int
|
|
||||||
max_leaf_nodes: int
|
|
||||||
min_impurity_decrease: float
|
|
||||||
ccp_alpha: float
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RegressionResult:
|
class RegressionResult:
|
||||||
tree: List[TreeNode]
|
tree: List[TreeNode]
|
||||||
|
@ -1,26 +1,52 @@
|
|||||||
from werkzeug.datastructures import FileStorage
|
from werkzeug.datastructures import FileStorage
|
||||||
|
|
||||||
from backend import regression
|
from backend.classification import fit_classification_model
|
||||||
|
from backend.classification.model import ClassificationResult
|
||||||
from backend.dataset import Dataset
|
from backend.dataset import Dataset
|
||||||
from backend.dataset.model import DatasetParams, SplittedDataset
|
from backend.dataset.model import DatasetParams, SplittedDataset
|
||||||
from backend.regression.model import RegressionResult, RegressionTreeParams
|
from backend.regression import fit_regression_model
|
||||||
|
from backend.regression.model import RegressionResult
|
||||||
|
from backend.tree.model import DecisionTreeParams
|
||||||
|
|
||||||
|
|
||||||
def run_regression(
|
def run_regression(
|
||||||
path: str | None,
|
path: str | None,
|
||||||
file: FileStorage,
|
file: FileStorage,
|
||||||
dataset_params: DatasetParams,
|
dataset_params: DatasetParams,
|
||||||
tree_params: RegressionTreeParams,
|
tree_params: DecisionTreeParams,
|
||||||
) -> RegressionResult:
|
) -> RegressionResult:
|
||||||
try:
|
try:
|
||||||
dataset: Dataset = Dataset(path=path, file=file)
|
dataset: Dataset = Dataset(path=path, file=file)
|
||||||
data = dataset.read(dataset_params)
|
data = dataset.read(dataset_params)
|
||||||
splitted_dataset: SplittedDataset = dataset.split(
|
splitted_dataset: SplittedDataset = dataset.split_regression(
|
||||||
data=data,
|
data=data,
|
||||||
params=dataset_params,
|
params=dataset_params,
|
||||||
random_state=tree_params.random_state,
|
random_state=tree_params.random_state,
|
||||||
)
|
)
|
||||||
result = regression.learn_regression_model(
|
result = fit_regression_model(
|
||||||
|
data=splitted_dataset,
|
||||||
|
params=tree_params,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
dataset.remove()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def run_classification(
|
||||||
|
path: str | None,
|
||||||
|
file: FileStorage,
|
||||||
|
dataset_params: DatasetParams,
|
||||||
|
tree_params: DecisionTreeParams,
|
||||||
|
) -> ClassificationResult:
|
||||||
|
try:
|
||||||
|
dataset: Dataset = Dataset(path=path, file=file)
|
||||||
|
data = dataset.read(dataset_params)
|
||||||
|
splitted_dataset: SplittedDataset = dataset.split_classification(
|
||||||
|
data=data,
|
||||||
|
params=dataset_params,
|
||||||
|
random_state=tree_params.random_state,
|
||||||
|
)
|
||||||
|
result = fit_classification_model(
|
||||||
data=splitted_dataset,
|
data=splitted_dataset,
|
||||||
params=tree_params,
|
params=tree_params,
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,22 @@
|
|||||||
from apiflask import Schema, fields
|
from apiflask import Schema, fields
|
||||||
|
from apiflask.validators import OneOf, Range
|
||||||
|
|
||||||
|
|
||||||
|
class DecisionTreeParamsDto(Schema):
|
||||||
|
splitter = fields.String(load_default="best", validate=OneOf(["best", "random"]))
|
||||||
|
max_depth = fields.Integer(load_default=7)
|
||||||
|
min_samples_split = fields.Integer(load_default=2, validate=Range(min=2))
|
||||||
|
min_samples_leaf = fields.Integer(load_default=1, validate=Range(min=1))
|
||||||
|
min_weight_fraction_leaf = fields.Float(load_default=0.0)
|
||||||
|
# TODO: Add float values support
|
||||||
|
max_features = fields.String(
|
||||||
|
load_default=None,
|
||||||
|
validate=OneOf(["auto", "sqrt", "log2", None]),
|
||||||
|
)
|
||||||
|
random_state = fields.Integer(load_default=None)
|
||||||
|
max_leaf_nodes = fields.Integer(load_default=None)
|
||||||
|
min_impurity_decrease = fields.Float(load_default=0.0)
|
||||||
|
ccp_alpha = fields.Float(load_default=0.0)
|
||||||
|
|
||||||
|
|
||||||
class RuleAtomDto(Schema):
|
class RuleAtomDto(Schema):
|
||||||
|
@ -5,6 +5,21 @@ from typing import List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DecisionTreeParams:
|
||||||
|
criterion: str
|
||||||
|
splitter: str
|
||||||
|
max_depth: int
|
||||||
|
min_samples_split: int
|
||||||
|
min_samples_leaf: int
|
||||||
|
min_weight_fraction_leaf: float
|
||||||
|
max_features: str
|
||||||
|
random_state: int
|
||||||
|
max_leaf_nodes: int
|
||||||
|
min_impurity_decrease: float
|
||||||
|
ccp_alpha: float
|
||||||
|
|
||||||
|
|
||||||
class ComparisonType(enum.Enum):
|
class ComparisonType(enum.Enum):
|
||||||
LESS = "<="
|
LESS = "<="
|
||||||
GREATER = ">"
|
GREATER = ">"
|
||||||
|
68986
dt-cart/data/cardio_clear.csv
Normal file
68986
dt-cart/data/cardio_clear.csv
Normal file
File diff suppressed because it is too large
Load Diff
70001
dt-cart/data/cardio_train.csv
Normal file
70001
dt-cart/data/cardio_train.csv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user