Compare commits

..

4 Commits

10 changed files with 5540 additions and 2011 deletions

File diff suppressed because one or more lines are too long

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@ -288,6 +288,26 @@ def simplify_and_group_rules(
return new_rules
def simplify_rules(X: pd.DataFrame, rules: List[Rule]) -> List[Rule]:
minmax = _get_variables_minmax(X)
new_rules: List[Rule] = []
for rule in rules:
intervals = _get_varibles_interval(rule.get_antecedent())
new_atoms = []
for key, value in intervals.items():
val: float = 0
if value[0] is None and value[1] is not None:
val = minmax[key][0]
if value[1] is None and value[0] is not None:
val = minmax[key][1]
if value[0] is not None and value[1] is not None:
val = (value[0] + value[1]) / 2
new_atoms.append(RuleAtom(key, ComparisonType.EQUALS, val))
new_rules.append(Rule(new_atoms, rule.get_consequent()))
return new_rules
def _get_fuzzy_rule_atom(
fuzzy_variable: FuzzyVariable, value: float
) -> Tuple[Term, float]:
@ -328,11 +348,11 @@ def _get_fuzzy_rules(
def _delete_same_fuzzy_rules(
rules_cluster: List[Tuple[List[RuleAtom], Term, float]]
rules: List[Tuple[List[RuleAtom], Term, float]]
) -> List[Tuple[List[RuleAtom], Term, float]]:
same_rules: List[int] = []
for rule1_index, rule1 in enumerate(rules_cluster):
for rule2_index, rule2 in enumerate(rules_cluster):
for rule1_index, rule1 in enumerate(rules):
for rule2_index, rule2 in enumerate(rules):
if rule1_index >= rule2_index:
continue
# Remove the same rules
@ -347,10 +367,10 @@ def _delete_same_fuzzy_rules(
if str(rule1[0]) == str(rule2[0]) and str(rule1[2]) > str(rule2[2]):
same_rules.append(rule1_index)
break
return [rule for index, rule in enumerate(rules_cluster) if index not in same_rules]
return [rule for index, rule in enumerate(rules) if index not in same_rules]
def get_fuzzy_rules(
def get_grouped_fuzzy_rules(
clustered_rules: List[List[Rule]], fuzzy_variables: Dict[str, FuzzyVariable]
) -> List[FuzzyRule]:
fuzzy_rules: List[List[Tuple[List[RuleAtom], Term, float]]] = []
@ -363,3 +383,12 @@ def get_fuzzy_rules(
for cluster in fuzzy_rules
for item in cluster
]
def get_fuzzy_rules(
rules: List[Rule], fuzzy_variables: Dict[str, FuzzyVariable]
) -> List[FuzzyRule]:
fuzzy_rules: List[Tuple[List[RuleAtom], Term, float]] = []
fuzzy_rules = _get_fuzzy_rules(rules, fuzzy_variables)
fuzzy_rules = _delete_same_fuzzy_rules(fuzzy_rules)
return [FuzzyRule(reduce(and_, item[0]), item[1]) for item in fuzzy_rules]

View File

@ -0,0 +1,773 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>T</th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" <th>Density</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>20</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.06250</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>25</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.05979</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>35</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.05404</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>40</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.05103</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>45</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.04794</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" T Al2O3 TiO2 Density\n",
"0 20 0.0 0.0 1.06250\n",
"1 25 0.0 0.0 1.05979\n",
"2 35 0.0 0.0 1.05404\n",
"3 40 0.0 0.0 1.05103\n",
"4 45 0.0 0.0 1.04794"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>T</th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" <th>Density</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>30</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>1.05696</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>55</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>1.04158</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>25</td>\n",
" <td>0.05</td>\n",
" <td>0.0</td>\n",
" <td>1.08438</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>30</td>\n",
" <td>0.05</td>\n",
" <td>0.0</td>\n",
" <td>1.08112</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>35</td>\n",
" <td>0.05</td>\n",
" <td>0.0</td>\n",
" <td>1.07781</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" T Al2O3 TiO2 Density\n",
"0 30 0.00 0.0 1.05696\n",
"1 55 0.00 0.0 1.04158\n",
"2 25 0.05 0.0 1.08438\n",
"3 30 0.05 0.0 1.08112\n",
"4 35 0.05 0.0 1.07781"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import pandas as pd\n",
"\n",
"train = pd.read_csv(\"data/density_train.csv\", sep=\";\", decimal=\",\")\n",
"test = pd.read_csv(\"data/density_test.csv\", sep=\";\", decimal=\",\")\n",
"\n",
"display(train.head())\n",
"display(test.head())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" <th>Density</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.06250</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.05979</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.05404</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.05103</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.04794</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Al2O3 TiO2 Density\n",
"0 0.0 0.0 1.06250\n",
"1 0.0 0.0 1.05979\n",
"2 0.0 0.0 1.05404\n",
"3 0.0 0.0 1.05103\n",
"4 0.0 0.0 1.04794"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0 20\n",
"1 25\n",
"2 35\n",
"3 40\n",
"4 45\n",
"Name: T, dtype: int64"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" <th>Density</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>1.05696</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>1.04158</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.05</td>\n",
" <td>0.0</td>\n",
" <td>1.08438</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.05</td>\n",
" <td>0.0</td>\n",
" <td>1.08112</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.05</td>\n",
" <td>0.0</td>\n",
" <td>1.07781</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Al2O3 TiO2 Density\n",
"0 0.00 0.0 1.05696\n",
"1 0.00 0.0 1.04158\n",
"2 0.05 0.0 1.08438\n",
"3 0.05 0.0 1.08112\n",
"4 0.05 0.0 1.07781"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0 30\n",
"1 55\n",
"2 25\n",
"3 30\n",
"4 35\n",
"Name: T, dtype: int64"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"y_train = train[\"T\"]\n",
"X_train = train.drop([\"T\"], axis=1)\n",
"\n",
"display(X_train.head())\n",
"display(y_train.head())\n",
"\n",
"y_test = test[\"T\"]\n",
"X_test = test.drop([\"T\"], axis=1)\n",
"\n",
"display(X_test.head())\n",
"display(y_test.head())"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"from sklearn import linear_model, tree, neighbors, ensemble\n",
"\n",
"random_state = 9\n",
"\n",
"models = {\n",
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(random_state=random_state, max_depth=6, criterion=\"absolute_error\")\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: linear\n",
"Model: linear_poly\n",
"Model: linear_interact\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: random_forest\n"
]
}
],
"source": [
"import math\n",
"from sklearn import metrics\n",
"\n",
"for model_name in models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" fitted_model = models[model_name][\"model\"].fit(\n",
" X_train.values, y_train.values.ravel()\n",
" )\n",
" y_train_pred = fitted_model.predict(X_train.values)\n",
" y_test_pred = fitted_model.predict(X_test.values)\n",
" models[model_name][\"fitted\"] = fitted_model\n",
" models[model_name][\"MSE_train\"] = metrics.mean_squared_error(y_train, y_train_pred)\n",
" models[model_name][\"MSE_test\"] = metrics.mean_squared_error(y_test, y_test_pred)\n",
" models[model_name][\"MAE_train\"] = metrics.mean_absolute_error(y_train, y_train_pred)\n",
" models[model_name][\"MAE_test\"] = metrics.mean_absolute_error(y_test, y_test_pred)\n",
" models[model_name][\"R2_train\"] = metrics.r2_score(y_train, y_train_pred)\n",
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_2421b_row0_col0, #T_2421b_row0_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row0_col3, #T_2421b_row6_col5 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row0_col5, #T_2421b_row6_col3 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row1_col0, #T_2421b_row2_col0, #T_2421b_row3_col0 {\n",
" background-color: #25848e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row1_col1 {\n",
" background-color: #24868e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row1_col3 {\n",
" background-color: #6f00a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row1_col5, #T_2421b_row5_col3 {\n",
" background-color: #d5536f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row2_col1 {\n",
" background-color: #23888e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row2_col3 {\n",
" background-color: #7201a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row2_col5 {\n",
" background-color: #d35171;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row3_col1 {\n",
" background-color: #1f998a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row3_col3 {\n",
" background-color: #9814a0;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row3_col5 {\n",
" background-color: #c23c81;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row4_col0 {\n",
" background-color: #23898e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row4_col1 {\n",
" background-color: #1e9d89;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row4_col3 {\n",
" background-color: #a21d9a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row4_col5 {\n",
" background-color: #be3885;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row5_col0 {\n",
" background-color: #5ac864;\n",
" color: #000000;\n",
"}\n",
"#T_2421b_row5_col1 {\n",
" background-color: #9bd93c;\n",
" color: #000000;\n",
"}\n",
"#T_2421b_row5_col5 {\n",
" background-color: #5601a4;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2421b_row6_col0, #T_2421b_row6_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"</style>\n",
"<table id=\"T_2421b\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_2421b_level0_col0\" class=\"col_heading level0 col0\" >MSE_train</th>\n",
" <th id=\"T_2421b_level0_col1\" class=\"col_heading level0 col1\" >MSE_test</th>\n",
" <th id=\"T_2421b_level0_col2\" class=\"col_heading level0 col2\" >MAE_train</th>\n",
" <th id=\"T_2421b_level0_col3\" class=\"col_heading level0 col3\" >MAE_test</th>\n",
" <th id=\"T_2421b_level0_col4\" class=\"col_heading level0 col4\" >R2_train</th>\n",
" <th id=\"T_2421b_level0_col5\" class=\"col_heading level0 col5\" >R2_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_2421b_level0_row0\" class=\"row_heading level0 row0\" >linear_poly</th>\n",
" <td id=\"T_2421b_row0_col0\" class=\"data row0 col0\" >0.302768</td>\n",
" <td id=\"T_2421b_row0_col1\" class=\"data row0 col1\" >0.203293</td>\n",
" <td id=\"T_2421b_row0_col2\" class=\"data row0 col2\" >0.419467</td>\n",
" <td id=\"T_2421b_row0_col3\" class=\"data row0 col3\" >0.392687</td>\n",
" <td id=\"T_2421b_row0_col4\" class=\"data row0 col4\" >0.998860</td>\n",
" <td id=\"T_2421b_row0_col5\" class=\"data row0 col5\" >0.999047</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_2421b_level0_row1\" class=\"row_heading level0 row1\" >linear_interact</th>\n",
" <td id=\"T_2421b_row1_col0\" class=\"data row1 col0\" >9.693323</td>\n",
" <td id=\"T_2421b_row1_col1\" class=\"data row1 col1\" >10.875442</td>\n",
" <td id=\"T_2421b_row1_col2\" class=\"data row1 col2\" >2.544944</td>\n",
" <td id=\"T_2421b_row1_col3\" class=\"data row1 col3\" >2.718424</td>\n",
" <td id=\"T_2421b_row1_col4\" class=\"data row1 col4\" >0.963492</td>\n",
" <td id=\"T_2421b_row1_col5\" class=\"data row1 col5\" >0.949019</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_2421b_level0_row2\" class=\"row_heading level0 row2\" >linear</th>\n",
" <td id=\"T_2421b_row2_col0\" class=\"data row2 col0\" >10.468503</td>\n",
" <td id=\"T_2421b_row2_col1\" class=\"data row2 col1\" >14.820315</td>\n",
" <td id=\"T_2421b_row2_col2\" class=\"data row2 col2\" >2.657476</td>\n",
" <td id=\"T_2421b_row2_col3\" class=\"data row2 col3\" >2.930229</td>\n",
" <td id=\"T_2421b_row2_col4\" class=\"data row2 col4\" >0.960572</td>\n",
" <td id=\"T_2421b_row2_col5\" class=\"data row2 col5\" >0.930526</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_2421b_level0_row3\" class=\"row_heading level0 row3\" >decision_tree</th>\n",
" <td id=\"T_2421b_row3_col0\" class=\"data row3 col0\" >10.526316</td>\n",
" <td id=\"T_2421b_row3_col1\" class=\"data row3 col1\" >47.426471</td>\n",
" <td id=\"T_2421b_row3_col2\" class=\"data row3 col2\" >1.842105</td>\n",
" <td id=\"T_2421b_row3_col3\" class=\"data row3 col3\" >5.735294</td>\n",
" <td id=\"T_2421b_row3_col4\" class=\"data row3 col4\" >0.960355</td>\n",
" <td id=\"T_2421b_row3_col5\" class=\"data row3 col5\" >0.777676</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_2421b_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
" <td id=\"T_2421b_row4_col0\" class=\"data row4 col0\" >20.243876</td>\n",
" <td id=\"T_2421b_row4_col1\" class=\"data row4 col1\" >54.501240</td>\n",
" <td id=\"T_2421b_row4_col2\" class=\"data row4 col2\" >3.592953</td>\n",
" <td id=\"T_2421b_row4_col3\" class=\"data row4 col3\" >6.598133</td>\n",
" <td id=\"T_2421b_row4_col4\" class=\"data row4 col4\" >0.923755</td>\n",
" <td id=\"T_2421b_row4_col5\" class=\"data row4 col5\" >0.744512</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_2421b_level0_row5\" class=\"row_heading level0 row5\" >knn</th>\n",
" <td id=\"T_2421b_row5_col0\" class=\"data row5 col0\" >174.100430</td>\n",
" <td id=\"T_2421b_row5_col1\" class=\"data row5 col1\" >191.176471</td>\n",
" <td id=\"T_2421b_row5_col2\" class=\"data row5 col2\" >10.808271</td>\n",
" <td id=\"T_2421b_row5_col3\" class=\"data row5 col3\" >11.680672</td>\n",
" <td id=\"T_2421b_row5_col4\" class=\"data row5 col4\" >0.344285</td>\n",
" <td id=\"T_2421b_row5_col5\" class=\"data row5 col5\" >0.103812</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_2421b_level0_row6\" class=\"row_heading level0 row6\" >ridge</th>\n",
" <td id=\"T_2421b_row6_col0\" class=\"data row6 col0\" >243.364664</td>\n",
" <td id=\"T_2421b_row6_col1\" class=\"data row6 col1\" >199.601477</td>\n",
" <td id=\"T_2421b_row6_col2\" class=\"data row6 col2\" >13.472724</td>\n",
" <td id=\"T_2421b_row6_col3\" class=\"data row6 col3\" >12.396799</td>\n",
" <td id=\"T_2421b_row6_col4\" class=\"data row6 col4\" >0.083415</td>\n",
" <td id=\"T_2421b_row6_col5\" class=\"data row6 col5\" >0.064317</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x16867d550>"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
" [\"MSE_train\", \"MSE_test\", \"MAE_train\", \"MAE_test\", \"R2_train\", \"R2_test\"]\n",
"]\n",
"reg_metrics.sort_values(by=\"MAE_test\").style.background_gradient(\n",
" cmap=\"viridis\", low=1, high=0.3, subset=[\"MSE_train\", \"MSE_test\"]\n",
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"MAE_test\", \"R2_test\"])"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"|--- Density <= 1.04\n",
"| |--- Density <= 1.03\n",
"| | |--- value: [70.00]\n",
"| |--- Density > 1.03\n",
"| | |--- Density <= 1.04\n",
"| | | |--- value: [65.00]\n",
"| | |--- Density > 1.04\n",
"| | | |--- value: [60.00]\n",
"|--- Density > 1.04\n",
"| |--- Density <= 1.07\n",
"| | |--- TiO2 <= 0.03\n",
"| | | |--- Al2O3 <= 0.03\n",
"| | | | |--- Density <= 1.05\n",
"| | | | | |--- Density <= 1.05\n",
"| | | | | | |--- value: [50.00]\n",
"| | | | | |--- Density > 1.05\n",
"| | | | | | |--- value: [42.50]\n",
"| | | | |--- Density > 1.05\n",
"| | | | | |--- Density <= 1.06\n",
"| | | | | | |--- value: [35.00]\n",
"| | | | | |--- Density > 1.06\n",
"| | | | | | |--- value: [22.50]\n",
"| | | |--- Al2O3 > 0.03\n",
"| | | | |--- Density <= 1.06\n",
"| | | | | |--- Density <= 1.05\n",
"| | | | | | |--- value: [70.00]\n",
"| | | | | |--- Density > 1.05\n",
"| | | | | | |--- value: [65.00]\n",
"| | | | |--- Density > 1.06\n",
"| | | | | |--- Density <= 1.07\n",
"| | | | | | |--- value: [55.00]\n",
"| | | | | |--- Density > 1.07\n",
"| | | | | | |--- value: [50.00]\n",
"| | |--- TiO2 > 0.03\n",
"| | | |--- Density <= 1.06\n",
"| | | | |--- value: [70.00]\n",
"| | | |--- Density > 1.06\n",
"| | | | |--- Density <= 1.06\n",
"| | | | | |--- value: [65.00]\n",
"| | | | |--- Density > 1.06\n",
"| | | | | |--- value: [60.00]\n",
"| |--- Density > 1.07\n",
"| | |--- Density <= 1.12\n",
"| | | |--- Density <= 1.08\n",
"| | | | |--- Density <= 1.07\n",
"| | | | | |--- value: [45.00]\n",
"| | | | |--- Density > 1.07\n",
"| | | | | |--- Density <= 1.08\n",
"| | | | | | |--- value: [40.00]\n",
"| | | | | |--- Density > 1.08\n",
"| | | | | | |--- value: [35.00]\n",
"| | | |--- Density > 1.08\n",
"| | | | |--- Density <= 1.09\n",
"| | | | | |--- value: [30.00]\n",
"| | | | |--- Density > 1.09\n",
"| | | | | |--- Al2O3 <= 0.03\n",
"| | | | | | |--- value: [22.50]\n",
"| | | | | |--- Al2O3 > 0.03\n",
"| | | | | | |--- value: [20.00]\n",
"| | |--- Density > 1.12\n",
"| | | |--- Density <= 1.18\n",
"| | | | |--- Density <= 1.15\n",
"| | | | | |--- value: [70.00]\n",
"| | | | |--- Density > 1.15\n",
"| | | | | |--- Al2O3 <= 0.15\n",
"| | | | | | |--- value: [65.00]\n",
"| | | | | |--- Al2O3 > 0.15\n",
"| | | | | | |--- value: [50.00]\n",
"| | | |--- Density > 1.18\n",
"| | | | |--- Al2O3 <= 0.15\n",
"| | | | | |--- Density <= 1.20\n",
"| | | | | | |--- value: [50.00]\n",
"| | | | | |--- Density > 1.20\n",
"| | | | | | |--- value: [30.00]\n",
"| | | | |--- Al2O3 > 0.15\n",
"| | | | | |--- Density <= 1.18\n",
"| | | | | | |--- value: [30.00]\n",
"| | | | | |--- Density > 1.18\n",
"| | | | | | |--- value: [22.50]\n",
"\n"
]
}
],
"source": [
"model = models[\"decision_tree\"][\"fitted\"]\n",
"rules = tree.export_text(\n",
" model, feature_names=X_train.columns.values.tolist()\n",
")\n",
"print(rules)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
"pickle.dump(model, open(\"data/temp_density_tree.model.sav\", \"wb\"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

1417
temp_density_tree.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,763 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>T</th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" <th>Viscosity</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>20</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>3.707</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>25</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>3.180</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>35</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>2.361</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>45</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.832</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>50</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.629</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" T Al2O3 TiO2 Viscosity\n",
"0 20 0.0 0.0 3.707\n",
"1 25 0.0 0.0 3.180\n",
"2 35 0.0 0.0 2.361\n",
"3 45 0.0 0.0 1.832\n",
"4 50 0.0 0.0 1.629"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>T</th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" <th>Viscosity</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>30</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>2.716</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>40</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>2.073</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>60</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>1.329</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>65</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>1.211</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>25</td>\n",
" <td>0.05</td>\n",
" <td>0.0</td>\n",
" <td>4.120</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" T Al2O3 TiO2 Viscosity\n",
"0 30 0.00 0.0 2.716\n",
"1 40 0.00 0.0 2.073\n",
"2 60 0.00 0.0 1.329\n",
"3 65 0.00 0.0 1.211\n",
"4 25 0.05 0.0 4.120"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import pandas as pd\n",
"\n",
"train = pd.read_csv(\"data/viscosity_train.csv\", sep=\";\", decimal=\",\")\n",
"test = pd.read_csv(\"data/viscosity_test.csv\", sep=\";\", decimal=\",\")\n",
"\n",
"display(train.head())\n",
"display(test.head())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" <th>Viscosity</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>3.707</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>3.180</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>2.361</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.832</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.629</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Al2O3 TiO2 Viscosity\n",
"0 0.0 0.0 3.707\n",
"1 0.0 0.0 3.180\n",
"2 0.0 0.0 2.361\n",
"3 0.0 0.0 1.832\n",
"4 0.0 0.0 1.629"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0 20\n",
"1 25\n",
"2 35\n",
"3 45\n",
"4 50\n",
"Name: T, dtype: int64"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" <th>Viscosity</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>2.716</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>2.073</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>1.329</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>1.211</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.05</td>\n",
" <td>0.0</td>\n",
" <td>4.120</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Al2O3 TiO2 Viscosity\n",
"0 0.00 0.0 2.716\n",
"1 0.00 0.0 2.073\n",
"2 0.00 0.0 1.329\n",
"3 0.00 0.0 1.211\n",
"4 0.05 0.0 4.120"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0 30\n",
"1 40\n",
"2 60\n",
"3 65\n",
"4 25\n",
"Name: T, dtype: int64"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"y_train = train[\"T\"]\n",
"X_train = train.drop([\"T\"], axis=1)\n",
"\n",
"display(X_train.head())\n",
"display(y_train.head())\n",
"\n",
"y_test = test[\"T\"]\n",
"X_test = test.drop([\"T\"], axis=1)\n",
"\n",
"display(X_test.head())\n",
"display(y_test.head())"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"from sklearn import linear_model, tree, neighbors, ensemble\n",
"\n",
"random_state = 9\n",
"\n",
"models = {\n",
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(random_state=random_state, max_depth=6, criterion=\"absolute_error\")\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: linear\n",
"Model: linear_poly\n",
"Model: linear_interact\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: random_forest\n"
]
}
],
"source": [
"import math\n",
"from sklearn import metrics\n",
"\n",
"for model_name in models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" fitted_model = models[model_name][\"model\"].fit(\n",
" X_train.values, y_train.values.ravel()\n",
" )\n",
" y_train_pred = fitted_model.predict(X_train.values)\n",
" y_test_pred = fitted_model.predict(X_test.values)\n",
" models[model_name][\"fitted\"] = fitted_model\n",
" models[model_name][\"MSE_train\"] = metrics.mean_squared_error(y_train, y_train_pred)\n",
" models[model_name][\"MSE_test\"] = metrics.mean_squared_error(y_test, y_test_pred)\n",
" models[model_name][\"MAE_train\"] = metrics.mean_absolute_error(y_train, y_train_pred)\n",
" models[model_name][\"MAE_test\"] = metrics.mean_absolute_error(y_test, y_test_pred)\n",
" models[model_name][\"R2_train\"] = metrics.r2_score(y_train, y_train_pred)\n",
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_6c8fb_row0_col0 {\n",
" background-color: #25838e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row0_col1, #T_6c8fb_row5_col0 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row0_col3, #T_6c8fb_row6_col5 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row0_col5, #T_6c8fb_row6_col3 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row1_col0, #T_6c8fb_row1_col1 {\n",
" background-color: #1fa187;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row1_col3 {\n",
" background-color: #a31e9a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row1_col5, #T_6c8fb_row3_col3 {\n",
" background-color: #b83289;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row2_col0 {\n",
" background-color: #25ac82;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row2_col1 {\n",
" background-color: #32b67a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row2_col3 {\n",
" background-color: #b6308b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row2_col5 {\n",
" background-color: #9e199d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row3_col0 {\n",
" background-color: #2eb37c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row3_col1 {\n",
" background-color: #35b779;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row3_col5 {\n",
" background-color: #9c179e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row4_col0 {\n",
" background-color: #238a8d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row4_col1, #T_6c8fb_row5_col1 {\n",
" background-color: #54c568;\n",
" color: #000000;\n",
"}\n",
"#T_6c8fb_row4_col3 {\n",
" background-color: #c5407e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row4_col5, #T_6c8fb_row5_col5 {\n",
" background-color: #8405a7;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row5_col3 {\n",
" background-color: #cd4a76;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6c8fb_row6_col0, #T_6c8fb_row6_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"</style>\n",
"<table id=\"T_6c8fb\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_6c8fb_level0_col0\" class=\"col_heading level0 col0\" >MSE_train</th>\n",
" <th id=\"T_6c8fb_level0_col1\" class=\"col_heading level0 col1\" >MSE_test</th>\n",
" <th id=\"T_6c8fb_level0_col2\" class=\"col_heading level0 col2\" >MAE_train</th>\n",
" <th id=\"T_6c8fb_level0_col3\" class=\"col_heading level0 col3\" >MAE_test</th>\n",
" <th id=\"T_6c8fb_level0_col4\" class=\"col_heading level0 col4\" >R2_train</th>\n",
" <th id=\"T_6c8fb_level0_col5\" class=\"col_heading level0 col5\" >R2_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_6c8fb_level0_row0\" class=\"row_heading level0 row0\" >linear_poly</th>\n",
" <td id=\"T_6c8fb_row0_col0\" class=\"data row0 col0\" >4.827768</td>\n",
" <td id=\"T_6c8fb_row0_col1\" class=\"data row0 col1\" >4.877296</td>\n",
" <td id=\"T_6c8fb_row0_col2\" class=\"data row0 col2\" >1.522643</td>\n",
" <td id=\"T_6c8fb_row0_col3\" class=\"data row0 col3\" >1.743058</td>\n",
" <td id=\"T_6c8fb_row0_col4\" class=\"data row0 col4\" >0.980864</td>\n",
" <td id=\"T_6c8fb_row0_col5\" class=\"data row0 col5\" >0.974964</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_6c8fb_level0_row1\" class=\"row_heading level0 row1\" >linear_interact</th>\n",
" <td id=\"T_6c8fb_row1_col0\" class=\"data row1 col0\" >21.786348</td>\n",
" <td id=\"T_6c8fb_row1_col1\" class=\"data row1 col1\" >23.459572</td>\n",
" <td id=\"T_6c8fb_row1_col2\" class=\"data row1 col2\" >3.830996</td>\n",
" <td id=\"T_6c8fb_row1_col3\" class=\"data row1 col3\" >4.381115</td>\n",
" <td id=\"T_6c8fb_row1_col4\" class=\"data row1 col4\" >0.913644</td>\n",
" <td id=\"T_6c8fb_row1_col5\" class=\"data row1 col5\" >0.879577</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_6c8fb_level0_row2\" class=\"row_heading level0 row2\" >linear</th>\n",
" <td id=\"T_6c8fb_row2_col0\" class=\"data row2 col0\" >27.766510</td>\n",
" <td id=\"T_6c8fb_row2_col1\" class=\"data row2 col1\" >35.430313</td>\n",
" <td id=\"T_6c8fb_row2_col2\" class=\"data row2 col2\" >4.088006</td>\n",
" <td id=\"T_6c8fb_row2_col3\" class=\"data row2 col3\" >5.106782</td>\n",
" <td id=\"T_6c8fb_row2_col4\" class=\"data row2 col4\" >0.889940</td>\n",
" <td id=\"T_6c8fb_row2_col5\" class=\"data row2 col5\" >0.818129</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_6c8fb_level0_row3\" class=\"row_heading level0 row3\" >ridge</th>\n",
" <td id=\"T_6c8fb_row3_col0\" class=\"data row3 col0\" >31.827476</td>\n",
" <td id=\"T_6c8fb_row3_col1\" class=\"data row3 col1\" >36.230606</td>\n",
" <td id=\"T_6c8fb_row3_col2\" class=\"data row3 col2\" >4.383008</td>\n",
" <td id=\"T_6c8fb_row3_col3\" class=\"data row3 col3\" >5.226480</td>\n",
" <td id=\"T_6c8fb_row3_col4\" class=\"data row3 col4\" >0.873843</td>\n",
" <td id=\"T_6c8fb_row3_col5\" class=\"data row3 col5\" >0.814021</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_6c8fb_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
" <td id=\"T_6c8fb_row4_col0\" class=\"data row4 col0\" >8.525285</td>\n",
" <td id=\"T_6c8fb_row4_col1\" class=\"data row4 col1\" >45.444651</td>\n",
" <td id=\"T_6c8fb_row4_col2\" class=\"data row4 col2\" >2.542935</td>\n",
" <td id=\"T_6c8fb_row4_col3\" class=\"data row4 col3\" >5.749510</td>\n",
" <td id=\"T_6c8fb_row4_col4\" class=\"data row4 col4\" >0.966208</td>\n",
" <td id=\"T_6c8fb_row4_col5\" class=\"data row4 col5\" >0.766723</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_6c8fb_level0_row5\" class=\"row_heading level0 row5\" >decision_tree</th>\n",
" <td id=\"T_6c8fb_row5_col0\" class=\"data row5 col0\" >3.289474</td>\n",
" <td id=\"T_6c8fb_row5_col1\" class=\"data row5 col1\" >45.588235</td>\n",
" <td id=\"T_6c8fb_row5_col2\" class=\"data row5 col2\" >0.921053</td>\n",
" <td id=\"T_6c8fb_row5_col3\" class=\"data row5 col3\" >6.176471</td>\n",
" <td id=\"T_6c8fb_row5_col4\" class=\"data row5 col4\" >0.986961</td>\n",
" <td id=\"T_6c8fb_row5_col5\" class=\"data row5 col5\" >0.765986</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_6c8fb_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
" <td id=\"T_6c8fb_row6_col0\" class=\"data row6 col0\" >61.855532</td>\n",
" <td id=\"T_6c8fb_row6_col1\" class=\"data row6 col1\" >64.165666</td>\n",
" <td id=\"T_6c8fb_row6_col2\" class=\"data row6 col2\" >6.522556</td>\n",
" <td id=\"T_6c8fb_row6_col3\" class=\"data row6 col3\" >6.806723</td>\n",
" <td id=\"T_6c8fb_row6_col4\" class=\"data row6 col4\" >0.754819</td>\n",
" <td id=\"T_6c8fb_row6_col5\" class=\"data row6 col5\" >0.670624</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x11b747260>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
" [\"MSE_train\", \"MSE_test\", \"MAE_train\", \"MAE_test\", \"R2_train\", \"R2_test\"]\n",
"]\n",
"reg_metrics.sort_values(by=\"MAE_test\").style.background_gradient(\n",
" cmap=\"viridis\", low=1, high=0.3, subset=[\"MSE_train\", \"MSE_test\"]\n",
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"MAE_test\", \"R2_test\"])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"|--- Viscosity <= 2.86\n",
"| |--- Viscosity <= 1.38\n",
"| | |--- value: [70.00]\n",
"| |--- Viscosity > 1.38\n",
"| | |--- Viscosity <= 1.78\n",
"| | | |--- Viscosity <= 1.72\n",
"| | | | |--- TiO2 <= 0.03\n",
"| | | | | |--- Al2O3 <= 0.03\n",
"| | | | | | |--- value: [52.50]\n",
"| | | | | |--- Al2O3 > 0.03\n",
"| | | | | | |--- value: [57.50]\n",
"| | | | |--- TiO2 > 0.03\n",
"| | | | | |--- value: [60.00]\n",
"| | | |--- Viscosity > 1.72\n",
"| | | | |--- value: [70.00]\n",
"| | |--- Viscosity > 1.78\n",
"| | | |--- TiO2 <= 0.18\n",
"| | | | |--- Al2O3 <= 0.18\n",
"| | | | | |--- Viscosity <= 2.24\n",
"| | | | | | |--- value: [50.00]\n",
"| | | | | |--- Viscosity > 2.24\n",
"| | | | | | |--- value: [40.00]\n",
"| | | | |--- Al2O3 > 0.18\n",
"| | | | | |--- Viscosity <= 2.29\n",
"| | | | | | |--- value: [60.00]\n",
"| | | | | |--- Viscosity > 2.29\n",
"| | | | | | |--- value: [55.00]\n",
"| | | |--- TiO2 > 0.18\n",
"| | | | |--- Viscosity <= 2.22\n",
"| | | | | |--- value: [70.00]\n",
"| | | | |--- Viscosity > 2.22\n",
"| | | | | |--- Viscosity <= 2.69\n",
"| | | | | | |--- value: [60.00]\n",
"| | | | | |--- Viscosity > 2.69\n",
"| | | | | | |--- value: [55.00]\n",
"|--- Viscosity > 2.86\n",
"| |--- Viscosity <= 3.64\n",
"| | |--- TiO2 <= 0.18\n",
"| | | |--- Viscosity <= 3.15\n",
"| | | | |--- value: [35.00]\n",
"| | | |--- Viscosity > 3.15\n",
"| | | | |--- Viscosity <= 3.47\n",
"| | | | | |--- Al2O3 <= 0.03\n",
"| | | | | | |--- value: [25.00]\n",
"| | | | | |--- Al2O3 > 0.03\n",
"| | | | | | |--- value: [30.00]\n",
"| | | | |--- Viscosity > 3.47\n",
"| | | | | |--- value: [40.00]\n",
"| | |--- TiO2 > 0.18\n",
"| | | |--- value: [45.00]\n",
"| |--- Viscosity > 3.64\n",
"| | |--- Viscosity <= 6.27\n",
"| | | |--- TiO2 <= 0.18\n",
"| | | | |--- Al2O3 <= 0.18\n",
"| | | | | |--- TiO2 <= 0.03\n",
"| | | | | | |--- value: [20.00]\n",
"| | | | | |--- TiO2 > 0.03\n",
"| | | | | | |--- value: [22.50]\n",
"| | | | |--- Al2O3 > 0.18\n",
"| | | | | |--- Viscosity <= 4.42\n",
"| | | | | | |--- value: [35.00]\n",
"| | | | | |--- Viscosity > 4.42\n",
"| | | | | | |--- value: [27.50]\n",
"| | | |--- TiO2 > 0.18\n",
"| | | | |--- Viscosity <= 4.65\n",
"| | | | | |--- value: [35.00]\n",
"| | | | |--- Viscosity > 4.65\n",
"| | | | | |--- Viscosity <= 5.40\n",
"| | | | | | |--- value: [30.00]\n",
"| | | | | |--- Viscosity > 5.40\n",
"| | | | | | |--- value: [25.00]\n",
"| | |--- Viscosity > 6.27\n",
"| | | |--- value: [20.00]\n",
"\n"
]
}
],
"source": [
"model = models[\"decision_tree\"][\"fitted\"]\n",
"rules = tree.export_text(\n",
" model, feature_names=X_train.columns.values.tolist()\n",
")\n",
"print(rules)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
"pickle.dump(model, open(\"data/temp_viscosity_tree.model.sav\", \"wb\"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

1357
temp_viscosity_tree.ipynb Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long