29 lines
1.3 KiB
Python
29 lines
1.3 KiB
Python
from sklearn import tree
|
|
|
|
from backend.classification.model import ClassificationResult
|
|
from backend.dataset.model import SplittedDataset
|
|
from backend.metric import classification, get_metric
|
|
from backend.tree import get_rules, get_tree
|
|
from backend.tree.model import DecisionTreeParams
|
|
|
|
|
|
def fit_classification_model(
|
|
data: SplittedDataset,
|
|
params: DecisionTreeParams,
|
|
) -> ClassificationResult:
|
|
model = tree.DecisionTreeClassifier(**vars(params))
|
|
fitted_model = model.fit(data.X_train.values, data.y_train.values.ravel())
|
|
y = (data.y_train, fitted_model.predict(data.X_train.values))
|
|
y_pred = (data.y_test, fitted_model.predict(data.X_test.values))
|
|
classes = fitted_model.classes_.tolist() # type: ignore
|
|
return ClassificationResult(
|
|
precision=get_metric(classification.precision, y, y_pred),
|
|
recall=get_metric(classification.recall, y, y_pred),
|
|
accuracy=get_metric(classification.accuracy, y, y_pred),
|
|
f1=get_metric(classification.f1, y, y_pred),
|
|
mcc=get_metric(classification.mcc, y, y_pred),
|
|
cohen_kappa=get_metric(classification.cohen_kappa, y, y_pred),
|
|
rules=get_rules(fitted_model, list(data.X_train.columns), classes),
|
|
tree=get_tree(fitted_model, list(data.X_train.columns), classes),
|
|
)
|