26 lines
1.0 KiB
Python

from sklearn import tree
from backend import metric
from backend import tree as tree_helper
from backend.dataset.model import SplittedDataset
from backend.regression.model import RegressionResult, RegressionTreeParams
def learn_regression_model(
data: SplittedDataset,
params: RegressionTreeParams,
) -> RegressionResult:
model = tree.DecisionTreeRegressor(**vars(params))
fitted_model = model.fit(data.X_train.values, data.y_train.values.ravel())
y = (data.y_train, fitted_model.predict(data.X_train.values))
y_pred = (data.y_test, fitted_model.predict(data.X_test.values))
return RegressionResult(
mse=metric.get_metric(metric.mse, y, y_pred),
mae=metric.get_metric(metric.mae, y, y_pred),
rmse=metric.get_metric(metric.rmse, y, y_pred),
rmae=metric.get_metric(metric.rmae, y, y_pred),
r2=metric.get_metric(metric.r2, y, y_pred),
rules=tree_helper.get_rules(fitted_model, list(data.X_train.columns)),
tree=tree_helper.get_tree(fitted_model, list(data.X_train.columns)),
)