37 lines
1.4 KiB
Python

from apiflask import Schema, fields
from apiflask.validators import OneOf, Range
from backend.metric.dto import MetrciDto
from backend.tree.dto import RuleDto, TreeNodeDto
class RegressionTreeDto(Schema):
criterion = fields.String(
load_default="squared_error",
validate=OneOf(["squared_error", "friedman_mse", "absolute_error", "poisson"]),
)
splitter = fields.String(load_default="best", validate=OneOf(["best", "random"]))
max_depth = fields.Integer(load_default=None)
min_samples_split = fields.Integer(load_default=2, validate=Range(min=2))
min_samples_leaf = fields.Integer(load_default=1, validate=Range(min=1))
min_weight_fraction_leaf = fields.Float(load_default=0.0)
# TODO: Add float values support
max_features = fields.String(
load_default=None,
validate=OneOf(["auto", "sqrt", "log2", None]),
)
random_state = fields.Integer(load_default=None)
max_leaf_nodes = fields.Integer(load_default=None)
min_impurity_decrease = fields.Float(load_default=0.0)
ccp_alpha = fields.Float(load_default=0.0)
class RegressionResultDto(Schema):
rules = fields.List(fields.Nested(RuleDto()))
tree = fields.List(fields.Nested(TreeNodeDto()))
mse = fields.Nested(MetrciDto())
mae = fields.Nested(MetrciDto())
rmse = fields.Nested(MetrciDto())
rmae = fields.Nested(MetrciDto())
r2 = fields.Nested(MetrciDto())