{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Financial Distress | \n",
" x1 | \n",
" x2 | \n",
" x3 | \n",
" x4 | \n",
" x5 | \n",
" x6 | \n",
" x7 | \n",
" x8 | \n",
" x9 | \n",
" ... | \n",
" x73 | \n",
" x74 | \n",
" x75 | \n",
" x76 | \n",
" x77 | \n",
" x78 | \n",
" x79 | \n",
" x81 | \n",
" x82 | \n",
" x83 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0 | \n",
" 1.2810 | \n",
" 0.022934 | \n",
" 0.87454 | \n",
" 1.21640 | \n",
" 0.060940 | \n",
" 0.188270 | \n",
" 0.52510 | \n",
" 0.018854 | \n",
" 0.182790 | \n",
" ... | \n",
" 36.0 | \n",
" 85.437 | \n",
" 27.07 | \n",
" 26.102 | \n",
" 16.000 | \n",
" 16.0 | \n",
" 0.2 | \n",
" 0.060390 | \n",
" 30 | \n",
" 49 | \n",
"
\n",
" \n",
" 1 | \n",
" 0 | \n",
" 1.2700 | \n",
" 0.006454 | \n",
" 0.82067 | \n",
" 1.00490 | \n",
" -0.014080 | \n",
" 0.181040 | \n",
" 0.62288 | \n",
" 0.006423 | \n",
" 0.035991 | \n",
" ... | \n",
" 36.0 | \n",
" 107.090 | \n",
" 31.31 | \n",
" 30.194 | \n",
" 17.000 | \n",
" 16.0 | \n",
" 0.4 | \n",
" 0.010636 | \n",
" 31 | \n",
" 50 | \n",
"
\n",
" \n",
" 2 | \n",
" 0 | \n",
" 1.0529 | \n",
" -0.059379 | \n",
" 0.92242 | \n",
" 0.72926 | \n",
" 0.020476 | \n",
" 0.044865 | \n",
" 0.43292 | \n",
" -0.081423 | \n",
" -0.765400 | \n",
" ... | \n",
" 35.0 | \n",
" 120.870 | \n",
" 36.07 | \n",
" 35.273 | \n",
" 17.000 | \n",
" 15.0 | \n",
" -0.2 | \n",
" -0.455970 | \n",
" 32 | \n",
" 51 | \n",
"
\n",
" \n",
" 3 | \n",
" 1 | \n",
" 1.1131 | \n",
" -0.015229 | \n",
" 0.85888 | \n",
" 0.80974 | \n",
" 0.076037 | \n",
" 0.091033 | \n",
" 0.67546 | \n",
" -0.018807 | \n",
" -0.107910 | \n",
" ... | \n",
" 33.0 | \n",
" 54.806 | \n",
" 39.80 | \n",
" 38.377 | \n",
" 17.167 | \n",
" 16.0 | \n",
" 5.6 | \n",
" -0.325390 | \n",
" 33 | \n",
" 52 | \n",
"
\n",
" \n",
" 4 | \n",
" 0 | \n",
" 1.0623 | \n",
" 0.107020 | \n",
" 0.81460 | \n",
" 0.83593 | \n",
" 0.199960 | \n",
" 0.047800 | \n",
" 0.74200 | \n",
" 0.128030 | \n",
" 0.577250 | \n",
" ... | \n",
" 36.0 | \n",
" 85.437 | \n",
" 27.07 | \n",
" 26.102 | \n",
" 16.000 | \n",
" 16.0 | \n",
" 0.2 | \n",
" 1.251000 | \n",
" 7 | \n",
" 27 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 3667 | \n",
" 0 | \n",
" 2.2605 | \n",
" 0.202890 | \n",
" 0.16037 | \n",
" 0.18588 | \n",
" 0.175970 | \n",
" 0.198400 | \n",
" 2.22360 | \n",
" 1.091500 | \n",
" 0.241640 | \n",
" ... | \n",
" 22.0 | \n",
" 100.000 | \n",
" 100.00 | \n",
" 100.000 | \n",
" 17.125 | \n",
" 14.5 | \n",
" -7.0 | \n",
" 0.436380 | \n",
" 4 | \n",
" 41 | \n",
"
\n",
" \n",
" 3668 | \n",
" 0 | \n",
" 1.9615 | \n",
" 0.216440 | \n",
" 0.20095 | \n",
" 0.21642 | \n",
" 0.203590 | \n",
" 0.189870 | \n",
" 1.93820 | \n",
" 1.000100 | \n",
" 0.270870 | \n",
" ... | \n",
" 28.0 | \n",
" 91.500 | \n",
" 130.50 | \n",
" 132.400 | \n",
" 20.000 | \n",
" 14.5 | \n",
" -16.0 | \n",
" 0.438020 | \n",
" 5 | \n",
" 42 | \n",
"
\n",
" \n",
" 3669 | \n",
" 0 | \n",
" 1.7099 | \n",
" 0.207970 | \n",
" 0.26136 | \n",
" 0.21399 | \n",
" 0.193670 | \n",
" 0.183890 | \n",
" 1.68980 | \n",
" 0.971860 | \n",
" 0.281560 | \n",
" ... | \n",
" 32.0 | \n",
" 87.100 | \n",
" 175.90 | \n",
" 178.100 | \n",
" 20.000 | \n",
" 14.5 | \n",
" -20.2 | \n",
" 0.482410 | \n",
" 6 | \n",
" 43 | \n",
"
\n",
" \n",
" 3670 | \n",
" 0 | \n",
" 1.5590 | \n",
" 0.185450 | \n",
" 0.30728 | \n",
" 0.19307 | \n",
" 0.172140 | \n",
" 0.170680 | \n",
" 1.53890 | \n",
" 0.960570 | \n",
" 0.267720 | \n",
" ... | \n",
" 30.0 | \n",
" 92.900 | \n",
" 203.20 | \n",
" 204.500 | \n",
" 22.000 | \n",
" 22.0 | \n",
" 6.4 | \n",
" 0.500770 | \n",
" 7 | \n",
" 44 | \n",
"
\n",
" \n",
" 3671 | \n",
" 0 | \n",
" 1.6148 | \n",
" 0.176760 | \n",
" 0.36369 | \n",
" 0.18442 | \n",
" 0.169550 | \n",
" 0.197860 | \n",
" 1.58420 | \n",
" 0.958450 | \n",
" 0.277780 | \n",
" ... | \n",
" 29.0 | \n",
" 91.700 | \n",
" 227.50 | \n",
" 214.500 | \n",
" 21.000 | \n",
" 20.5 | \n",
" 8.6 | \n",
" 0.611030 | \n",
" 8 | \n",
" 45 | \n",
"
\n",
" \n",
"
\n",
"
3672 rows × 83 columns
\n",
"
"
],
"text/plain": [
" Financial Distress x1 x2 x3 x4 x5 \\\n",
"0 0 1.2810 0.022934 0.87454 1.21640 0.060940 \n",
"1 0 1.2700 0.006454 0.82067 1.00490 -0.014080 \n",
"2 0 1.0529 -0.059379 0.92242 0.72926 0.020476 \n",
"3 1 1.1131 -0.015229 0.85888 0.80974 0.076037 \n",
"4 0 1.0623 0.107020 0.81460 0.83593 0.199960 \n",
"... ... ... ... ... ... ... \n",
"3667 0 2.2605 0.202890 0.16037 0.18588 0.175970 \n",
"3668 0 1.9615 0.216440 0.20095 0.21642 0.203590 \n",
"3669 0 1.7099 0.207970 0.26136 0.21399 0.193670 \n",
"3670 0 1.5590 0.185450 0.30728 0.19307 0.172140 \n",
"3671 0 1.6148 0.176760 0.36369 0.18442 0.169550 \n",
"\n",
" x6 x7 x8 x9 ... x73 x74 x75 \\\n",
"0 0.188270 0.52510 0.018854 0.182790 ... 36.0 85.437 27.07 \n",
"1 0.181040 0.62288 0.006423 0.035991 ... 36.0 107.090 31.31 \n",
"2 0.044865 0.43292 -0.081423 -0.765400 ... 35.0 120.870 36.07 \n",
"3 0.091033 0.67546 -0.018807 -0.107910 ... 33.0 54.806 39.80 \n",
"4 0.047800 0.74200 0.128030 0.577250 ... 36.0 85.437 27.07 \n",
"... ... ... ... ... ... ... ... ... \n",
"3667 0.198400 2.22360 1.091500 0.241640 ... 22.0 100.000 100.00 \n",
"3668 0.189870 1.93820 1.000100 0.270870 ... 28.0 91.500 130.50 \n",
"3669 0.183890 1.68980 0.971860 0.281560 ... 32.0 87.100 175.90 \n",
"3670 0.170680 1.53890 0.960570 0.267720 ... 30.0 92.900 203.20 \n",
"3671 0.197860 1.58420 0.958450 0.277780 ... 29.0 91.700 227.50 \n",
"\n",
" x76 x77 x78 x79 x81 x82 x83 \n",
"0 26.102 16.000 16.0 0.2 0.060390 30 49 \n",
"1 30.194 17.000 16.0 0.4 0.010636 31 50 \n",
"2 35.273 17.000 15.0 -0.2 -0.455970 32 51 \n",
"3 38.377 17.167 16.0 5.6 -0.325390 33 52 \n",
"4 26.102 16.000 16.0 0.2 1.251000 7 27 \n",
"... ... ... ... ... ... ... ... \n",
"3667 100.000 17.125 14.5 -7.0 0.436380 4 41 \n",
"3668 132.400 20.000 14.5 -16.0 0.438020 5 42 \n",
"3669 178.100 20.000 14.5 -20.2 0.482410 6 43 \n",
"3670 204.500 22.000 22.0 6.4 0.500770 7 44 \n",
"3671 214.500 21.000 20.5 8.6 0.611030 8 45 \n",
"\n",
"[3672 rows x 83 columns]"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"random_state = 9\n",
"\n",
"def get_class(row):\n",
" return 0 if row[\"Financial Distress\"] > -0.5 else 1\n",
"\n",
"\n",
"df = pd.read_csv(\"data-distress/FinancialDistress.csv\").drop(\n",
" [\"Company\", \"Time\", \"x80\"], axis=1\n",
")\n",
"df[\"Financial Distress\"] = df.apply(get_class, axis=1)\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" x1 | \n",
" x2 | \n",
" x3 | \n",
" x4 | \n",
" x5 | \n",
" x6 | \n",
" x7 | \n",
" x8 | \n",
" x9 | \n",
" x10 | \n",
" ... | \n",
" x73 | \n",
" x74 | \n",
" x75 | \n",
" x76 | \n",
" x77 | \n",
" x78 | \n",
" x79 | \n",
" x81 | \n",
" x82 | \n",
" x83 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1.0902 | \n",
" 0.081715 | \n",
" 0.71056 | \n",
" 0.93446 | \n",
" 0.14445 | \n",
" 0.060042 | \n",
" 0.74048 | \n",
" 0.087447 | \n",
" 0.28232 | \n",
" 0.14572 | \n",
" ... | \n",
" 28.0 | \n",
" 79.951 | \n",
" 66.12 | \n",
" 59.471 | \n",
" 18.0 | \n",
" 12.0 | \n",
" -13.4 | \n",
" 1.1152 | \n",
" 13 | \n",
" 28 | \n",
"
\n",
" \n",
" 1 | \n",
" 1.6711 | \n",
" 0.445250 | \n",
" 0.21104 | \n",
" 0.59523 | \n",
" 0.30998 | \n",
" 0.133260 | \n",
" 1.34060 | \n",
" 0.748040 | \n",
" 0.56436 | \n",
" 0.48288 | \n",
" ... | \n",
" 32.0 | \n",
" 87.100 | \n",
" 175.90 | \n",
" 178.100 | \n",
" 20.0 | \n",
" 14.5 | \n",
" -20.2 | \n",
" 1.7354 | \n",
" 2 | \n",
" 18 | \n",
"
\n",
" \n",
" 2 | \n",
" 1.6321 | \n",
" 0.375790 | \n",
" 0.46072 | \n",
" 0.90327 | \n",
" 0.28563 | \n",
" 0.287440 | \n",
" 1.35890 | \n",
" 0.416040 | \n",
" 0.69684 | \n",
" 0.45008 | \n",
" ... | \n",
" 30.0 | \n",
" 92.900 | \n",
" 203.20 | \n",
" 204.500 | \n",
" 22.0 | \n",
" 22.0 | \n",
" 6.4 | \n",
" 5.5809 | \n",
" 8 | \n",
" 17 | \n",
"
\n",
" \n",
"
\n",
"
3 rows × 82 columns
\n",
"
"
],
"text/plain": [
" x1 x2 x3 x4 x5 x6 x7 x8 \\\n",
"0 1.0902 0.081715 0.71056 0.93446 0.14445 0.060042 0.74048 0.087447 \n",
"1 1.6711 0.445250 0.21104 0.59523 0.30998 0.133260 1.34060 0.748040 \n",
"2 1.6321 0.375790 0.46072 0.90327 0.28563 0.287440 1.35890 0.416040 \n",
"\n",
" x9 x10 ... x73 x74 x75 x76 x77 x78 x79 \\\n",
"0 0.28232 0.14572 ... 28.0 79.951 66.12 59.471 18.0 12.0 -13.4 \n",
"1 0.56436 0.48288 ... 32.0 87.100 175.90 178.100 20.0 14.5 -20.2 \n",
"2 0.69684 0.45008 ... 30.0 92.900 203.20 204.500 22.0 22.0 6.4 \n",
"\n",
" x81 x82 x83 \n",
"0 1.1152 13 28 \n",
"1 1.7354 2 18 \n",
"2 5.5809 8 17 \n",
"\n",
"[3 rows x 82 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Financial Distress | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Financial Distress\n",
"0 0\n",
"1 0\n",
"2 0"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" x1 | \n",
" x2 | \n",
" x3 | \n",
" x4 | \n",
" x5 | \n",
" x6 | \n",
" x7 | \n",
" x8 | \n",
" x9 | \n",
" x10 | \n",
" ... | \n",
" x73 | \n",
" x74 | \n",
" x75 | \n",
" x76 | \n",
" x77 | \n",
" x78 | \n",
" x79 | \n",
" x81 | \n",
" x82 | \n",
" x83 | \n",
"
\n",
" \n",
" \n",
" \n",
" 3379 | \n",
" 0.98569 | \n",
" 0.12912 | \n",
" 0.62266 | \n",
" 0.74377 | \n",
" 0.13716 | \n",
" -0.006040 | \n",
" 0.65175 | \n",
" 0.173600 | \n",
" 0.34218 | \n",
" 0.008050 | \n",
" ... | \n",
" 35.5 | \n",
" 113.980 | \n",
" 33.690 | \n",
" 32.734 | \n",
" 17.000 | \n",
" 15.5 | \n",
" 0.10000 | \n",
" 0.045774 | \n",
" 9 | \n",
" 19 | \n",
"
\n",
" \n",
" 156 | \n",
" 0.91084 | \n",
" 0.04889 | \n",
" 0.79108 | \n",
" 0.68615 | \n",
" 0.10943 | \n",
" -0.050311 | \n",
" 0.60633 | \n",
" 0.071253 | \n",
" 0.23401 | \n",
" 0.011391 | \n",
" ... | \n",
" 35.0 | \n",
" 120.870 | \n",
" 36.070 | \n",
" 35.273 | \n",
" 17.000 | \n",
" 15.0 | \n",
" -0.20000 | \n",
" 0.152330 | \n",
" 7 | \n",
" 29 | \n",
"
\n",
" \n",
" 2215 | \n",
" 1.43350 | \n",
" 0.20068 | \n",
" 0.46538 | \n",
" 0.54146 | \n",
" 0.25140 | \n",
" 0.096212 | \n",
" 0.82271 | \n",
" 0.370630 | \n",
" 0.37538 | \n",
" 0.187750 | \n",
" ... | \n",
" 36.0 | \n",
" 98.066 | \n",
" 29.543 | \n",
" 28.489 | \n",
" 16.583 | \n",
" 16.0 | \n",
" 0.31667 | \n",
" 0.563030 | \n",
" 1 | \n",
" 19 | \n",
"
\n",
" \n",
"
\n",
"
3 rows × 82 columns
\n",
"
"
],
"text/plain": [
" x1 x2 x3 x4 x5 x6 x7 \\\n",
"3379 0.98569 0.12912 0.62266 0.74377 0.13716 -0.006040 0.65175 \n",
"156 0.91084 0.04889 0.79108 0.68615 0.10943 -0.050311 0.60633 \n",
"2215 1.43350 0.20068 0.46538 0.54146 0.25140 0.096212 0.82271 \n",
"\n",
" x8 x9 x10 ... x73 x74 x75 x76 x77 \\\n",
"3379 0.173600 0.34218 0.008050 ... 35.5 113.980 33.690 32.734 17.000 \n",
"156 0.071253 0.23401 0.011391 ... 35.0 120.870 36.070 35.273 17.000 \n",
"2215 0.370630 0.37538 0.187750 ... 36.0 98.066 29.543 28.489 16.583 \n",
"\n",
" x78 x79 x81 x82 x83 \n",
"3379 15.5 0.10000 0.045774 9 19 \n",
"156 15.0 -0.20000 0.152330 7 29 \n",
"2215 16.0 0.31667 0.563030 1 19 \n",
"\n",
"[3 rows x 82 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Financial Distress | \n",
"
\n",
" \n",
" \n",
" \n",
" 3379 | \n",
" 0 | \n",
"
\n",
" \n",
" 156 | \n",
" 0 | \n",
"
\n",
" \n",
" 2215 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Financial Distress\n",
"3379 0\n",
"156 0\n",
"2215 0"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from collections import Counter\n",
"from src.utils import split_stratified_into_train_val_test\n",
"from imblearn.over_sampling import ADASYN\n",
"\n",
"X_train, X_test, y_train, y_test = split_stratified_into_train_val_test(\n",
" df,\n",
" stratify_colname=\"Financial Distress\",\n",
" frac_train=0.8,\n",
" frac_val=0,\n",
" frac_test=0.2,\n",
" random_state=random_state,\n",
")\n",
"\n",
"ada = ADASYN()\n",
"\n",
"X_train, y_train = ada.fit_resample(X_train, y_train)\n",
"\n",
"\n",
"# print(f\"Original dataset shape {len(df[[\"Financial Distress\"]])}\")\n",
"# X, y = reducer.fit_resample(\n",
"# df.drop([\"Financial Distress\"], axis=1), df[[\"Financial Distress\"]]\n",
"# )\n",
"# print(f\"Original dataset shape {len(y)}\")\n",
"\n",
"\n",
"# X_train = pd.DataFrame(\n",
"# sc.fit_transform(X_train.values), columns=X_train.columns, index=X_train.index\n",
"# )\n",
"# X_test = pd.DataFrame(\n",
"# sc.fit_transform(X_test.values), columns=X_test.columns, index=X_test.index\n",
"# )\n",
"\n",
"\n",
"display(X_train.head(3))\n",
"display(y_train.head(3))\n",
"display(X_test.head(3))\n",
"display(y_test.head(3))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\base.py:486: UserWarning: X has feature names, but DecisionTreeClassifier was fitted without feature names\n",
" warnings.warn(\n",
"c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\base.py:486: UserWarning: X has feature names, but DecisionTreeClassifier was fitted without feature names\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"{'pipeline': DecisionTreeClassifier(max_depth=7, random_state=9),\n",
" 'probs': array([1. , 1. , 1. , 0.04451683, 1. ,\n",
" 1. , 1. , 1. , 0.98984772, 1. ,\n",
" 0.04451683, 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 0.04451683, 1. , 1. , 0.04451683, 0.98984772,\n",
" 0.98984772, 1. , 1. , 0.04451683, 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 0.98984772, 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 0.15901814, 1. , 1. , 1. ,\n",
" 1. , 0.98984772, 0.04451683, 0.98984772, 1. ,\n",
" 1. , 1. , 1. , 0.98984772, 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 0.15901814, 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 0.98984772, 1. , 1. ,\n",
" 1. , 1. , 1. , 0.64285714, 0.98984772,\n",
" 0.04451683, 1. , 0.98984772, 1. , 1. ,\n",
" 1. , 1. , 1. , 0.98984772, 1. ,\n",
" 1. , 0.64285714, 0.04451683, 1. , 1. ,\n",
" 1. , 0.04451683, 1. , 1. , 0.98984772,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 0.15901814, 1. ,\n",
" 0.15901814, 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 0. , 1. , 1. ,\n",
" 0.98984772, 1. , 1. , 1. , 0.04451683,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 0.04451683, 1. , 0.15901814, 1. ,\n",
" 1. , 0.64285714, 1. , 1. , 0.64285714,\n",
" 1. , 0.15901814, 0.15901814, 0.64285714, 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 0.04451683, 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 0.04451683, 1. , 1. ,\n",
" 1. , 0.04451683, 1. , 0.04451683, 0.15901814,\n",
" 1. , 0.98984772, 1. , 1. , 0.625 ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 0.98984772, 1. , 1. , 1. ,\n",
" 1. , 1. , 0.04451683, 0.98984772, 0.98984772,\n",
" 1. , 0.04451683, 1. , 1. , 1. ,\n",
" 1. , 1. , 0.04451683, 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 0.64285714, 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 0.64285714, 0.15901814, 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 0.98984772, 1. , 1. , 1. ,\n",
" 0.15901814, 0.15901814, 1. , 0.98984772, 1. ,\n",
" 1. , 1. , 0.15901814, 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 0.98984772, 0.15901814, 1. , 1. , 1. ,\n",
" 1. , 0.04451683, 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 0.98984772, 1. , 0.04451683, 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 0.64285714,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 0.04451683,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 0.98984772, 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 0.04451683, 1. , 1. , 1. , 1. ,\n",
" 0.15901814, 1. , 0.98984772, 0.04451683, 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 0.15901814, 1. ,\n",
" 1. , 1. , 0.04451683, 1. , 1. ,\n",
" 1. , 1. , 1. , 0.98984772, 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 0.04451683,\n",
" 0.15901814, 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 0. , 0.15901814, 1. , 0.64285714,\n",
" 0.04451683, 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 0.98984772, 1. , 0.98984772,\n",
" 1. , 1. , 1. , 0.98984772, 1. ,\n",
" 1. , 1. , 1. , 0.98984772, 1. ,\n",
" 0.08333333, 1. , 1. , 1. , 1. ,\n",
" 1. , 0. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 0.15901814,\n",
" 1. , 1. , 1. , 0.98984772, 1. ,\n",
" 1. , 0. , 1. , 1. , 0.64285714,\n",
" 0.98984772, 1. , 0.04451683, 1. , 1. ,\n",
" 0.04451683, 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 0.98984772, 1. ,\n",
" 1. , 1. , 0. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 0.98984772, 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 0.98984772, 0.15901814, 1. ,\n",
" 1. , 1. , 0.04451683, 1. , 1. ,\n",
" 1. , 1. , 0.15901814, 1. , 0.8 ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 0.64285714, 0.64285714, 1. , 1. ,\n",
" 0.08333333, 1. , 1. , 1. , 0.8 ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 0.04451683, 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 0.98984772, 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 0.04451683, 1. , 1. ,\n",
" 0.04451683, 1. , 0.15901814, 1. , 0.04451683,\n",
" 1. , 1. , 1. , 0.98984772, 0.15901814,\n",
" 1. , 0.64285714, 1. , 1. , 1. ,\n",
" 0.64285714, 1. , 1. , 0.64285714, 0. ,\n",
" 0.04451683, 1. , 1. , 1. , 1. ,\n",
" 1. , 0.04451683, 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 0.15901814,\n",
" 0.98984772, 1. , 0.15901814, 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 0.04451683, 1. , 0.98984772, 0.04451683,\n",
" 1. , 0.98984772, 1. , 1. , 0.04451683,\n",
" 1. , 1. , 1. , 0.64285714, 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 0.98984772, 1. , 1. , 1. , 1. ,\n",
" 1. , 0.98984772, 1. , 1. , 1. ,\n",
" 0.35294118, 0.15901814, 1. , 1. , 0.04451683,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 0.98984772, 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 0.04451683, 1. , 1. , 0.04451683, 1. ,\n",
" 0.15901814, 1. , 0.04451683, 1. , 1. ,\n",
" 1. , 0.64285714, 1. , 0.04451683, 1. ,\n",
" 0.15901814, 1. , 1. , 0.15901814, 1. ,\n",
" 1. , 0.04451683, 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 0.04451683, 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 0.04451683, 1. , 1. ,\n",
" 1. , 1. , 0.98984772, 0.04451683, 0.04451683,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 0.04451683, 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 0.98984772,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 0.98984772, 0.04451683, 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 0.15901814, 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 0.15901814, 1. , 1. ,\n",
" 1. , 0.98984772, 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 0. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 0.15901814,\n",
" 0.04451683, 1. , 1. , 0.98984772, 1. ]),\n",
" 'preds': array([1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n",
" 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1,\n",
" 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,\n",
" 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n",
" 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,\n",
" 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0,\n",
" 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0,\n",
" 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1,\n",
" 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1,\n",
" 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n",
" 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
" 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
" 1, 1, 1, 0, 0, 1, 1, 1, 1]),\n",
" 'Precision_train': np.float64(0.9209039548022598),\n",
" 'Precision_test': np.float64(0.015552099533437015),\n",
" 'Recall_train': np.float64(0.9889364739471805),\n",
" 'Recall_test': np.float64(0.45454545454545453),\n",
" 'Accuracy_train': 0.9521777777777778,\n",
" 'Accuracy_test': 0.12244897959183673,\n",
" 'ROC_AUC_test': np.float64(0.23084278974882058),\n",
" 'F1_train': np.float64(0.9537084839098262),\n",
" 'F1_test': np.float64(0.03007518796992481),\n",
" 'MCC_test': np.float64(-0.22309912384470268),\n",
" 'Cohen_kappa_test': np.float64(-0.029516833411874055),\n",
" 'Confusion_matrix': array([[ 80, 633],\n",
" [ 12, 10]])}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from src.utils import run_classification, run_regression\n",
"from sklearn import tree\n",
"\n",
"\n",
"fitted_model = tree.DecisionTreeClassifier(max_depth=7, random_state=random_state).fit(\n",
" X_train.values, y_train.values.ravel()\n",
")\n",
"result = run_classification(\n",
" fitted_model, X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test\n",
")\n",
"result"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py:540: FitFailedWarning: \n",
"450 fits failed out of a total of 1800.\n",
"The score on these train-test partitions for these parameters will be set to nan.\n",
"If these failures are not expected, you can try to debug them by setting error_score='raise'.\n",
"\n",
"Below are more details about the failures:\n",
"--------------------------------------------------------------------------------\n",
"450 fits failed with the following error:\n",
"Traceback (most recent call last):\n",
" File \"c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n",
" estimator.fit(X_train, y_train, **fit_params)\n",
" File \"c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n",
" return fit_method(estimator, *args, **kwargs)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\tree\\_classes.py\", line 1377, in fit\n",
" super()._fit(\n",
" File \"c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\tree\\_classes.py\", line 269, in _fit\n",
" raise ValueError(\n",
"ValueError: Some value(s) of y are negative which is not allowed for Poisson regression.\n",
"\n",
" warnings.warn(some_fits_failed_message, FitFailedWarning)\n",
"c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1103: UserWarning: One or more of the test scores are non-finite: [-4.40722975 -4.40722975 -4.40722975 -4.40722975 -4.40722975 -4.40722975\n",
" -4.40722975 -4.40722975 -4.40722975 -4.02547375 -4.02547375 -4.02547375\n",
" -4.02547375 -4.02547375 -4.02547375 -4.02547375 -4.02547375 -4.02547375\n",
" -4.17872998 -4.17872998 -4.1794137 -4.18153603 -4.18153603 -4.18153603\n",
" -4.18153603 -4.18153603 -4.18153603 -4.29228886 -4.27823005 -4.27378326\n",
" -4.27592366 -4.27917056 -4.23213065 -4.22140931 -4.2203268 -4.216674\n",
" -4.22464637 -4.33616785 -4.27369282 -4.25199473 -4.31065536 -4.27417069\n",
" -4.24259248 -4.2377248 -4.2476017 -4.32852991 -4.31823651 -4.28553229\n",
" -4.34177931 -4.30250957 -4.28753163 -4.28076012 -4.25271136 -4.25347712\n",
" -4.3969495 -4.36131695 -4.32082651 -4.31057625 -4.33653274 -4.3316473\n",
" -4.29354845 -4.30058389 -4.2945977 -4.33373406 -4.38833647 -4.35267217\n",
" -4.33498752 -4.4013588 -4.32169037 -4.30799648 -4.30856117 -4.35205443\n",
" -4.38832687 -4.37876963 -4.36158341 -4.33729531 -4.36291152 -4.34227116\n",
" -4.33926356 -4.31523722 -4.31237617 -4.42139371 -4.35079669 -4.40994949\n",
" -4.35999754 -4.37508884 -4.39768089 -4.36889387 -4.31988427 -4.30750986\n",
" 0.2652827 0.2652827 0.2652827 0.2652827 0.2652827 0.2652827\n",
" 0.2652827 0.2652827 0.2652827 -1.9287676 -1.9287676 -1.9287676\n",
" -1.9287676 -1.9287676 -1.9287676 -1.9287676 -1.9287676 -1.9287676\n",
" -1.97020055 -1.97020055 -1.97042935 -1.97042935 -1.97004014 -1.97004014\n",
" -1.9700399 -1.99709038 -1.99715081 -2.00301099 -1.99717467 -1.9930943\n",
" -1.99098557 -1.98702489 -1.98700995 -1.98699179 -2.01302486 -2.01501513\n",
" -2.10894255 -2.10994722 -2.09952227 -2.08394081 -2.08202479 -2.07188081\n",
" -2.07112545 -2.09099692 -2.09034787 -2.15392376 -2.1403913 -2.14641712\n",
" -2.12768247 -2.11768777 -2.11183283 -2.03846712 -2.12716538 -2.1272445\n",
" -2.16307528 -2.17486177 -2.19327604 -2.16449752 -2.08104295 -2.13803493\n",
" -2.08495978 -2.14818606 -2.14245762 -2.18712444 -2.20432741 -2.18741809\n",
" -2.17137267 -2.16649999 -2.09136755 -2.16386709 -2.17518238 -2.15065344\n",
" -2.24366056 -2.14969955 -2.12174719 -2.11054372 -2.17740208 -2.15664086\n",
" -2.16327196 -2.1757021 -2.16935629 -2.20600588 -2.21021639 -2.20224851\n",
" -2.19087629 -2.18339126 -2.16722284 -2.16000468 -2.18148192 -2.17331625\n",
" -4.40722975 -4.40722975 -4.40722975 -4.40722975 -4.40722975 -4.40722975\n",
" -4.40722975 -4.40722975 -4.40722975 -4.02547375 -4.02547375 -4.02547375\n",
" -4.02547375 -4.02547375 -4.02547375 -4.02547375 -4.02547375 -4.02547375\n",
" -4.17872998 -4.17872998 -4.1794137 -4.18153603 -4.18153603 -4.18153603\n",
" -4.18153603 -4.18153603 -4.18153603 -4.31301085 -4.27817322 -4.27362102\n",
" -4.27576142 -4.27917056 -4.23213065 -4.22140931 -4.2203268 -4.216674\n",
" -4.308967 -4.33406277 -4.27031206 -4.25187889 -4.30964117 -4.27417069\n",
" -4.24259248 -4.23837042 -4.2476017 -4.30938294 -4.35432558 -4.28424212\n",
" -4.3429458 -4.3011881 -4.28789821 -4.28076012 -4.25271136 -4.25347712\n",
" -4.36334041 -4.36324461 -4.32002519 -4.3118746 -4.33707678 -4.33137066\n",
" -4.29320914 -4.30010728 -4.2945977 -4.38237244 -4.38340433 -4.34774631\n",
" -4.33676135 -4.40583985 -4.33138452 -4.30737588 -4.30825185 -4.35224119\n",
" -4.38925553 -4.35711249 -4.35484652 -4.34829607 -4.3661772 -4.35261695\n",
" -4.33890487 -4.31498202 -4.31236225 -4.43228313 -4.38839311 -4.40088779\n",
" -4.35168819 -4.38046828 -4.39657208 -4.36830249 -4.31947573 -4.30727193\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan]\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"{'criterion': 'absolute_error', 'max_depth': 1, 'min_samples_split': 2}"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"from sklearn import model_selection\n",
"\n",
"parameters = {\n",
" \"criterion\": [\"squared_error\", \"absolute_error\", \"friedman_mse\", \"poisson\"],\n",
" \"max_depth\": np.arange(1, 21).tolist()[0::2],\n",
" \"min_samples_split\": np.arange(2, 20).tolist()[0::2],\n",
"}\n",
"\n",
"grid = model_selection.GridSearchCV(\n",
" tree.DecisionTreeRegressor(random_state=random_state), parameters, n_jobs=-1\n",
")\n",
"\n",
"grid.fit(X_train, y_train)\n",
"grid.best_params_"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\base.py:493: UserWarning: X does not have valid feature names, but DecisionTreeRegressor was fitted with feature names\n",
" warnings.warn(\n",
"c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\base.py:493: UserWarning: X does not have valid feature names, but DecisionTreeRegressor was fitted with feature names\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"{'RMSE_test': 1.1620245531837428,\n",
" 'RMAE_test': 0.793506853815132,\n",
" 'R2_test': 0.32558767171344627}"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'fitted': DecisionTreeRegressor(criterion='absolute_error', max_depth=1, random_state=9),\n",
" 'train_preds': array([0.40003, 2.4023 , 2.4023 , ..., 0.40003, 0.40003, 0.40003]),\n",
" 'preds': array([0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003,\n",
" 0.40003, 0.40003, 2.4023 , 0.40003, 2.4023 , 2.4023 , 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 2.4023 , 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003,\n",
" 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 2.4023 , 0.40003,\n",
" 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n",
" 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 2.4023 , 2.4023 , 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 2.4023 ,\n",
" 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 2.4023 , 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003, 2.4023 , 0.40003,\n",
" 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 2.4023 , 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 2.4023 , 0.40003, 2.4023 , 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003,\n",
" 2.4023 , 0.40003, 2.4023 , 0.40003, 2.4023 , 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 2.4023 , 0.40003, 2.4023 , 2.4023 , 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003,\n",
" 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003,\n",
" 2.4023 , 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 2.4023 , 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 2.4023 , 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n",
" 2.4023 , 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 2.4023 ,\n",
" 2.4023 , 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 2.4023 , 0.40003, 0.40003, 2.4023 , 2.4023 , 2.4023 , 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n",
" 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 2.4023 , 2.4023 , 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n",
" 0.40003, 2.4023 , 0.40003, 0.40003, 2.4023 , 0.40003, 2.4023 ,\n",
" 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n",
" 2.4023 , 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003, 2.4023 ]),\n",
" 'RMSE_train': 2.6924305509873223,\n",
" 'RMSE_test': 1.198383084740795,\n",
" 'RMAE_test': 0.8489844401433057,\n",
" 'R2_test': 0.2827241118036927}"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"old_metrics = {\n",
" \"RMSE_test\": models[\"decision_tree\"][\"RMSE_test\"],\n",
" \"RMAE_test\": models[\"decision_tree\"][\"RMAE_test\"],\n",
" \"R2_test\": models[\"decision_tree\"][\"R2_test\"],\n",
"}\n",
"new_metrics = run_regression(grid.best_estimator_, X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test)\n",
"\n",
"display(old_metrics)\n",
"display(new_metrics)"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"|--- x81 <= -0.00\n",
"| |--- x81 <= -0.44\n",
"| | |--- x4 <= 0.58\n",
"| | | |--- x58 <= -0.08\n",
"| | | | |--- class: 0\n",
"| | | |--- x58 > -0.08\n",
"| | | | |--- class: 1\n",
"| | |--- x4 > 0.58\n",
"| | | |--- x38 <= 0.04\n",
"| | | | |--- x81 <= -0.48\n",
"| | | | | |--- class: 1\n",
"| | | | |--- x81 > -0.48\n",
"| | | | | |--- class: 0\n",
"| | | |--- x38 > 0.04\n",
"| | | | |--- class: 1\n",
"| |--- x81 > -0.44\n",
"| | |--- x44 <= -0.07\n",
"| | | |--- x40 <= 0.33\n",
"| | | | |--- x3 <= 0.72\n",
"| | | | | |--- x23 <= 0.17\n",
"| | | | | | |--- class: 0\n",
"| | | | | |--- x23 > 0.17\n",
"| | | | | | |--- x15 <= 1.22\n",
"| | | | | | | |--- class: 0\n",
"| | | | | | |--- x15 > 1.22\n",
"| | | | | | | |--- class: 1\n",
"| | | | |--- x3 > 0.72\n",
"| | | | | |--- x41 <= 9.63\n",
"| | | | | | |--- x81 <= -0.40\n",
"| | | | | | | |--- class: 0\n",
"| | | | | | |--- x81 > -0.40\n",
"| | | | | | | |--- class: 1\n",
"| | | | | |--- x41 > 9.63\n",
"| | | | | | |--- class: 0\n",
"| | | |--- x40 > 0.33\n",
"| | | | |--- x75 <= 41.70\n",
"| | | | | |--- x23 <= 0.21\n",
"| | | | | | |--- class: 1\n",
"| | | | | |--- x23 > 0.21\n",
"| | | | | | |--- class: 0\n",
"| | | | |--- x75 > 41.70\n",
"| | | | | |--- x10 <= -0.19\n",
"| | | | | | |--- class: 1\n",
"| | | | | |--- x10 > -0.19\n",
"| | | | | | |--- class: 0\n",
"| | |--- x44 > -0.07\n",
"| | | |--- x57 <= 0.18\n",
"| | | | |--- x7 <= 0.74\n",
"| | | | | |--- class: 0\n",
"| | | | |--- x7 > 0.74\n",
"| | | | | |--- x15 <= 3.68\n",
"| | | | | | |--- x28 <= -0.12\n",
"| | | | | | | |--- class: 1\n",
"| | | | | | |--- x28 > -0.12\n",
"| | | | | | | |--- class: 0\n",
"| | | | | |--- x15 > 3.68\n",
"| | | | | | |--- class: 1\n",
"| | | |--- x57 > 0.18\n",
"| | | | |--- x49 <= 2.63\n",
"| | | | | |--- class: 0\n",
"| | | | |--- x49 > 2.63\n",
"| | | | | |--- x71 <= 45.23\n",
"| | | | | | |--- class: 1\n",
"| | | | | |--- x71 > 45.23\n",
"| | | | | | |--- class: 0\n",
"|--- x81 > -0.00\n",
"| |--- x25 <= -812.33\n",
"| | |--- class: 1\n",
"| |--- x25 > -812.33\n",
"| | |--- x36 <= 0.08\n",
"| | | |--- x47 <= 17.10\n",
"| | | | |--- x43 <= 0.15\n",
"| | | | | |--- x13 <= 0.07\n",
"| | | | | | |--- class: 0\n",
"| | | | | |--- x13 > 0.07\n",
"| | | | | | |--- class: 1\n",
"| | | | |--- x43 > 0.15\n",
"| | | | | |--- x75 <= 33.09\n",
"| | | | | | |--- x42 <= -0.92\n",
"| | | | | | | |--- class: 1\n",
"| | | | | | |--- x42 > -0.92\n",
"| | | | | | | |--- class: 0\n",
"| | | | | |--- x75 > 33.09\n",
"| | | | | | |--- x60 <= 0.66\n",
"| | | | | | | |--- class: 0\n",
"| | | | | | |--- x60 > 0.66\n",
"| | | | | | | |--- class: 1\n",
"| | | |--- x47 > 17.10\n",
"| | | | |--- class: 1\n",
"| | |--- x36 > 0.08\n",
"| | | |--- x53 <= -0.02\n",
"| | | | |--- class: 1\n",
"| | | |--- x53 > -0.02\n",
"| | | | |--- x46 <= 0.03\n",
"| | | | | |--- x46 <= 0.03\n",
"| | | | | | |--- class: 0\n",
"| | | | | |--- x46 > 0.03\n",
"| | | | | | |--- class: 1\n",
"| | | | |--- x46 > 0.03\n",
"| | | | | |--- x26 <= 11.02\n",
"| | | | | | |--- x79 <= -17.05\n",
"| | | | | | | |--- class: 1\n",
"| | | | | | |--- x79 > -17.05\n",
"| | | | | | | |--- class: 0\n",
"| | | | | |--- x26 > 11.02\n",
"| | | | | | |--- x24 <= 0.82\n",
"| | | | | | | |--- class: 0\n",
"| | | | | | |--- x24 > 0.82\n",
"| | | | | | | |--- class: 0\n",
"\n"
]
}
],
"source": [
"rules = tree.export_text(\n",
" fitted_model,\n",
" feature_names=X_train.columns.values.tolist(),\n",
")\n",
"print(rules)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
"pickle.dump(\n",
" models[\"decision_tree\"][\"fitted\"], open(\"data-distress/vtree.model.sav\", \"wb\")\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}