27 lines
1.1 KiB
Python
27 lines
1.1 KiB
Python
from sklearn import tree
|
|
|
|
from backend.dataset.model import SplittedDataset
|
|
from backend.metric import get_metric, regression
|
|
from backend.regression.model import RegressionResult
|
|
from backend.tree import get_rules, get_tree
|
|
from backend.tree.model import DecisionTreeParams
|
|
|
|
|
|
def fit_regression_model(
|
|
data: SplittedDataset,
|
|
params: DecisionTreeParams,
|
|
) -> RegressionResult:
|
|
model = tree.DecisionTreeRegressor(**vars(params))
|
|
fitted_model = model.fit(data.X_train.values, data.y_train.values.ravel())
|
|
y = (data.y_train, fitted_model.predict(data.X_train.values))
|
|
y_pred = (data.y_test, fitted_model.predict(data.X_test.values))
|
|
return RegressionResult(
|
|
mse=get_metric(regression.mse, y, y_pred),
|
|
mae=get_metric(regression.mae, y, y_pred),
|
|
rmse=get_metric(regression.rmse, y, y_pred),
|
|
rmae=get_metric(regression.rmae, y, y_pred),
|
|
r2=get_metric(regression.r2, y, y_pred),
|
|
rules=get_rules(fitted_model, list(data.X_train.columns)),
|
|
tree=get_tree(fitted_model, list(data.X_train.columns)),
|
|
)
|