40 lines
1.2 KiB
Python
40 lines
1.2 KiB
Python
from apiflask import Schema, fields
|
|
from apiflask.validators import OneOf, Range
|
|
|
|
|
|
class DecisionTreeParamsDto(Schema):
|
|
splitter = fields.String(load_default="best", validate=OneOf(["best", "random"]))
|
|
max_depth = fields.Integer(load_default=7)
|
|
min_samples_split = fields.Integer(load_default=2, validate=Range(min=2))
|
|
min_samples_leaf = fields.Integer(load_default=1, validate=Range(min=1))
|
|
min_weight_fraction_leaf = fields.Float(load_default=0.0)
|
|
# TODO: Add float values support
|
|
max_features = fields.String(
|
|
load_default=None,
|
|
validate=OneOf(["auto", "sqrt", "log2", None]),
|
|
)
|
|
random_state = fields.Integer(load_default=None)
|
|
max_leaf_nodes = fields.Integer(load_default=None)
|
|
min_impurity_decrease = fields.Float(load_default=0.0)
|
|
ccp_alpha = fields.Float(load_default=0.0)
|
|
|
|
|
|
class RuleAtomDto(Schema):
|
|
variable = fields.String()
|
|
type = fields.String()
|
|
value = fields.Float()
|
|
|
|
|
|
class RuleDto(Schema):
|
|
antecedent = fields.List(fields.Nested(RuleAtomDto()))
|
|
consequent = fields.Field()
|
|
|
|
|
|
class TreeNodeDto(Schema):
|
|
parent = fields.String()
|
|
name = fields.String()
|
|
level = fields.Integer()
|
|
variable = fields.String()
|
|
type = fields.String()
|
|
value = fields.Field()
|