from apiflask import Schema, fields from apiflask.validators import OneOf, Range from backend.metric.dto import MetrciDto from backend.tree.dto import RuleDto, TreeNodeDto class RegressionTreeDto(Schema): criterion = fields.String( load_default="squared_error", validate=OneOf(["squared_error", "friedman_mse", "absolute_error", "poisson"]), ) splitter = fields.String(load_default="best", validate=OneOf(["best", "random"])) max_depth = fields.Integer(load_default=None) min_samples_split = fields.Integer(load_default=2, validate=Range(min=2)) min_samples_leaf = fields.Integer(load_default=1, validate=Range(min=1)) min_weight_fraction_leaf = fields.Float(load_default=0.0) # TODO: Add float values support max_features = fields.String( load_default=None, validate=OneOf(["auto", "sqrt", "log2", None]), ) random_state = fields.Integer(load_default=None) max_leaf_nodes = fields.Integer(load_default=None) min_impurity_decrease = fields.Float(load_default=0.0) ccp_alpha = fields.Float(load_default=0.0) class RegressionResultDto(Schema): rules = fields.List(fields.Nested(RuleDto())) tree = fields.List(fields.Nested(TreeNodeDto())) mse = fields.Nested(MetrciDto()) mae = fields.Nested(MetrciDto()) rmse = fields.Nested(MetrciDto()) rmae = fields.Nested(MetrciDto()) r2 = fields.Nested(MetrciDto())