import enum from dataclasses import dataclass from typing import List import numpy as np @dataclass class DecisionTreeParams: criterion: str splitter: str max_depth: int min_samples_split: int min_samples_leaf: int min_weight_fraction_leaf: float max_features: str random_state: int max_leaf_nodes: int min_impurity_decrease: float ccp_alpha: float class ComparisonType(enum.Enum): LESS = "<=" GREATER = ">" @dataclass(repr=False) class RuleAtom: variable: str type: str value: float def __repr__(self) -> str: return f"({self.variable} {self.type} {np.round(self.value, 3)})" @dataclass(repr=False) class Rule: antecedent: List[RuleAtom] consequent: float | str def __repr__(self) -> str: consequent_value: float | str = str(self.consequent) if consequent_value.isnumeric(): consequent_value = np.round(float(consequent_value), 3) return f"if {" and ".join([str(atom) for atom in self.antecedent])} -> {consequent_value}" @dataclass(repr=False) class TreeNode: parent: str | None name: str | None level: int variable: str type: str value: float | str