43 lines
907 B
Python
43 lines
907 B
Python
import enum
|
|
from dataclasses import dataclass
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
|
|
|
|
class ComparisonType(enum.Enum):
|
|
LESS = "<="
|
|
GREATER = ">"
|
|
|
|
|
|
@dataclass(repr=False)
|
|
class RuleAtom:
|
|
variable: str
|
|
type: str
|
|
value: float
|
|
|
|
def __repr__(self) -> str:
|
|
return f"({self.variable} {self.type} {np.round(self.value, 3)})"
|
|
|
|
|
|
@dataclass(repr=False)
|
|
class Rule:
|
|
antecedent: List[RuleAtom]
|
|
consequent: float | str
|
|
|
|
def __repr__(self) -> str:
|
|
consequent_value: float | str = str(self.consequent)
|
|
if consequent_value.isnumeric():
|
|
consequent_value = np.round(float(consequent_value), 3)
|
|
return f"if {" and ".join([str(atom) for atom in self.antecedent])} -> {consequent_value}"
|
|
|
|
|
|
@dataclass(repr=False)
|
|
class TreeNode:
|
|
parent: str | None
|
|
name: str | None
|
|
level: int
|
|
variable: str
|
|
type: str
|
|
value: float | str
|