from apiflask import Schema, fields from apiflask.validators import OneOf from backend.metric.dto import MetrciDto from backend.tree.dto import DecisionTreeParamsDto, RuleDto, TreeNodeDto class ClassificationTreeDto(DecisionTreeParamsDto): criterion = fields.String( load_default="gini", validate=OneOf(["gini", "entropy", "log_loss"]), ) class ClassificationResultDto(Schema): rules = fields.List(fields.Nested(RuleDto())) tree = fields.List(fields.Nested(TreeNodeDto())) precision = fields.Nested(MetrciDto()) recall = fields.Nested(MetrciDto()) accuracy = fields.Nested(MetrciDto()) f1 = fields.Nested(MetrciDto()) mcc = fields.Nested(MetrciDto()) cohen_kappa = fields.Nested(MetrciDto())