40 lines
1.2 KiB
Python

from apiflask import Schema, fields
from apiflask.validators import OneOf, Range
class DecisionTreeParamsDto(Schema):
splitter = fields.String(load_default="best", validate=OneOf(["best", "random"]))
max_depth = fields.Integer(load_default=7)
min_samples_split = fields.Integer(load_default=2, validate=Range(min=2))
min_samples_leaf = fields.Integer(load_default=1, validate=Range(min=1))
min_weight_fraction_leaf = fields.Float(load_default=0.0)
# TODO: Add float values support
max_features = fields.String(
load_default=None,
validate=OneOf(["auto", "sqrt", "log2", None]),
)
random_state = fields.Integer(load_default=None)
max_leaf_nodes = fields.Integer(load_default=None)
min_impurity_decrease = fields.Float(load_default=0.0)
ccp_alpha = fields.Float(load_default=0.0)
class RuleAtomDto(Schema):
variable = fields.String()
type = fields.String()
value = fields.Float()
class RuleDto(Schema):
antecedent = fields.List(fields.Nested(RuleAtomDto()))
consequent = fields.Field()
class TreeNodeDto(Schema):
parent = fields.String()
name = fields.String()
level = fields.Integer()
variable = fields.String()
type = fields.String()
value = fields.Field()