fuzzy-rules-generator/viscosity_regression.ipynb

815 lines
26 KiB
Plaintext
Raw Permalink Normal View History

2024-11-01 11:04:05 +04:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>T</th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" <th>Viscosity</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>20</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>3.707</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>25</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>3.180</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>35</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>2.361</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" T Al2O3 TiO2 Viscosity\n",
"0 20 0.0 0.0 3.707\n",
"1 25 0.0 0.0 3.180\n",
"2 35 0.0 0.0 2.361"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>T</th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" <th>Viscosity</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>30</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>2.716</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>40</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>2.073</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>60</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.329</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" T Al2O3 TiO2 Viscosity\n",
"0 30 0.0 0.0 2.716\n",
"1 40 0.0 0.0 2.073\n",
"2 60 0.0 0.0 1.329"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import pandas as pd\n",
"\n",
"viscosity_train = pd.read_csv(\"data/viscosity_train.csv\", sep=\";\", decimal=\",\")\n",
"viscosity_test = pd.read_csv(\"data/viscosity_test.csv\", sep=\";\", decimal=\",\")\n",
"\n",
"display(viscosity_train.head(3))\n",
"display(viscosity_test.head(3))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>T</th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>20</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>25</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>35</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" T Al2O3 TiO2\n",
"0 20 0.0 0.0\n",
"1 25 0.0 0.0\n",
"2 35 0.0 0.0"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0 3.707\n",
"1 3.180\n",
"2 2.361\n",
"Name: Viscosity, dtype: float64"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>T</th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>30</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>40</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>60</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" T Al2O3 TiO2\n",
"0 30 0.0 0.0\n",
"1 40 0.0 0.0\n",
"2 60 0.0 0.0"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0 2.716\n",
"1 2.073\n",
"2 1.329\n",
"Name: Viscosity, dtype: float64"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"viscosity_y_train = viscosity_train[\"Viscosity\"]\n",
"viscosity_train = viscosity_train.drop([\"Viscosity\"], axis=1)\n",
"\n",
"display(viscosity_train.head(3))\n",
"display(viscosity_y_train.head(3))\n",
"\n",
"viscosity_y_test = viscosity_test[\"Viscosity\"]\n",
"viscosity_test = viscosity_test.drop([\"Viscosity\"], axis=1)\n",
"\n",
"display(viscosity_test.head(3))\n",
"display(viscosity_y_test.head(3))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"from sklearn import linear_model, tree, neighbors, ensemble\n",
"\n",
"random_state = 9\n",
"\n",
"models = {\n",
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: linear\n",
"Model: linear_poly\n",
"Model: linear_interact\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: random_forest\n"
]
}
],
"source": [
"import math\n",
"from sklearn import metrics\n",
"\n",
"for model_name in models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" fitted_model = models[model_name][\"model\"].fit(\n",
" viscosity_train.values, viscosity_y_train.values.ravel()\n",
" )\n",
" y_train_pred = fitted_model.predict(viscosity_train.values)\n",
" y_test_pred = fitted_model.predict(viscosity_test.values)\n",
" models[model_name][\"fitted\"] = fitted_model\n",
" models[model_name][\"train_preds\"] = y_train_pred\n",
" models[model_name][\"preds\"] = y_test_pred\n",
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
" metrics.mean_squared_error(viscosity_y_train, y_train_pred)\n",
" )\n",
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
" metrics.mean_squared_error(viscosity_y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
" metrics.mean_absolute_error(viscosity_y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"R2_test\"] = metrics.r2_score(viscosity_y_test, y_test_pred)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_0b35b_row0_col0 {\n",
" background-color: #21918c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row0_col1, #T_0b35b_row4_col0 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row0_col2, #T_0b35b_row6_col3 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row0_col3, #T_0b35b_row6_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row1_col0 {\n",
" background-color: #31b57b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row1_col1 {\n",
" background-color: #22a884;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row1_col2 {\n",
" background-color: #a31e9a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row1_col3 {\n",
" background-color: #c13b82;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row2_col0 {\n",
" background-color: #1f9e89;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row2_col1 {\n",
" background-color: #2cb17e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row2_col2 {\n",
" background-color: #ab2494;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row2_col3 {\n",
" background-color: #b7318a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row3_col0 {\n",
" background-color: #54c568;\n",
" color: #000000;\n",
"}\n",
"#T_0b35b_row3_col1, #T_0b35b_row4_col1 {\n",
" background-color: #3aba76;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row3_col2 {\n",
" background-color: #b02991;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row3_col3 {\n",
" background-color: #ad2793;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row4_col2 {\n",
" background-color: #b83289;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row4_col3 {\n",
" background-color: #ac2694;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row5_col0 {\n",
" background-color: #48c16e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row5_col1 {\n",
" background-color: #52c569;\n",
" color: #000000;\n",
"}\n",
"#T_0b35b_row5_col2 {\n",
" background-color: #c23c81;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row5_col3 {\n",
" background-color: #9a169f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0b35b_row6_col0, #T_0b35b_row6_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"</style>\n",
"<table id=\"T_0b35b\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_0b35b_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
" <th id=\"T_0b35b_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
" <th id=\"T_0b35b_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
" <th id=\"T_0b35b_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_0b35b_level0_row0\" class=\"row_heading level0 row0\" >linear_poly</th>\n",
" <td id=\"T_0b35b_row0_col0\" class=\"data row0 col0\" >0.150745</td>\n",
" <td id=\"T_0b35b_row0_col1\" class=\"data row0 col1\" >0.139507</td>\n",
" <td id=\"T_0b35b_row0_col2\" class=\"data row0 col2\" >0.336239</td>\n",
" <td id=\"T_0b35b_row0_col3\" class=\"data row0 col3\" >0.978119</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0b35b_level0_row1\" class=\"row_heading level0 row1\" >linear_interact</th>\n",
" <td id=\"T_0b35b_row1_col0\" class=\"data row1 col0\" >0.361309</td>\n",
" <td id=\"T_0b35b_row1_col1\" class=\"data row1 col1\" >0.303389</td>\n",
" <td id=\"T_0b35b_row1_col2\" class=\"data row1 col2\" >0.527911</td>\n",
" <td id=\"T_0b35b_row1_col3\" class=\"data row1 col3\" >0.896517</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0b35b_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
" <td id=\"T_0b35b_row2_col0\" class=\"data row2 col0\" >0.226420</td>\n",
" <td id=\"T_0b35b_row2_col1\" class=\"data row2 col1\" >0.341014</td>\n",
" <td id=\"T_0b35b_row2_col2\" class=\"data row2 col2\" >0.545765</td>\n",
" <td id=\"T_0b35b_row2_col3\" class=\"data row2 col3\" >0.869259</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0b35b_level0_row3\" class=\"row_heading level0 row3\" >ridge</th>\n",
" <td id=\"T_0b35b_row3_col0\" class=\"data row3 col0\" >0.472399</td>\n",
" <td id=\"T_0b35b_row3_col1\" class=\"data row3 col1\" >0.378573</td>\n",
" <td id=\"T_0b35b_row3_col2\" class=\"data row3 col2\" >0.559409</td>\n",
" <td id=\"T_0b35b_row3_col3\" class=\"data row3 col3\" >0.838873</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0b35b_level0_row4\" class=\"row_heading level0 row4\" >decision_tree</th>\n",
" <td id=\"T_0b35b_row4_col0\" class=\"data row4 col0\" >0.054533</td>\n",
" <td id=\"T_0b35b_row4_col1\" class=\"data row4 col1\" >0.379017</td>\n",
" <td id=\"T_0b35b_row4_col2\" class=\"data row4 col2\" >0.587467</td>\n",
" <td id=\"T_0b35b_row4_col3\" class=\"data row4 col3\" >0.838495</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0b35b_level0_row5\" class=\"row_heading level0 row5\" >linear</th>\n",
" <td id=\"T_0b35b_row5_col0\" class=\"data row5 col0\" >0.441760</td>\n",
" <td id=\"T_0b35b_row5_col1\" class=\"data row5 col1\" >0.428940</td>\n",
" <td id=\"T_0b35b_row5_col2\" class=\"data row5 col2\" >0.617212</td>\n",
" <td id=\"T_0b35b_row5_col3\" class=\"data row5 col3\" >0.793147</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0b35b_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
" <td id=\"T_0b35b_row6_col0\" class=\"data row6 col0\" >0.666903</td>\n",
" <td id=\"T_0b35b_row6_col1\" class=\"data row6 col1\" >0.566901</td>\n",
" <td id=\"T_0b35b_row6_col2\" class=\"data row6 col2\" >0.702700</td>\n",
" <td id=\"T_0b35b_row6_col3\" class=\"data row6 col3\" >0.638689</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x24995879c40>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
"]\n",
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'criterion': 'poisson', 'max_depth': 9, 'min_samples_split': 2}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"from sklearn import model_selection\n",
"\n",
"parameters = {\n",
" \"criterion\": [\"squared_error\", \"absolute_error\", \"friedman_mse\", \"poisson\"],\n",
" \"max_depth\": np.arange(1, 21).tolist()[0::2],\n",
" \"min_samples_split\": np.arange(2, 20).tolist()[0::2],\n",
"}\n",
"\n",
"grid = model_selection.GridSearchCV(\n",
" tree.DecisionTreeRegressor(random_state=random_state), parameters, n_jobs=-1\n",
")\n",
"\n",
"grid.fit(viscosity_train, viscosity_y_train)\n",
"grid.best_params_"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'RMSE_test': 0.37901722760783496,\n",
" 'RMAE_test': 0.5874671455143883,\n",
" 'R2_test': 0.8384951109125148}"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'RMSE_test': 0.39412315184917696,\n",
" 'RMAE_test': 0.593196723643326,\n",
" 'R2_test': 0.8253648477295591}"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = grid.best_estimator_\n",
"y_pred = model.predict(viscosity_test)\n",
"old_metrics = {\n",
" \"RMSE_test\": models[\"decision_tree\"][\"RMSE_test\"],\n",
" \"RMAE_test\": models[\"decision_tree\"][\"RMAE_test\"],\n",
" \"R2_test\": models[\"decision_tree\"][\"R2_test\"],\n",
"}\n",
"new_metrics = {}\n",
"new_metrics[\"RMSE_test\"] = math.sqrt(\n",
" metrics.mean_squared_error(viscosity_y_test, y_pred)\n",
")\n",
"new_metrics[\"RMAE_test\"] = math.sqrt(\n",
" metrics.mean_absolute_error(viscosity_y_test, y_pred)\n",
")\n",
"new_metrics[\"R2_test\"] = metrics.r2_score(viscosity_y_test, y_pred)\n",
"\n",
"display(old_metrics)\n",
"display(new_metrics)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"|--- T <= 32.50\n",
"| |--- TiO2 <= 0.18\n",
"| | |--- Al2O3 <= 0.18\n",
"| | | |--- T <= 22.50\n",
"| | | | |--- TiO2 <= 0.03\n",
"| | | | | |--- Al2O3 <= 0.03\n",
"| | | | | | |--- value: [3.71]\n",
"| | | | | |--- Al2O3 > 0.03\n",
"| | | | | | |--- value: [4.66]\n",
"| | | | |--- TiO2 > 0.03\n",
"| | | | | |--- value: [4.88]\n",
"| | | |--- T > 22.50\n",
"| | | | |--- TiO2 <= 0.03\n",
"| | | | | |--- Al2O3 <= 0.03\n",
"| | | | | | |--- value: [3.18]\n",
"| | | | | |--- Al2O3 > 0.03\n",
"| | | | | | |--- value: [3.38]\n",
"| | | | |--- TiO2 > 0.03\n",
"| | | | | |--- value: [4.24]\n",
"| | |--- Al2O3 > 0.18\n",
"| | | |--- T <= 22.50\n",
"| | | | |--- value: [6.67]\n",
"| | | |--- T > 22.50\n",
"| | | | |--- T <= 27.50\n",
"| | | | | |--- value: [5.59]\n",
"| | | | |--- T > 27.50\n",
"| | | | | |--- value: [4.73]\n",
"| |--- TiO2 > 0.18\n",
"| | |--- T <= 22.50\n",
"| | | |--- value: [7.13]\n",
"| | |--- T > 22.50\n",
"| | | |--- T <= 27.50\n",
"| | | | |--- value: [5.87]\n",
"| | | |--- T > 27.50\n",
"| | | | |--- value: [4.94]\n",
"|--- T > 32.50\n",
"| |--- T <= 47.50\n",
"| | |--- TiO2 <= 0.18\n",
"| | | |--- Al2O3 <= 0.18\n",
"| | | | |--- T <= 42.50\n",
"| | | | | |--- TiO2 <= 0.03\n",
"| | | | | | |--- Al2O3 <= 0.03\n",
"| | | | | | | |--- value: [2.36]\n",
"| | | | | | |--- Al2O3 > 0.03\n",
"| | | | | | | |--- value: [2.68]\n",
"| | | | | |--- TiO2 > 0.03\n",
"| | | | | | |--- T <= 37.50\n",
"| | | | | | | |--- value: [3.12]\n",
"| | | | | | |--- T > 37.50\n",
"| | | | | | | |--- value: [2.65]\n",
"| | | | |--- T > 42.50\n",
"| | | | | |--- TiO2 <= 0.03\n",
"| | | | | | |--- value: [1.83]\n",
"| | | | | |--- TiO2 > 0.03\n",
"| | | | | | |--- value: [2.40]\n",
"| | | |--- Al2O3 > 0.18\n",
"| | | | |--- T <= 37.50\n",
"| | | | | |--- value: [4.12]\n",
"| | | | |--- T > 37.50\n",
"| | | | | |--- value: [3.56]\n",
"| | |--- TiO2 > 0.18\n",
"| | | |--- T <= 40.00\n",
"| | | | |--- value: [4.35]\n",
"| | | |--- T > 40.00\n",
"| | | | |--- value: [3.56]\n",
"| |--- T > 47.50\n",
"| | |--- TiO2 <= 0.18\n",
"| | | |--- Al2O3 <= 0.18\n",
"| | | | |--- T <= 52.50\n",
"| | | | | |--- TiO2 <= 0.03\n",
"| | | | | | |--- Al2O3 <= 0.03\n",
"| | | | | | | |--- value: [1.63]\n",
"| | | | | | |--- Al2O3 > 0.03\n",
"| | | | | | | |--- value: [1.90]\n",
"| | | | | |--- TiO2 > 0.03\n",
"| | | | | | |--- value: [2.11]\n",
"| | | | |--- T > 52.50\n",
"| | | | | |--- T <= 65.00\n",
"| | | | | | |--- TiO2 <= 0.03\n",
"| | | | | | | |--- value: [1.55]\n",
"| | | | | | |--- TiO2 > 0.03\n",
"| | | | | | | |--- value: [1.66]\n",
"| | | | | |--- T > 65.00\n",
"| | | | | | |--- TiO2 <= 0.03\n",
"| | | | | | | |--- value: [1.19]\n",
"| | | | | | |--- TiO2 > 0.03\n",
"| | | | | | | |--- value: [1.29]\n",
"| | | |--- Al2O3 > 0.18\n",
"| | | | |--- T <= 65.00\n",
"| | | | | |--- T <= 57.50\n",
"| | | | | | |--- value: [2.43]\n",
"| | | | | |--- T > 57.50\n",
"| | | | | | |--- value: [2.16]\n",
"| | | | |--- T > 65.00\n",
"| | | | | |--- value: [1.73]\n",
"| | |--- TiO2 > 0.18\n",
"| | | |--- T <= 65.00\n",
"| | | | |--- T <= 57.50\n",
"| | | | | |--- value: [2.84]\n",
"| | | | |--- T > 57.50\n",
"| | | | | |--- value: [2.54]\n",
"| | | |--- T > 65.00\n",
"| | | | |--- value: [1.91]\n",
"\n"
]
}
],
"source": [
"rules = tree.export_text(\n",
" models[\"decision_tree\"][\"fitted\"],\n",
" feature_names=viscosity_train.columns.values.tolist(),\n",
")\n",
"print(rules)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
"pickle.dump(models[\"decision_tree\"][\"fitted\"], open(\"data/vtree.model.sav\", \"wb\"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}