from sklearn import tree from backend.classification.model import ClassificationResult from backend.dataset.model import SplittedDataset from backend.metric import classification, get_metric from backend.tree import get_rules, get_tree from backend.tree.model import DecisionTreeParams def fit_classification_model( data: SplittedDataset, params: DecisionTreeParams, ) -> ClassificationResult: model = tree.DecisionTreeClassifier(**vars(params)) fitted_model = model.fit(data.X_train.values, data.y_train.values.ravel()) y = (data.y_train, fitted_model.predict(data.X_train.values)) y_pred = (data.y_test, fitted_model.predict(data.X_test.values)) classes = fitted_model.classes_.tolist() # type: ignore return ClassificationResult( precision=get_metric(classification.precision, y, y_pred), recall=get_metric(classification.recall, y, y_pred), accuracy=get_metric(classification.accuracy, y, y_pred), f1=get_metric(classification.f1, y, y_pred), mcc=get_metric(classification.mcc, y, y_pred), cohen_kappa=get_metric(classification.cohen_kappa, y, y_pred), rules=get_rules(fitted_model, list(data.X_train.columns), classes), tree=get_tree(fitted_model, list(data.X_train.columns), classes), )