Add additional systems (temp + density / viscosity)
This commit is contained in:
parent
f576f28814
commit
168a6350b3
BIN
data/temp_density_tree.model.sav
Normal file
BIN
data/temp_density_tree.model.sav
Normal file
Binary file not shown.
BIN
data/temp_viscosity_tree.model.sav
Normal file
BIN
data/temp_viscosity_tree.model.sav
Normal file
Binary file not shown.
773
temp_density_regression.ipynb
Normal file
773
temp_density_regression.ipynb
Normal file
@ -0,0 +1,773 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>T</th>\n",
|
||||
" <th>Al2O3</th>\n",
|
||||
" <th>TiO2</th>\n",
|
||||
" <th>Density</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>20</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.06250</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>25</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.05979</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>35</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.05404</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>40</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.05103</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>45</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.04794</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" T Al2O3 TiO2 Density\n",
|
||||
"0 20 0.0 0.0 1.06250\n",
|
||||
"1 25 0.0 0.0 1.05979\n",
|
||||
"2 35 0.0 0.0 1.05404\n",
|
||||
"3 40 0.0 0.0 1.05103\n",
|
||||
"4 45 0.0 0.0 1.04794"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>T</th>\n",
|
||||
" <th>Al2O3</th>\n",
|
||||
" <th>TiO2</th>\n",
|
||||
" <th>Density</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>30</td>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.05696</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>55</td>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.04158</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>25</td>\n",
|
||||
" <td>0.05</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.08438</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>30</td>\n",
|
||||
" <td>0.05</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.08112</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>35</td>\n",
|
||||
" <td>0.05</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.07781</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" T Al2O3 TiO2 Density\n",
|
||||
"0 30 0.00 0.0 1.05696\n",
|
||||
"1 55 0.00 0.0 1.04158\n",
|
||||
"2 25 0.05 0.0 1.08438\n",
|
||||
"3 30 0.05 0.0 1.08112\n",
|
||||
"4 35 0.05 0.0 1.07781"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"train = pd.read_csv(\"data/density_train.csv\", sep=\";\", decimal=\",\")\n",
|
||||
"test = pd.read_csv(\"data/density_test.csv\", sep=\";\", decimal=\",\")\n",
|
||||
"\n",
|
||||
"display(train.head())\n",
|
||||
"display(test.head())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>Al2O3</th>\n",
|
||||
" <th>TiO2</th>\n",
|
||||
" <th>Density</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.06250</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.05979</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.05404</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.05103</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.04794</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" Al2O3 TiO2 Density\n",
|
||||
"0 0.0 0.0 1.06250\n",
|
||||
"1 0.0 0.0 1.05979\n",
|
||||
"2 0.0 0.0 1.05404\n",
|
||||
"3 0.0 0.0 1.05103\n",
|
||||
"4 0.0 0.0 1.04794"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0 20\n",
|
||||
"1 25\n",
|
||||
"2 35\n",
|
||||
"3 40\n",
|
||||
"4 45\n",
|
||||
"Name: T, dtype: int64"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>Al2O3</th>\n",
|
||||
" <th>TiO2</th>\n",
|
||||
" <th>Density</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.05696</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.04158</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>0.05</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.08438</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>0.05</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.08112</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>0.05</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.07781</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" Al2O3 TiO2 Density\n",
|
||||
"0 0.00 0.0 1.05696\n",
|
||||
"1 0.00 0.0 1.04158\n",
|
||||
"2 0.05 0.0 1.08438\n",
|
||||
"3 0.05 0.0 1.08112\n",
|
||||
"4 0.05 0.0 1.07781"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0 30\n",
|
||||
"1 55\n",
|
||||
"2 25\n",
|
||||
"3 30\n",
|
||||
"4 35\n",
|
||||
"Name: T, dtype: int64"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"y_train = train[\"T\"]\n",
|
||||
"X_train = train.drop([\"T\"], axis=1)\n",
|
||||
"\n",
|
||||
"display(X_train.head())\n",
|
||||
"display(y_train.head())\n",
|
||||
"\n",
|
||||
"y_test = test[\"T\"]\n",
|
||||
"X_test = test.drop([\"T\"], axis=1)\n",
|
||||
"\n",
|
||||
"display(X_test.head())\n",
|
||||
"display(y_test.head())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.pipeline import make_pipeline\n",
|
||||
"from sklearn.preprocessing import PolynomialFeatures\n",
|
||||
"from sklearn import linear_model, tree, neighbors, ensemble\n",
|
||||
"\n",
|
||||
"random_state = 9\n",
|
||||
"\n",
|
||||
"models = {\n",
|
||||
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
|
||||
" \"linear_poly\": {\n",
|
||||
" \"model\": make_pipeline(\n",
|
||||
" PolynomialFeatures(degree=2),\n",
|
||||
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
||||
" )\n",
|
||||
" },\n",
|
||||
" \"linear_interact\": {\n",
|
||||
" \"model\": make_pipeline(\n",
|
||||
" PolynomialFeatures(interaction_only=True),\n",
|
||||
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
||||
" )\n",
|
||||
" },\n",
|
||||
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
|
||||
" \"decision_tree\": {\n",
|
||||
" \"model\": tree.DecisionTreeRegressor(random_state=random_state, max_depth=6, criterion=\"absolute_error\")\n",
|
||||
" },\n",
|
||||
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
|
||||
" \"random_forest\": {\n",
|
||||
" \"model\": ensemble.RandomForestRegressor(\n",
|
||||
" max_depth=7, random_state=random_state, n_jobs=-1\n",
|
||||
" )\n",
|
||||
" },\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model: linear\n",
|
||||
"Model: linear_poly\n",
|
||||
"Model: linear_interact\n",
|
||||
"Model: ridge\n",
|
||||
"Model: decision_tree\n",
|
||||
"Model: knn\n",
|
||||
"Model: random_forest\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import math\n",
|
||||
"from sklearn import metrics\n",
|
||||
"\n",
|
||||
"for model_name in models.keys():\n",
|
||||
" print(f\"Model: {model_name}\")\n",
|
||||
" fitted_model = models[model_name][\"model\"].fit(\n",
|
||||
" X_train.values, y_train.values.ravel()\n",
|
||||
" )\n",
|
||||
" y_train_pred = fitted_model.predict(X_train.values)\n",
|
||||
" y_test_pred = fitted_model.predict(X_test.values)\n",
|
||||
" models[model_name][\"fitted\"] = fitted_model\n",
|
||||
" models[model_name][\"MSE_train\"] = metrics.mean_squared_error(y_train, y_train_pred)\n",
|
||||
" models[model_name][\"MSE_test\"] = metrics.mean_squared_error(y_test, y_test_pred)\n",
|
||||
" models[model_name][\"MAE_train\"] = metrics.mean_absolute_error(y_train, y_train_pred)\n",
|
||||
" models[model_name][\"MAE_test\"] = metrics.mean_absolute_error(y_test, y_test_pred)\n",
|
||||
" models[model_name][\"R2_train\"] = metrics.r2_score(y_train, y_train_pred)\n",
|
||||
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<style type=\"text/css\">\n",
|
||||
"#T_2421b_row0_col0, #T_2421b_row0_col1 {\n",
|
||||
" background-color: #26818e;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row0_col3, #T_2421b_row6_col5 {\n",
|
||||
" background-color: #4e02a2;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row0_col5, #T_2421b_row6_col3 {\n",
|
||||
" background-color: #da5a6a;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row1_col0, #T_2421b_row2_col0, #T_2421b_row3_col0 {\n",
|
||||
" background-color: #25848e;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row1_col1 {\n",
|
||||
" background-color: #24868e;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row1_col3 {\n",
|
||||
" background-color: #6f00a8;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row1_col5, #T_2421b_row5_col3 {\n",
|
||||
" background-color: #d5536f;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row2_col1 {\n",
|
||||
" background-color: #23888e;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row2_col3 {\n",
|
||||
" background-color: #7201a8;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row2_col5 {\n",
|
||||
" background-color: #d35171;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row3_col1 {\n",
|
||||
" background-color: #1f998a;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row3_col3 {\n",
|
||||
" background-color: #9814a0;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row3_col5 {\n",
|
||||
" background-color: #c23c81;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row4_col0 {\n",
|
||||
" background-color: #23898e;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row4_col1 {\n",
|
||||
" background-color: #1e9d89;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row4_col3 {\n",
|
||||
" background-color: #a21d9a;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row4_col5 {\n",
|
||||
" background-color: #be3885;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row5_col0 {\n",
|
||||
" background-color: #5ac864;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row5_col1 {\n",
|
||||
" background-color: #9bd93c;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row5_col5 {\n",
|
||||
" background-color: #5601a4;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_2421b_row6_col0, #T_2421b_row6_col1 {\n",
|
||||
" background-color: #a8db34;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"</style>\n",
|
||||
"<table id=\"T_2421b\">\n",
|
||||
" <thead>\n",
|
||||
" <tr>\n",
|
||||
" <th class=\"blank level0\" > </th>\n",
|
||||
" <th id=\"T_2421b_level0_col0\" class=\"col_heading level0 col0\" >MSE_train</th>\n",
|
||||
" <th id=\"T_2421b_level0_col1\" class=\"col_heading level0 col1\" >MSE_test</th>\n",
|
||||
" <th id=\"T_2421b_level0_col2\" class=\"col_heading level0 col2\" >MAE_train</th>\n",
|
||||
" <th id=\"T_2421b_level0_col3\" class=\"col_heading level0 col3\" >MAE_test</th>\n",
|
||||
" <th id=\"T_2421b_level0_col4\" class=\"col_heading level0 col4\" >R2_train</th>\n",
|
||||
" <th id=\"T_2421b_level0_col5\" class=\"col_heading level0 col5\" >R2_test</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_2421b_level0_row0\" class=\"row_heading level0 row0\" >linear_poly</th>\n",
|
||||
" <td id=\"T_2421b_row0_col0\" class=\"data row0 col0\" >0.302768</td>\n",
|
||||
" <td id=\"T_2421b_row0_col1\" class=\"data row0 col1\" >0.203293</td>\n",
|
||||
" <td id=\"T_2421b_row0_col2\" class=\"data row0 col2\" >0.419467</td>\n",
|
||||
" <td id=\"T_2421b_row0_col3\" class=\"data row0 col3\" >0.392687</td>\n",
|
||||
" <td id=\"T_2421b_row0_col4\" class=\"data row0 col4\" >0.998860</td>\n",
|
||||
" <td id=\"T_2421b_row0_col5\" class=\"data row0 col5\" >0.999047</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_2421b_level0_row1\" class=\"row_heading level0 row1\" >linear_interact</th>\n",
|
||||
" <td id=\"T_2421b_row1_col0\" class=\"data row1 col0\" >9.693323</td>\n",
|
||||
" <td id=\"T_2421b_row1_col1\" class=\"data row1 col1\" >10.875442</td>\n",
|
||||
" <td id=\"T_2421b_row1_col2\" class=\"data row1 col2\" >2.544944</td>\n",
|
||||
" <td id=\"T_2421b_row1_col3\" class=\"data row1 col3\" >2.718424</td>\n",
|
||||
" <td id=\"T_2421b_row1_col4\" class=\"data row1 col4\" >0.963492</td>\n",
|
||||
" <td id=\"T_2421b_row1_col5\" class=\"data row1 col5\" >0.949019</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_2421b_level0_row2\" class=\"row_heading level0 row2\" >linear</th>\n",
|
||||
" <td id=\"T_2421b_row2_col0\" class=\"data row2 col0\" >10.468503</td>\n",
|
||||
" <td id=\"T_2421b_row2_col1\" class=\"data row2 col1\" >14.820315</td>\n",
|
||||
" <td id=\"T_2421b_row2_col2\" class=\"data row2 col2\" >2.657476</td>\n",
|
||||
" <td id=\"T_2421b_row2_col3\" class=\"data row2 col3\" >2.930229</td>\n",
|
||||
" <td id=\"T_2421b_row2_col4\" class=\"data row2 col4\" >0.960572</td>\n",
|
||||
" <td id=\"T_2421b_row2_col5\" class=\"data row2 col5\" >0.930526</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_2421b_level0_row3\" class=\"row_heading level0 row3\" >decision_tree</th>\n",
|
||||
" <td id=\"T_2421b_row3_col0\" class=\"data row3 col0\" >10.526316</td>\n",
|
||||
" <td id=\"T_2421b_row3_col1\" class=\"data row3 col1\" >47.426471</td>\n",
|
||||
" <td id=\"T_2421b_row3_col2\" class=\"data row3 col2\" >1.842105</td>\n",
|
||||
" <td id=\"T_2421b_row3_col3\" class=\"data row3 col3\" >5.735294</td>\n",
|
||||
" <td id=\"T_2421b_row3_col4\" class=\"data row3 col4\" >0.960355</td>\n",
|
||||
" <td id=\"T_2421b_row3_col5\" class=\"data row3 col5\" >0.777676</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_2421b_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
|
||||
" <td id=\"T_2421b_row4_col0\" class=\"data row4 col0\" >20.243876</td>\n",
|
||||
" <td id=\"T_2421b_row4_col1\" class=\"data row4 col1\" >54.501240</td>\n",
|
||||
" <td id=\"T_2421b_row4_col2\" class=\"data row4 col2\" >3.592953</td>\n",
|
||||
" <td id=\"T_2421b_row4_col3\" class=\"data row4 col3\" >6.598133</td>\n",
|
||||
" <td id=\"T_2421b_row4_col4\" class=\"data row4 col4\" >0.923755</td>\n",
|
||||
" <td id=\"T_2421b_row4_col5\" class=\"data row4 col5\" >0.744512</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_2421b_level0_row5\" class=\"row_heading level0 row5\" >knn</th>\n",
|
||||
" <td id=\"T_2421b_row5_col0\" class=\"data row5 col0\" >174.100430</td>\n",
|
||||
" <td id=\"T_2421b_row5_col1\" class=\"data row5 col1\" >191.176471</td>\n",
|
||||
" <td id=\"T_2421b_row5_col2\" class=\"data row5 col2\" >10.808271</td>\n",
|
||||
" <td id=\"T_2421b_row5_col3\" class=\"data row5 col3\" >11.680672</td>\n",
|
||||
" <td id=\"T_2421b_row5_col4\" class=\"data row5 col4\" >0.344285</td>\n",
|
||||
" <td id=\"T_2421b_row5_col5\" class=\"data row5 col5\" >0.103812</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_2421b_level0_row6\" class=\"row_heading level0 row6\" >ridge</th>\n",
|
||||
" <td id=\"T_2421b_row6_col0\" class=\"data row6 col0\" >243.364664</td>\n",
|
||||
" <td id=\"T_2421b_row6_col1\" class=\"data row6 col1\" >199.601477</td>\n",
|
||||
" <td id=\"T_2421b_row6_col2\" class=\"data row6 col2\" >13.472724</td>\n",
|
||||
" <td id=\"T_2421b_row6_col3\" class=\"data row6 col3\" >12.396799</td>\n",
|
||||
" <td id=\"T_2421b_row6_col4\" class=\"data row6 col4\" >0.083415</td>\n",
|
||||
" <td id=\"T_2421b_row6_col5\" class=\"data row6 col5\" >0.064317</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"<pandas.io.formats.style.Styler at 0x16867d550>"
|
||||
]
|
||||
},
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
|
||||
" [\"MSE_train\", \"MSE_test\", \"MAE_train\", \"MAE_test\", \"R2_train\", \"R2_test\"]\n",
|
||||
"]\n",
|
||||
"reg_metrics.sort_values(by=\"MAE_test\").style.background_gradient(\n",
|
||||
" cmap=\"viridis\", low=1, high=0.3, subset=[\"MSE_train\", \"MSE_test\"]\n",
|
||||
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"MAE_test\", \"R2_test\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"|--- Density <= 1.04\n",
|
||||
"| |--- Density <= 1.03\n",
|
||||
"| | |--- value: [70.00]\n",
|
||||
"| |--- Density > 1.03\n",
|
||||
"| | |--- Density <= 1.04\n",
|
||||
"| | | |--- value: [65.00]\n",
|
||||
"| | |--- Density > 1.04\n",
|
||||
"| | | |--- value: [60.00]\n",
|
||||
"|--- Density > 1.04\n",
|
||||
"| |--- Density <= 1.07\n",
|
||||
"| | |--- TiO2 <= 0.03\n",
|
||||
"| | | |--- Al2O3 <= 0.03\n",
|
||||
"| | | | |--- Density <= 1.05\n",
|
||||
"| | | | | |--- Density <= 1.05\n",
|
||||
"| | | | | | |--- value: [50.00]\n",
|
||||
"| | | | | |--- Density > 1.05\n",
|
||||
"| | | | | | |--- value: [42.50]\n",
|
||||
"| | | | |--- Density > 1.05\n",
|
||||
"| | | | | |--- Density <= 1.06\n",
|
||||
"| | | | | | |--- value: [35.00]\n",
|
||||
"| | | | | |--- Density > 1.06\n",
|
||||
"| | | | | | |--- value: [22.50]\n",
|
||||
"| | | |--- Al2O3 > 0.03\n",
|
||||
"| | | | |--- Density <= 1.06\n",
|
||||
"| | | | | |--- Density <= 1.05\n",
|
||||
"| | | | | | |--- value: [70.00]\n",
|
||||
"| | | | | |--- Density > 1.05\n",
|
||||
"| | | | | | |--- value: [65.00]\n",
|
||||
"| | | | |--- Density > 1.06\n",
|
||||
"| | | | | |--- Density <= 1.07\n",
|
||||
"| | | | | | |--- value: [55.00]\n",
|
||||
"| | | | | |--- Density > 1.07\n",
|
||||
"| | | | | | |--- value: [50.00]\n",
|
||||
"| | |--- TiO2 > 0.03\n",
|
||||
"| | | |--- Density <= 1.06\n",
|
||||
"| | | | |--- value: [70.00]\n",
|
||||
"| | | |--- Density > 1.06\n",
|
||||
"| | | | |--- Density <= 1.06\n",
|
||||
"| | | | | |--- value: [65.00]\n",
|
||||
"| | | | |--- Density > 1.06\n",
|
||||
"| | | | | |--- value: [60.00]\n",
|
||||
"| |--- Density > 1.07\n",
|
||||
"| | |--- Density <= 1.12\n",
|
||||
"| | | |--- Density <= 1.08\n",
|
||||
"| | | | |--- Density <= 1.07\n",
|
||||
"| | | | | |--- value: [45.00]\n",
|
||||
"| | | | |--- Density > 1.07\n",
|
||||
"| | | | | |--- Density <= 1.08\n",
|
||||
"| | | | | | |--- value: [40.00]\n",
|
||||
"| | | | | |--- Density > 1.08\n",
|
||||
"| | | | | | |--- value: [35.00]\n",
|
||||
"| | | |--- Density > 1.08\n",
|
||||
"| | | | |--- Density <= 1.09\n",
|
||||
"| | | | | |--- value: [30.00]\n",
|
||||
"| | | | |--- Density > 1.09\n",
|
||||
"| | | | | |--- Al2O3 <= 0.03\n",
|
||||
"| | | | | | |--- value: [22.50]\n",
|
||||
"| | | | | |--- Al2O3 > 0.03\n",
|
||||
"| | | | | | |--- value: [20.00]\n",
|
||||
"| | |--- Density > 1.12\n",
|
||||
"| | | |--- Density <= 1.18\n",
|
||||
"| | | | |--- Density <= 1.15\n",
|
||||
"| | | | | |--- value: [70.00]\n",
|
||||
"| | | | |--- Density > 1.15\n",
|
||||
"| | | | | |--- Al2O3 <= 0.15\n",
|
||||
"| | | | | | |--- value: [65.00]\n",
|
||||
"| | | | | |--- Al2O3 > 0.15\n",
|
||||
"| | | | | | |--- value: [50.00]\n",
|
||||
"| | | |--- Density > 1.18\n",
|
||||
"| | | | |--- Al2O3 <= 0.15\n",
|
||||
"| | | | | |--- Density <= 1.20\n",
|
||||
"| | | | | | |--- value: [50.00]\n",
|
||||
"| | | | | |--- Density > 1.20\n",
|
||||
"| | | | | | |--- value: [30.00]\n",
|
||||
"| | | | |--- Al2O3 > 0.15\n",
|
||||
"| | | | | |--- Density <= 1.18\n",
|
||||
"| | | | | | |--- value: [30.00]\n",
|
||||
"| | | | | |--- Density > 1.18\n",
|
||||
"| | | | | | |--- value: [22.50]\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = models[\"decision_tree\"][\"fitted\"]\n",
|
||||
"rules = tree.export_text(\n",
|
||||
" model, feature_names=X_train.columns.values.tolist()\n",
|
||||
")\n",
|
||||
"print(rules)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"\n",
|
||||
"pickle.dump(model, open(\"data/temp_density_tree.model.sav\", \"wb\"))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
1417
temp_density_tree.ipynb
Normal file
1417
temp_density_tree.ipynb
Normal file
File diff suppressed because one or more lines are too long
763
temp_viscosity_regression.ipynb
Normal file
763
temp_viscosity_regression.ipynb
Normal file
@ -0,0 +1,763 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>T</th>\n",
|
||||
" <th>Al2O3</th>\n",
|
||||
" <th>TiO2</th>\n",
|
||||
" <th>Viscosity</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>20</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>3.707</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>25</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>3.180</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>35</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>2.361</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>45</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.832</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>50</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.629</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" T Al2O3 TiO2 Viscosity\n",
|
||||
"0 20 0.0 0.0 3.707\n",
|
||||
"1 25 0.0 0.0 3.180\n",
|
||||
"2 35 0.0 0.0 2.361\n",
|
||||
"3 45 0.0 0.0 1.832\n",
|
||||
"4 50 0.0 0.0 1.629"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>T</th>\n",
|
||||
" <th>Al2O3</th>\n",
|
||||
" <th>TiO2</th>\n",
|
||||
" <th>Viscosity</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>30</td>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>2.716</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>40</td>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>2.073</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>60</td>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.329</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>65</td>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.211</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>25</td>\n",
|
||||
" <td>0.05</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>4.120</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" T Al2O3 TiO2 Viscosity\n",
|
||||
"0 30 0.00 0.0 2.716\n",
|
||||
"1 40 0.00 0.0 2.073\n",
|
||||
"2 60 0.00 0.0 1.329\n",
|
||||
"3 65 0.00 0.0 1.211\n",
|
||||
"4 25 0.05 0.0 4.120"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"train = pd.read_csv(\"data/viscosity_train.csv\", sep=\";\", decimal=\",\")\n",
|
||||
"test = pd.read_csv(\"data/viscosity_test.csv\", sep=\";\", decimal=\",\")\n",
|
||||
"\n",
|
||||
"display(train.head())\n",
|
||||
"display(test.head())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>Al2O3</th>\n",
|
||||
" <th>TiO2</th>\n",
|
||||
" <th>Viscosity</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>3.707</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>3.180</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>2.361</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.832</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.629</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" Al2O3 TiO2 Viscosity\n",
|
||||
"0 0.0 0.0 3.707\n",
|
||||
"1 0.0 0.0 3.180\n",
|
||||
"2 0.0 0.0 2.361\n",
|
||||
"3 0.0 0.0 1.832\n",
|
||||
"4 0.0 0.0 1.629"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0 20\n",
|
||||
"1 25\n",
|
||||
"2 35\n",
|
||||
"3 45\n",
|
||||
"4 50\n",
|
||||
"Name: T, dtype: int64"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>Al2O3</th>\n",
|
||||
" <th>TiO2</th>\n",
|
||||
" <th>Viscosity</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>2.716</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>2.073</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.329</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.211</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>0.05</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>4.120</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" Al2O3 TiO2 Viscosity\n",
|
||||
"0 0.00 0.0 2.716\n",
|
||||
"1 0.00 0.0 2.073\n",
|
||||
"2 0.00 0.0 1.329\n",
|
||||
"3 0.00 0.0 1.211\n",
|
||||
"4 0.05 0.0 4.120"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0 30\n",
|
||||
"1 40\n",
|
||||
"2 60\n",
|
||||
"3 65\n",
|
||||
"4 25\n",
|
||||
"Name: T, dtype: int64"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"y_train = train[\"T\"]\n",
|
||||
"X_train = train.drop([\"T\"], axis=1)\n",
|
||||
"\n",
|
||||
"display(X_train.head())\n",
|
||||
"display(y_train.head())\n",
|
||||
"\n",
|
||||
"y_test = test[\"T\"]\n",
|
||||
"X_test = test.drop([\"T\"], axis=1)\n",
|
||||
"\n",
|
||||
"display(X_test.head())\n",
|
||||
"display(y_test.head())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.pipeline import make_pipeline\n",
|
||||
"from sklearn.preprocessing import PolynomialFeatures\n",
|
||||
"from sklearn import linear_model, tree, neighbors, ensemble\n",
|
||||
"\n",
|
||||
"random_state = 9\n",
|
||||
"\n",
|
||||
"models = {\n",
|
||||
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
|
||||
" \"linear_poly\": {\n",
|
||||
" \"model\": make_pipeline(\n",
|
||||
" PolynomialFeatures(degree=2),\n",
|
||||
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
||||
" )\n",
|
||||
" },\n",
|
||||
" \"linear_interact\": {\n",
|
||||
" \"model\": make_pipeline(\n",
|
||||
" PolynomialFeatures(interaction_only=True),\n",
|
||||
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
||||
" )\n",
|
||||
" },\n",
|
||||
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
|
||||
" \"decision_tree\": {\n",
|
||||
" \"model\": tree.DecisionTreeRegressor(random_state=random_state, max_depth=6, criterion=\"absolute_error\")\n",
|
||||
" },\n",
|
||||
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
|
||||
" \"random_forest\": {\n",
|
||||
" \"model\": ensemble.RandomForestRegressor(\n",
|
||||
" max_depth=7, random_state=random_state, n_jobs=-1\n",
|
||||
" )\n",
|
||||
" },\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model: linear\n",
|
||||
"Model: linear_poly\n",
|
||||
"Model: linear_interact\n",
|
||||
"Model: ridge\n",
|
||||
"Model: decision_tree\n",
|
||||
"Model: knn\n",
|
||||
"Model: random_forest\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import math\n",
|
||||
"from sklearn import metrics\n",
|
||||
"\n",
|
||||
"for model_name in models.keys():\n",
|
||||
" print(f\"Model: {model_name}\")\n",
|
||||
" fitted_model = models[model_name][\"model\"].fit(\n",
|
||||
" X_train.values, y_train.values.ravel()\n",
|
||||
" )\n",
|
||||
" y_train_pred = fitted_model.predict(X_train.values)\n",
|
||||
" y_test_pred = fitted_model.predict(X_test.values)\n",
|
||||
" models[model_name][\"fitted\"] = fitted_model\n",
|
||||
" models[model_name][\"MSE_train\"] = metrics.mean_squared_error(y_train, y_train_pred)\n",
|
||||
" models[model_name][\"MSE_test\"] = metrics.mean_squared_error(y_test, y_test_pred)\n",
|
||||
" models[model_name][\"MAE_train\"] = metrics.mean_absolute_error(y_train, y_train_pred)\n",
|
||||
" models[model_name][\"MAE_test\"] = metrics.mean_absolute_error(y_test, y_test_pred)\n",
|
||||
" models[model_name][\"R2_train\"] = metrics.r2_score(y_train, y_train_pred)\n",
|
||||
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<style type=\"text/css\">\n",
|
||||
"#T_6c8fb_row0_col0 {\n",
|
||||
" background-color: #25838e;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row0_col1, #T_6c8fb_row5_col0 {\n",
|
||||
" background-color: #26818e;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row0_col3, #T_6c8fb_row6_col5 {\n",
|
||||
" background-color: #4e02a2;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row0_col5, #T_6c8fb_row6_col3 {\n",
|
||||
" background-color: #da5a6a;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row1_col0, #T_6c8fb_row1_col1 {\n",
|
||||
" background-color: #1fa187;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row1_col3 {\n",
|
||||
" background-color: #a31e9a;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row1_col5, #T_6c8fb_row3_col3 {\n",
|
||||
" background-color: #b83289;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row2_col0 {\n",
|
||||
" background-color: #25ac82;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row2_col1 {\n",
|
||||
" background-color: #32b67a;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row2_col3 {\n",
|
||||
" background-color: #b6308b;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row2_col5 {\n",
|
||||
" background-color: #9e199d;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row3_col0 {\n",
|
||||
" background-color: #2eb37c;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row3_col1 {\n",
|
||||
" background-color: #35b779;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row3_col5 {\n",
|
||||
" background-color: #9c179e;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row4_col0 {\n",
|
||||
" background-color: #238a8d;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row4_col1, #T_6c8fb_row5_col1 {\n",
|
||||
" background-color: #54c568;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row4_col3 {\n",
|
||||
" background-color: #c5407e;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row4_col5, #T_6c8fb_row5_col5 {\n",
|
||||
" background-color: #8405a7;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row5_col3 {\n",
|
||||
" background-color: #cd4a76;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_6c8fb_row6_col0, #T_6c8fb_row6_col1 {\n",
|
||||
" background-color: #a8db34;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"</style>\n",
|
||||
"<table id=\"T_6c8fb\">\n",
|
||||
" <thead>\n",
|
||||
" <tr>\n",
|
||||
" <th class=\"blank level0\" > </th>\n",
|
||||
" <th id=\"T_6c8fb_level0_col0\" class=\"col_heading level0 col0\" >MSE_train</th>\n",
|
||||
" <th id=\"T_6c8fb_level0_col1\" class=\"col_heading level0 col1\" >MSE_test</th>\n",
|
||||
" <th id=\"T_6c8fb_level0_col2\" class=\"col_heading level0 col2\" >MAE_train</th>\n",
|
||||
" <th id=\"T_6c8fb_level0_col3\" class=\"col_heading level0 col3\" >MAE_test</th>\n",
|
||||
" <th id=\"T_6c8fb_level0_col4\" class=\"col_heading level0 col4\" >R2_train</th>\n",
|
||||
" <th id=\"T_6c8fb_level0_col5\" class=\"col_heading level0 col5\" >R2_test</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_6c8fb_level0_row0\" class=\"row_heading level0 row0\" >linear_poly</th>\n",
|
||||
" <td id=\"T_6c8fb_row0_col0\" class=\"data row0 col0\" >4.827768</td>\n",
|
||||
" <td id=\"T_6c8fb_row0_col1\" class=\"data row0 col1\" >4.877296</td>\n",
|
||||
" <td id=\"T_6c8fb_row0_col2\" class=\"data row0 col2\" >1.522643</td>\n",
|
||||
" <td id=\"T_6c8fb_row0_col3\" class=\"data row0 col3\" >1.743058</td>\n",
|
||||
" <td id=\"T_6c8fb_row0_col4\" class=\"data row0 col4\" >0.980864</td>\n",
|
||||
" <td id=\"T_6c8fb_row0_col5\" class=\"data row0 col5\" >0.974964</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_6c8fb_level0_row1\" class=\"row_heading level0 row1\" >linear_interact</th>\n",
|
||||
" <td id=\"T_6c8fb_row1_col0\" class=\"data row1 col0\" >21.786348</td>\n",
|
||||
" <td id=\"T_6c8fb_row1_col1\" class=\"data row1 col1\" >23.459572</td>\n",
|
||||
" <td id=\"T_6c8fb_row1_col2\" class=\"data row1 col2\" >3.830996</td>\n",
|
||||
" <td id=\"T_6c8fb_row1_col3\" class=\"data row1 col3\" >4.381115</td>\n",
|
||||
" <td id=\"T_6c8fb_row1_col4\" class=\"data row1 col4\" >0.913644</td>\n",
|
||||
" <td id=\"T_6c8fb_row1_col5\" class=\"data row1 col5\" >0.879577</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_6c8fb_level0_row2\" class=\"row_heading level0 row2\" >linear</th>\n",
|
||||
" <td id=\"T_6c8fb_row2_col0\" class=\"data row2 col0\" >27.766510</td>\n",
|
||||
" <td id=\"T_6c8fb_row2_col1\" class=\"data row2 col1\" >35.430313</td>\n",
|
||||
" <td id=\"T_6c8fb_row2_col2\" class=\"data row2 col2\" >4.088006</td>\n",
|
||||
" <td id=\"T_6c8fb_row2_col3\" class=\"data row2 col3\" >5.106782</td>\n",
|
||||
" <td id=\"T_6c8fb_row2_col4\" class=\"data row2 col4\" >0.889940</td>\n",
|
||||
" <td id=\"T_6c8fb_row2_col5\" class=\"data row2 col5\" >0.818129</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_6c8fb_level0_row3\" class=\"row_heading level0 row3\" >ridge</th>\n",
|
||||
" <td id=\"T_6c8fb_row3_col0\" class=\"data row3 col0\" >31.827476</td>\n",
|
||||
" <td id=\"T_6c8fb_row3_col1\" class=\"data row3 col1\" >36.230606</td>\n",
|
||||
" <td id=\"T_6c8fb_row3_col2\" class=\"data row3 col2\" >4.383008</td>\n",
|
||||
" <td id=\"T_6c8fb_row3_col3\" class=\"data row3 col3\" >5.226480</td>\n",
|
||||
" <td id=\"T_6c8fb_row3_col4\" class=\"data row3 col4\" >0.873843</td>\n",
|
||||
" <td id=\"T_6c8fb_row3_col5\" class=\"data row3 col5\" >0.814021</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_6c8fb_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
|
||||
" <td id=\"T_6c8fb_row4_col0\" class=\"data row4 col0\" >8.525285</td>\n",
|
||||
" <td id=\"T_6c8fb_row4_col1\" class=\"data row4 col1\" >45.444651</td>\n",
|
||||
" <td id=\"T_6c8fb_row4_col2\" class=\"data row4 col2\" >2.542935</td>\n",
|
||||
" <td id=\"T_6c8fb_row4_col3\" class=\"data row4 col3\" >5.749510</td>\n",
|
||||
" <td id=\"T_6c8fb_row4_col4\" class=\"data row4 col4\" >0.966208</td>\n",
|
||||
" <td id=\"T_6c8fb_row4_col5\" class=\"data row4 col5\" >0.766723</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_6c8fb_level0_row5\" class=\"row_heading level0 row5\" >decision_tree</th>\n",
|
||||
" <td id=\"T_6c8fb_row5_col0\" class=\"data row5 col0\" >3.289474</td>\n",
|
||||
" <td id=\"T_6c8fb_row5_col1\" class=\"data row5 col1\" >45.588235</td>\n",
|
||||
" <td id=\"T_6c8fb_row5_col2\" class=\"data row5 col2\" >0.921053</td>\n",
|
||||
" <td id=\"T_6c8fb_row5_col3\" class=\"data row5 col3\" >6.176471</td>\n",
|
||||
" <td id=\"T_6c8fb_row5_col4\" class=\"data row5 col4\" >0.986961</td>\n",
|
||||
" <td id=\"T_6c8fb_row5_col5\" class=\"data row5 col5\" >0.765986</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_6c8fb_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
|
||||
" <td id=\"T_6c8fb_row6_col0\" class=\"data row6 col0\" >61.855532</td>\n",
|
||||
" <td id=\"T_6c8fb_row6_col1\" class=\"data row6 col1\" >64.165666</td>\n",
|
||||
" <td id=\"T_6c8fb_row6_col2\" class=\"data row6 col2\" >6.522556</td>\n",
|
||||
" <td id=\"T_6c8fb_row6_col3\" class=\"data row6 col3\" >6.806723</td>\n",
|
||||
" <td id=\"T_6c8fb_row6_col4\" class=\"data row6 col4\" >0.754819</td>\n",
|
||||
" <td id=\"T_6c8fb_row6_col5\" class=\"data row6 col5\" >0.670624</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"<pandas.io.formats.style.Styler at 0x11b747260>"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
|
||||
" [\"MSE_train\", \"MSE_test\", \"MAE_train\", \"MAE_test\", \"R2_train\", \"R2_test\"]\n",
|
||||
"]\n",
|
||||
"reg_metrics.sort_values(by=\"MAE_test\").style.background_gradient(\n",
|
||||
" cmap=\"viridis\", low=1, high=0.3, subset=[\"MSE_train\", \"MSE_test\"]\n",
|
||||
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"MAE_test\", \"R2_test\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"|--- Viscosity <= 2.86\n",
|
||||
"| |--- Viscosity <= 1.38\n",
|
||||
"| | |--- value: [70.00]\n",
|
||||
"| |--- Viscosity > 1.38\n",
|
||||
"| | |--- Viscosity <= 1.78\n",
|
||||
"| | | |--- Viscosity <= 1.72\n",
|
||||
"| | | | |--- TiO2 <= 0.03\n",
|
||||
"| | | | | |--- Al2O3 <= 0.03\n",
|
||||
"| | | | | | |--- value: [52.50]\n",
|
||||
"| | | | | |--- Al2O3 > 0.03\n",
|
||||
"| | | | | | |--- value: [57.50]\n",
|
||||
"| | | | |--- TiO2 > 0.03\n",
|
||||
"| | | | | |--- value: [60.00]\n",
|
||||
"| | | |--- Viscosity > 1.72\n",
|
||||
"| | | | |--- value: [70.00]\n",
|
||||
"| | |--- Viscosity > 1.78\n",
|
||||
"| | | |--- TiO2 <= 0.18\n",
|
||||
"| | | | |--- Al2O3 <= 0.18\n",
|
||||
"| | | | | |--- Viscosity <= 2.24\n",
|
||||
"| | | | | | |--- value: [50.00]\n",
|
||||
"| | | | | |--- Viscosity > 2.24\n",
|
||||
"| | | | | | |--- value: [40.00]\n",
|
||||
"| | | | |--- Al2O3 > 0.18\n",
|
||||
"| | | | | |--- Viscosity <= 2.29\n",
|
||||
"| | | | | | |--- value: [60.00]\n",
|
||||
"| | | | | |--- Viscosity > 2.29\n",
|
||||
"| | | | | | |--- value: [55.00]\n",
|
||||
"| | | |--- TiO2 > 0.18\n",
|
||||
"| | | | |--- Viscosity <= 2.22\n",
|
||||
"| | | | | |--- value: [70.00]\n",
|
||||
"| | | | |--- Viscosity > 2.22\n",
|
||||
"| | | | | |--- Viscosity <= 2.69\n",
|
||||
"| | | | | | |--- value: [60.00]\n",
|
||||
"| | | | | |--- Viscosity > 2.69\n",
|
||||
"| | | | | | |--- value: [55.00]\n",
|
||||
"|--- Viscosity > 2.86\n",
|
||||
"| |--- Viscosity <= 3.64\n",
|
||||
"| | |--- TiO2 <= 0.18\n",
|
||||
"| | | |--- Viscosity <= 3.15\n",
|
||||
"| | | | |--- value: [35.00]\n",
|
||||
"| | | |--- Viscosity > 3.15\n",
|
||||
"| | | | |--- Viscosity <= 3.47\n",
|
||||
"| | | | | |--- Al2O3 <= 0.03\n",
|
||||
"| | | | | | |--- value: [25.00]\n",
|
||||
"| | | | | |--- Al2O3 > 0.03\n",
|
||||
"| | | | | | |--- value: [30.00]\n",
|
||||
"| | | | |--- Viscosity > 3.47\n",
|
||||
"| | | | | |--- value: [40.00]\n",
|
||||
"| | |--- TiO2 > 0.18\n",
|
||||
"| | | |--- value: [45.00]\n",
|
||||
"| |--- Viscosity > 3.64\n",
|
||||
"| | |--- Viscosity <= 6.27\n",
|
||||
"| | | |--- TiO2 <= 0.18\n",
|
||||
"| | | | |--- Al2O3 <= 0.18\n",
|
||||
"| | | | | |--- TiO2 <= 0.03\n",
|
||||
"| | | | | | |--- value: [20.00]\n",
|
||||
"| | | | | |--- TiO2 > 0.03\n",
|
||||
"| | | | | | |--- value: [22.50]\n",
|
||||
"| | | | |--- Al2O3 > 0.18\n",
|
||||
"| | | | | |--- Viscosity <= 4.42\n",
|
||||
"| | | | | | |--- value: [35.00]\n",
|
||||
"| | | | | |--- Viscosity > 4.42\n",
|
||||
"| | | | | | |--- value: [27.50]\n",
|
||||
"| | | |--- TiO2 > 0.18\n",
|
||||
"| | | | |--- Viscosity <= 4.65\n",
|
||||
"| | | | | |--- value: [35.00]\n",
|
||||
"| | | | |--- Viscosity > 4.65\n",
|
||||
"| | | | | |--- Viscosity <= 5.40\n",
|
||||
"| | | | | | |--- value: [30.00]\n",
|
||||
"| | | | | |--- Viscosity > 5.40\n",
|
||||
"| | | | | | |--- value: [25.00]\n",
|
||||
"| | |--- Viscosity > 6.27\n",
|
||||
"| | | |--- value: [20.00]\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = models[\"decision_tree\"][\"fitted\"]\n",
|
||||
"rules = tree.export_text(\n",
|
||||
" model, feature_names=X_train.columns.values.tolist()\n",
|
||||
")\n",
|
||||
"print(rules)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"\n",
|
||||
"pickle.dump(model, open(\"data/temp_viscosity_tree.model.sav\", \"wb\"))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
1357
temp_viscosity_tree.ipynb
Normal file
1357
temp_viscosity_tree.ipynb
Normal file
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user