37 lines
1.4 KiB
Python
37 lines
1.4 KiB
Python
from apiflask import Schema, fields
|
|
from apiflask.validators import OneOf, Range
|
|
|
|
from backend.metric.dto import MetrciDto
|
|
from backend.tree.dto import RuleDto, TreeNodeDto
|
|
|
|
|
|
class RegressionTreeDto(Schema):
|
|
criterion = fields.String(
|
|
load_default="squared_error",
|
|
validate=OneOf(["squared_error", "friedman_mse", "absolute_error", "poisson"]),
|
|
)
|
|
splitter = fields.String(load_default="best", validate=OneOf(["best", "random"]))
|
|
max_depth = fields.Integer(load_default=None)
|
|
min_samples_split = fields.Integer(load_default=2, validate=Range(min=2))
|
|
min_samples_leaf = fields.Integer(load_default=1, validate=Range(min=1))
|
|
min_weight_fraction_leaf = fields.Float(load_default=0.0)
|
|
# TODO: Add float values support
|
|
max_features = fields.String(
|
|
load_default=None,
|
|
validate=OneOf(["auto", "sqrt", "log2", None]),
|
|
)
|
|
random_state = fields.Integer(load_default=None)
|
|
max_leaf_nodes = fields.Integer(load_default=None)
|
|
min_impurity_decrease = fields.Float(load_default=0.0)
|
|
ccp_alpha = fields.Float(load_default=0.0)
|
|
|
|
|
|
class RegressionResultDto(Schema):
|
|
rules = fields.List(fields.Nested(RuleDto()))
|
|
tree = fields.List(fields.Nested(TreeNodeDto()))
|
|
mse = fields.Nested(MetrciDto())
|
|
mae = fields.Nested(MetrciDto())
|
|
rmse = fields.Nested(MetrciDto())
|
|
rmae = fields.Nested(MetrciDto())
|
|
r2 = fields.Nested(MetrciDto())
|