from apiflask import Schema, fields from apiflask.validators import OneOf, Range class DecisionTreeParamsDto(Schema): splitter = fields.String(load_default="best", validate=OneOf(["best", "random"])) max_depth = fields.Integer(load_default=7) min_samples_split = fields.Integer(load_default=2, validate=Range(min=2)) min_samples_leaf = fields.Integer(load_default=1, validate=Range(min=1)) min_weight_fraction_leaf = fields.Float(load_default=0.0) # TODO: Add float values support max_features = fields.String( load_default=None, validate=OneOf(["auto", "sqrt", "log2", None]), ) random_state = fields.Integer(load_default=None) max_leaf_nodes = fields.Integer(load_default=None) min_impurity_decrease = fields.Float(load_default=0.0) ccp_alpha = fields.Float(load_default=0.0) class RuleAtomDto(Schema): variable = fields.String() type = fields.String() value = fields.Float() class RuleDto(Schema): antecedent = fields.List(fields.Nested(RuleAtomDto())) consequent = fields.Field() class TreeNodeDto(Schema): parent = fields.String() name = fields.String() level = fields.Integer() variable = fields.String() type = fields.String() value = fields.Field()