Add classification tree implementation, some refactoring

This commit is contained in:
Aleksey Filippov 2025-03-11 21:39:00 +04:00
parent 1682e25dcb
commit a4bdf7c88c
17 changed files with 139237 additions and 85 deletions

View File

@ -1,22 +1,37 @@
from werkzeug.datastructures import FileStorage
from backend import api_bp, dataset_path, service
from backend.classification.dto import ClassificationResultDto
from backend.dataset.dto import DatasetUploadDto
from backend.dataset.model import DatasetParams
from backend.dto import RegressionDto
from backend.dto import ClassificationDto, RegressionDto
from backend.regression.dto import RegressionResultDto
from backend.regression.model import RegressionTreeParams
from backend.tree.model import DecisionTreeParams
@api_bp.post("/regression")
@api_bp.input(DatasetUploadDto, location="files")
@api_bp.input(RegressionDto, location="query")
@api_bp.output(RegressionResultDto)
def upload_dataset(files_data, query_data):
def regression(files_data, query_data):
uploaded_file: FileStorage = files_data["dataset"]
schema = RegressionDto()
dataset_params: DatasetParams = schema.get_dataset_params(query_data)
tree_params: RegressionTreeParams = schema.get_tree_params(query_data)
tree_params: DecisionTreeParams = schema.get_tree_params(query_data)
return service.run_regression(
dataset_path, uploaded_file, dataset_params, tree_params
)
@api_bp.post("/classification")
@api_bp.input(DatasetUploadDto, location="files")
@api_bp.input(ClassificationDto, location="query")
@api_bp.output(ClassificationResultDto)
def classification(files_data, query_data):
uploaded_file: FileStorage = files_data["dataset"]
schema = ClassificationDto()
dataset_params: DatasetParams = schema.get_dataset_params(query_data)
tree_params: DecisionTreeParams = schema.get_tree_params(query_data)
return service.run_classification(
dataset_path, uploaded_file, dataset_params, tree_params
)

View File

@ -0,0 +1,27 @@
from sklearn import tree
from backend.classification.model import ClassificationResult
from backend.dataset.model import SplittedDataset
from backend.metric import classification, get_metric
from backend.tree import get_rules, get_tree
from backend.tree.model import DecisionTreeParams
def fit_classification_model(
data: SplittedDataset,
params: DecisionTreeParams,
) -> ClassificationResult:
model = tree.DecisionTreeClassifier(**vars(params))
fitted_model = model.fit(data.X_train.values, data.y_train.values.ravel())
y = (data.y_train, fitted_model.predict(data.X_train.values))
y_pred = (data.y_test, fitted_model.predict(data.X_test.values))
return ClassificationResult(
precision=get_metric(classification.precision, y, y_pred),
recall=get_metric(classification.recall, y, y_pred),
accuracy=get_metric(classification.accuracy, y, y_pred),
f1=get_metric(classification.f1, y, y_pred),
mcc=get_metric(classification.mcc, y, y_pred),
cohen_kappa=get_metric(classification.cohen_kappa, y, y_pred),
rules=get_rules(fitted_model, list(data.X_train.columns)),
tree=get_tree(fitted_model, list(data.X_train.columns)),
)

View File

@ -0,0 +1,23 @@
from apiflask import Schema, fields
from apiflask.validators import OneOf
from backend.metric.dto import MetrciDto
from backend.tree.dto import DecisionTreeParamsDto, RuleDto, TreeNodeDto
class ClassificationTreeDto(DecisionTreeParamsDto):
criterion = fields.String(
load_default="gini",
validate=OneOf(["gini", "entropy", "log_loss"]),
)
class ClassificationResultDto(Schema):
rules = fields.List(fields.Nested(RuleDto()))
tree = fields.List(fields.Nested(TreeNodeDto()))
precision = fields.Nested(MetrciDto())
recall = fields.Nested(MetrciDto())
accuracy = fields.Nested(MetrciDto())
f1 = fields.Nested(MetrciDto())
mcc = fields.Nested(MetrciDto())
cohen_kappa = fields.Nested(MetrciDto())

View File

@ -0,0 +1,17 @@
from dataclasses import dataclass
from typing import List
from backend.metric.model import MetricValue
from backend.tree.model import Rule, TreeNode
@dataclass
class ClassificationResult:
tree: List[TreeNode]
rules: List[Rule]
precision: MetricValue
recall: MetricValue
accuracy: MetricValue
f1: MetricValue
mcc: MetricValue
cohen_kappa: MetricValue

View File

@ -39,18 +39,38 @@ class Dataset:
return df
def split(
self, data: DataFrame, params: DatasetParams, random_state: int
def __split(
self,
data: DataFrame,
params: DatasetParams,
random_state: int,
is_classification: bool = False,
) -> SplittedDataset:
X = data.drop([params.target], axis=1)
y = data[[params.target]]
stratify = None if not is_classification else y
X_train, X_test, y_train, y_test = train_test_split(
X,
y,
test_size=(1.0 - params.train_volume),
random_state=random_state,
stratify=stratify,
)
return SplittedDataset(X_train, X_test, y_train, y_test)
def split_regression(
self, data: DataFrame, params: DatasetParams, random_state: int
) -> SplittedDataset:
return self.__split(
data=data, params=params, random_state=random_state, is_classification=False
)
def split_classification(
self, data: DataFrame, params: DatasetParams, random_state: int
) -> SplittedDataset:
return self.__split(
data=data, params=params, random_state=random_state, is_classification=True
)
def remove(self):
os.remove(self.__file_name)

View File

@ -1,18 +1,25 @@
import dataclasses
from backend.classification.dto import ClassificationTreeDto
from backend.dataset.dto import DatasetDto
from backend.dataset.model import DatasetParams
from backend.regression.dto import RegressionTreeDto
from backend.regression.model import RegressionTreeParams
from backend.tree.model import DecisionTreeParams
class RegressionDto(DatasetDto, RegressionTreeDto):
class Desiralizer:
def get_dataset_params(self, data) -> DatasetParams:
field_names = set(f.name for f in dataclasses.fields(DatasetParams))
return DatasetParams(**{k: v for k, v in data.items() if k in field_names})
def get_tree_params(self, data) -> RegressionTreeParams:
field_names = set(f.name for f in dataclasses.fields(RegressionTreeParams))
return RegressionTreeParams(
**{k: v for k, v in data.items() if k in field_names}
)
def get_tree_params(self, data) -> DecisionTreeParams:
field_names = set(f.name for f in dataclasses.fields(DecisionTreeParams))
return DecisionTreeParams(**{k: v for k, v in data.items() if k in field_names})
class RegressionDto(DatasetDto, RegressionTreeDto, Desiralizer):
pass
class ClassificationDto(DatasetDto, ClassificationTreeDto, Desiralizer):
pass

View File

@ -1,30 +1,7 @@
import math
from typing import Callable
from sklearn import metrics
from backend.metric.model import MetricValue
def mse(y, y_pred) -> float:
return float(metrics.mean_squared_error(y, y_pred))
def rmse(y, y_pred) -> float:
return float(math.sqrt(metrics.mean_squared_error(y, y_pred)))
def mae(y, y_pred) -> float:
return float(metrics.mean_absolute_error(y, y_pred))
def rmae(y, y_pred) -> float:
return float(math.sqrt(metrics.mean_absolute_error(y, y_pred)))
def r2(y, y_pred) -> float:
return float(metrics.r2_score(y, y_pred))
def get_metric(metric: Callable, y, y_pred) -> MetricValue:
return MetricValue(metric(y[0], y[1]), metric(y_pred[0], y_pred[1]))

View File

@ -0,0 +1,25 @@
from sklearn import metrics
def precision(y, y_pred) -> float:
return float(metrics.precision_score(y, y_pred))
def recall(y, y_pred) -> float:
return float(metrics.recall_score(y, y_pred))
def accuracy(y, y_pred) -> float:
return float(metrics.accuracy_score(y, y_pred))
def f1(y, y_pred) -> float:
return float(metrics.f1_score(y, y_pred))
def mcc(y, y_pred) -> float:
return float(metrics.matthews_corrcoef(y, y_pred))
def cohen_kappa(y, y_pred) -> float:
return float(metrics.cohen_kappa_score(y, y_pred))

View File

@ -0,0 +1,23 @@
import math
from sklearn import metrics
def mse(y, y_pred) -> float:
return float(metrics.mean_squared_error(y, y_pred))
def rmse(y, y_pred) -> float:
return float(math.sqrt(metrics.mean_squared_error(y, y_pred)))
def mae(y, y_pred) -> float:
return float(metrics.mean_absolute_error(y, y_pred))
def rmae(y, y_pred) -> float:
return float(math.sqrt(metrics.mean_absolute_error(y, y_pred)))
def r2(y, y_pred) -> float:
return float(metrics.r2_score(y, y_pred))

View File

@ -1,25 +1,26 @@
from sklearn import tree
from backend import metric
from backend import tree as tree_helper
from backend.dataset.model import SplittedDataset
from backend.regression.model import RegressionResult, RegressionTreeParams
from backend.metric import get_metric, regression
from backend.regression.model import RegressionResult
from backend.tree import get_rules, get_tree
from backend.tree.model import DecisionTreeParams
def learn_regression_model(
def fit_regression_model(
data: SplittedDataset,
params: RegressionTreeParams,
params: DecisionTreeParams,
) -> RegressionResult:
model = tree.DecisionTreeRegressor(**vars(params))
fitted_model = model.fit(data.X_train.values, data.y_train.values.ravel())
y = (data.y_train, fitted_model.predict(data.X_train.values))
y_pred = (data.y_test, fitted_model.predict(data.X_test.values))
return RegressionResult(
mse=metric.get_metric(metric.mse, y, y_pred),
mae=metric.get_metric(metric.mae, y, y_pred),
rmse=metric.get_metric(metric.rmse, y, y_pred),
rmae=metric.get_metric(metric.rmae, y, y_pred),
r2=metric.get_metric(metric.r2, y, y_pred),
rules=tree_helper.get_rules(fitted_model, list(data.X_train.columns)),
tree=tree_helper.get_tree(fitted_model, list(data.X_train.columns)),
mse=get_metric(regression.mse, y, y_pred),
mae=get_metric(regression.mae, y, y_pred),
rmse=get_metric(regression.rmse, y, y_pred),
rmae=get_metric(regression.rmae, y, y_pred),
r2=get_metric(regression.r2, y, y_pred),
rules=get_rules(fitted_model, list(data.X_train.columns)),
tree=get_tree(fitted_model, list(data.X_train.columns)),
)

View File

@ -1,29 +1,15 @@
from apiflask import Schema, fields
from apiflask.validators import OneOf, Range
from apiflask.validators import OneOf
from backend.metric.dto import MetrciDto
from backend.tree.dto import RuleDto, TreeNodeDto
from backend.tree.dto import DecisionTreeParamsDto, RuleDto, TreeNodeDto
class RegressionTreeDto(Schema):
class RegressionTreeDto(DecisionTreeParamsDto):
criterion = fields.String(
load_default="squared_error",
validate=OneOf(["squared_error", "friedman_mse", "absolute_error", "poisson"]),
)
splitter = fields.String(load_default="best", validate=OneOf(["best", "random"]))
max_depth = fields.Integer(load_default=None)
min_samples_split = fields.Integer(load_default=2, validate=Range(min=2))
min_samples_leaf = fields.Integer(load_default=1, validate=Range(min=1))
min_weight_fraction_leaf = fields.Float(load_default=0.0)
# TODO: Add float values support
max_features = fields.String(
load_default=None,
validate=OneOf(["auto", "sqrt", "log2", None]),
)
random_state = fields.Integer(load_default=None)
max_leaf_nodes = fields.Integer(load_default=None)
min_impurity_decrease = fields.Float(load_default=0.0)
ccp_alpha = fields.Float(load_default=0.0)
class RegressionResultDto(Schema):

View File

@ -5,21 +5,6 @@ from backend.metric.model import MetricValue
from backend.tree.model import Rule, TreeNode
@dataclass
class RegressionTreeParams:
criterion: str
splitter: str
max_depth: int
min_samples_split: int
min_samples_leaf: int
min_weight_fraction_leaf: float
max_features: str
random_state: int
max_leaf_nodes: int
min_impurity_decrease: float
ccp_alpha: float
@dataclass
class RegressionResult:
tree: List[TreeNode]

View File

@ -1,26 +1,52 @@
from werkzeug.datastructures import FileStorage
from backend import regression
from backend.classification import fit_classification_model
from backend.classification.model import ClassificationResult
from backend.dataset import Dataset
from backend.dataset.model import DatasetParams, SplittedDataset
from backend.regression.model import RegressionResult, RegressionTreeParams
from backend.regression import fit_regression_model
from backend.regression.model import RegressionResult
from backend.tree.model import DecisionTreeParams
def run_regression(
path: str | None,
file: FileStorage,
dataset_params: DatasetParams,
tree_params: RegressionTreeParams,
tree_params: DecisionTreeParams,
) -> RegressionResult:
try:
dataset: Dataset = Dataset(path=path, file=file)
data = dataset.read(dataset_params)
splitted_dataset: SplittedDataset = dataset.split(
splitted_dataset: SplittedDataset = dataset.split_regression(
data=data,
params=dataset_params,
random_state=tree_params.random_state,
)
result = regression.learn_regression_model(
result = fit_regression_model(
data=splitted_dataset,
params=tree_params,
)
finally:
dataset.remove()
return result
def run_classification(
path: str | None,
file: FileStorage,
dataset_params: DatasetParams,
tree_params: DecisionTreeParams,
) -> ClassificationResult:
try:
dataset: Dataset = Dataset(path=path, file=file)
data = dataset.read(dataset_params)
splitted_dataset: SplittedDataset = dataset.split_classification(
data=data,
params=dataset_params,
random_state=tree_params.random_state,
)
result = fit_classification_model(
data=splitted_dataset,
params=tree_params,
)

View File

@ -1,4 +1,22 @@
from apiflask import Schema, fields
from apiflask.validators import OneOf, Range
class DecisionTreeParamsDto(Schema):
splitter = fields.String(load_default="best", validate=OneOf(["best", "random"]))
max_depth = fields.Integer(load_default=7)
min_samples_split = fields.Integer(load_default=2, validate=Range(min=2))
min_samples_leaf = fields.Integer(load_default=1, validate=Range(min=1))
min_weight_fraction_leaf = fields.Float(load_default=0.0)
# TODO: Add float values support
max_features = fields.String(
load_default=None,
validate=OneOf(["auto", "sqrt", "log2", None]),
)
random_state = fields.Integer(load_default=None)
max_leaf_nodes = fields.Integer(load_default=None)
min_impurity_decrease = fields.Float(load_default=0.0)
ccp_alpha = fields.Float(load_default=0.0)
class RuleAtomDto(Schema):

View File

@ -5,6 +5,21 @@ from typing import List
import numpy as np
@dataclass
class DecisionTreeParams:
criterion: str
splitter: str
max_depth: int
min_samples_split: int
min_samples_leaf: int
min_weight_fraction_leaf: float
max_features: str
random_state: int
max_leaf_nodes: int
min_impurity_decrease: float
ccp_alpha: float
class ComparisonType(enum.Enum):
LESS = "<="
GREATER = ">"

68986
dt-cart/data/cardio_clear.csv Normal file

File diff suppressed because it is too large Load Diff

70001
dt-cart/data/cardio_train.csv Normal file

File diff suppressed because it is too large Load Diff