{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Financial Distressx1x2x3x4x5x6x7x8x9...x73x74x75x76x77x78x79x81x82x83
001.28100.0229340.874541.216400.0609400.1882700.525100.0188540.182790...36.085.43727.0726.10216.00016.00.20.0603903049
101.27000.0064540.820671.00490-0.0140800.1810400.622880.0064230.035991...36.0107.09031.3130.19417.00016.00.40.0106363150
201.0529-0.0593790.922420.729260.0204760.0448650.43292-0.081423-0.765400...35.0120.87036.0735.27317.00015.0-0.2-0.4559703251
311.1131-0.0152290.858880.809740.0760370.0910330.67546-0.018807-0.107910...33.054.80639.8038.37717.16716.05.6-0.3253903352
401.06230.1070200.814600.835930.1999600.0478000.742000.1280300.577250...36.085.43727.0726.10216.00016.00.21.251000727
..................................................................
366702.26050.2028900.160370.185880.1759700.1984002.223601.0915000.241640...22.0100.000100.00100.00017.12514.5-7.00.436380441
366801.96150.2164400.200950.216420.2035900.1898701.938201.0001000.270870...28.091.500130.50132.40020.00014.5-16.00.438020542
366901.70990.2079700.261360.213990.1936700.1838901.689800.9718600.281560...32.087.100175.90178.10020.00014.5-20.20.482410643
367001.55900.1854500.307280.193070.1721400.1706801.538900.9605700.267720...30.092.900203.20204.50022.00022.06.40.500770744
367101.61480.1767600.363690.184420.1695500.1978601.584200.9584500.277780...29.091.700227.50214.50021.00020.58.60.611030845
\n", "

3672 rows × 83 columns

\n", "
" ], "text/plain": [ " Financial Distress x1 x2 x3 x4 x5 \\\n", "0 0 1.2810 0.022934 0.87454 1.21640 0.060940 \n", "1 0 1.2700 0.006454 0.82067 1.00490 -0.014080 \n", "2 0 1.0529 -0.059379 0.92242 0.72926 0.020476 \n", "3 1 1.1131 -0.015229 0.85888 0.80974 0.076037 \n", "4 0 1.0623 0.107020 0.81460 0.83593 0.199960 \n", "... ... ... ... ... ... ... \n", "3667 0 2.2605 0.202890 0.16037 0.18588 0.175970 \n", "3668 0 1.9615 0.216440 0.20095 0.21642 0.203590 \n", "3669 0 1.7099 0.207970 0.26136 0.21399 0.193670 \n", "3670 0 1.5590 0.185450 0.30728 0.19307 0.172140 \n", "3671 0 1.6148 0.176760 0.36369 0.18442 0.169550 \n", "\n", " x6 x7 x8 x9 ... x73 x74 x75 \\\n", "0 0.188270 0.52510 0.018854 0.182790 ... 36.0 85.437 27.07 \n", "1 0.181040 0.62288 0.006423 0.035991 ... 36.0 107.090 31.31 \n", "2 0.044865 0.43292 -0.081423 -0.765400 ... 35.0 120.870 36.07 \n", "3 0.091033 0.67546 -0.018807 -0.107910 ... 33.0 54.806 39.80 \n", "4 0.047800 0.74200 0.128030 0.577250 ... 36.0 85.437 27.07 \n", "... ... ... ... ... ... ... ... ... \n", "3667 0.198400 2.22360 1.091500 0.241640 ... 22.0 100.000 100.00 \n", "3668 0.189870 1.93820 1.000100 0.270870 ... 28.0 91.500 130.50 \n", "3669 0.183890 1.68980 0.971860 0.281560 ... 32.0 87.100 175.90 \n", "3670 0.170680 1.53890 0.960570 0.267720 ... 30.0 92.900 203.20 \n", "3671 0.197860 1.58420 0.958450 0.277780 ... 29.0 91.700 227.50 \n", "\n", " x76 x77 x78 x79 x81 x82 x83 \n", "0 26.102 16.000 16.0 0.2 0.060390 30 49 \n", "1 30.194 17.000 16.0 0.4 0.010636 31 50 \n", "2 35.273 17.000 15.0 -0.2 -0.455970 32 51 \n", "3 38.377 17.167 16.0 5.6 -0.325390 33 52 \n", "4 26.102 16.000 16.0 0.2 1.251000 7 27 \n", "... ... ... ... ... ... ... ... \n", "3667 100.000 17.125 14.5 -7.0 0.436380 4 41 \n", "3668 132.400 20.000 14.5 -16.0 0.438020 5 42 \n", "3669 178.100 20.000 14.5 -20.2 0.482410 6 43 \n", "3670 204.500 22.000 22.0 6.4 0.500770 7 44 \n", "3671 214.500 21.000 20.5 8.6 0.611030 8 45 \n", "\n", "[3672 rows x 83 columns]" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "random_state = 9\n", "\n", "def get_class(row):\n", " return 0 if row[\"Financial Distress\"] > -0.5 else 1\n", "\n", "\n", "df = pd.read_csv(\"data-distress/FinancialDistress.csv\").drop(\n", " [\"Company\", \"Time\", \"x80\"], axis=1\n", ")\n", "df[\"Financial Distress\"] = df.apply(get_class, axis=1)\n", "df" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
x1x2x3x4x5x6x7x8x9x10...x73x74x75x76x77x78x79x81x82x83
01.09020.0817150.710560.934460.144450.0600420.740480.0874470.282320.14572...28.079.95166.1259.47118.012.0-13.41.11521328
11.67110.4452500.211040.595230.309980.1332601.340600.7480400.564360.48288...32.087.100175.90178.10020.014.5-20.21.7354218
21.63210.3757900.460720.903270.285630.2874401.358900.4160400.696840.45008...30.092.900203.20204.50022.022.06.45.5809817
\n", "

3 rows × 82 columns

\n", "
" ], "text/plain": [ " x1 x2 x3 x4 x5 x6 x7 x8 \\\n", "0 1.0902 0.081715 0.71056 0.93446 0.14445 0.060042 0.74048 0.087447 \n", "1 1.6711 0.445250 0.21104 0.59523 0.30998 0.133260 1.34060 0.748040 \n", "2 1.6321 0.375790 0.46072 0.90327 0.28563 0.287440 1.35890 0.416040 \n", "\n", " x9 x10 ... x73 x74 x75 x76 x77 x78 x79 \\\n", "0 0.28232 0.14572 ... 28.0 79.951 66.12 59.471 18.0 12.0 -13.4 \n", "1 0.56436 0.48288 ... 32.0 87.100 175.90 178.100 20.0 14.5 -20.2 \n", "2 0.69684 0.45008 ... 30.0 92.900 203.20 204.500 22.0 22.0 6.4 \n", "\n", " x81 x82 x83 \n", "0 1.1152 13 28 \n", "1 1.7354 2 18 \n", "2 5.5809 8 17 \n", "\n", "[3 rows x 82 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Financial Distress
00
10
20
\n", "
" ], "text/plain": [ " Financial Distress\n", "0 0\n", "1 0\n", "2 0" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
x1x2x3x4x5x6x7x8x9x10...x73x74x75x76x77x78x79x81x82x83
33790.985690.129120.622660.743770.13716-0.0060400.651750.1736000.342180.008050...35.5113.98033.69032.73417.00015.50.100000.045774919
1560.910840.048890.791080.686150.10943-0.0503110.606330.0712530.234010.011391...35.0120.87036.07035.27317.00015.0-0.200000.152330729
22151.433500.200680.465380.541460.251400.0962120.822710.3706300.375380.187750...36.098.06629.54328.48916.58316.00.316670.563030119
\n", "

3 rows × 82 columns

\n", "
" ], "text/plain": [ " x1 x2 x3 x4 x5 x6 x7 \\\n", "3379 0.98569 0.12912 0.62266 0.74377 0.13716 -0.006040 0.65175 \n", "156 0.91084 0.04889 0.79108 0.68615 0.10943 -0.050311 0.60633 \n", "2215 1.43350 0.20068 0.46538 0.54146 0.25140 0.096212 0.82271 \n", "\n", " x8 x9 x10 ... x73 x74 x75 x76 x77 \\\n", "3379 0.173600 0.34218 0.008050 ... 35.5 113.980 33.690 32.734 17.000 \n", "156 0.071253 0.23401 0.011391 ... 35.0 120.870 36.070 35.273 17.000 \n", "2215 0.370630 0.37538 0.187750 ... 36.0 98.066 29.543 28.489 16.583 \n", "\n", " x78 x79 x81 x82 x83 \n", "3379 15.5 0.10000 0.045774 9 19 \n", "156 15.0 -0.20000 0.152330 7 29 \n", "2215 16.0 0.31667 0.563030 1 19 \n", "\n", "[3 rows x 82 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Financial Distress
33790
1560
22150
\n", "
" ], "text/plain": [ " Financial Distress\n", "3379 0\n", "156 0\n", "2215 0" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from collections import Counter\n", "from src.utils import split_stratified_into_train_val_test\n", "from imblearn.over_sampling import ADASYN\n", "\n", "X_train, X_test, y_train, y_test = split_stratified_into_train_val_test(\n", " df,\n", " stratify_colname=\"Financial Distress\",\n", " frac_train=0.8,\n", " frac_val=0,\n", " frac_test=0.2,\n", " random_state=random_state,\n", ")\n", "\n", "ada = ADASYN()\n", "\n", "X_train, y_train = ada.fit_resample(X_train, y_train)\n", "\n", "\n", "# print(f\"Original dataset shape {len(df[[\"Financial Distress\"]])}\")\n", "# X, y = reducer.fit_resample(\n", "# df.drop([\"Financial Distress\"], axis=1), df[[\"Financial Distress\"]]\n", "# )\n", "# print(f\"Original dataset shape {len(y)}\")\n", "\n", "\n", "# X_train = pd.DataFrame(\n", "# sc.fit_transform(X_train.values), columns=X_train.columns, index=X_train.index\n", "# )\n", "# X_test = pd.DataFrame(\n", "# sc.fit_transform(X_test.values), columns=X_test.columns, index=X_test.index\n", "# )\n", "\n", "\n", "display(X_train.head(3))\n", "display(y_train.head(3))\n", "display(X_test.head(3))\n", "display(y_test.head(3))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\base.py:486: UserWarning: X has feature names, but DecisionTreeClassifier was fitted without feature names\n", " warnings.warn(\n", "c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\base.py:486: UserWarning: X has feature names, but DecisionTreeClassifier was fitted without feature names\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "{'pipeline': DecisionTreeClassifier(max_depth=7, random_state=9),\n", " 'probs': array([1. , 1. , 1. , 0.04451683, 1. ,\n", " 1. , 1. , 1. , 0.98984772, 1. ,\n", " 0.04451683, 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 0.04451683, 1. , 1. , 0.04451683, 0.98984772,\n", " 0.98984772, 1. , 1. , 0.04451683, 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 0.98984772, 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 0.15901814, 1. , 1. , 1. ,\n", " 1. , 0.98984772, 0.04451683, 0.98984772, 1. ,\n", " 1. , 1. , 1. , 0.98984772, 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 0.15901814, 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 0.98984772, 1. , 1. ,\n", " 1. , 1. , 1. , 0.64285714, 0.98984772,\n", " 0.04451683, 1. , 0.98984772, 1. , 1. ,\n", " 1. , 1. , 1. , 0.98984772, 1. ,\n", " 1. , 0.64285714, 0.04451683, 1. , 1. ,\n", " 1. , 0.04451683, 1. , 1. , 0.98984772,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 0.15901814, 1. ,\n", " 0.15901814, 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 0. , 1. , 1. ,\n", " 0.98984772, 1. , 1. , 1. , 0.04451683,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 0.04451683, 1. , 0.15901814, 1. ,\n", " 1. , 0.64285714, 1. , 1. , 0.64285714,\n", " 1. , 0.15901814, 0.15901814, 0.64285714, 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 0.04451683, 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 0.04451683, 1. , 1. ,\n", " 1. , 0.04451683, 1. , 0.04451683, 0.15901814,\n", " 1. , 0.98984772, 1. , 1. , 0.625 ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 0.98984772, 1. , 1. , 1. ,\n", " 1. , 1. , 0.04451683, 0.98984772, 0.98984772,\n", " 1. , 0.04451683, 1. , 1. , 1. ,\n", " 1. , 1. , 0.04451683, 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 0.64285714, 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 0.64285714, 0.15901814, 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 0.98984772, 1. , 1. , 1. ,\n", " 0.15901814, 0.15901814, 1. , 0.98984772, 1. ,\n", " 1. , 1. , 0.15901814, 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 0.98984772, 0.15901814, 1. , 1. , 1. ,\n", " 1. , 0.04451683, 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 0.98984772, 1. , 0.04451683, 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 0.64285714,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 0.04451683,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 0.98984772, 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 0.04451683, 1. , 1. , 1. , 1. ,\n", " 0.15901814, 1. , 0.98984772, 0.04451683, 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 0.15901814, 1. ,\n", " 1. , 1. , 0.04451683, 1. , 1. ,\n", " 1. , 1. , 1. , 0.98984772, 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 0.04451683,\n", " 0.15901814, 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 0. , 0.15901814, 1. , 0.64285714,\n", " 0.04451683, 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 0.98984772, 1. , 0.98984772,\n", " 1. , 1. , 1. , 0.98984772, 1. ,\n", " 1. , 1. , 1. , 0.98984772, 1. ,\n", " 0.08333333, 1. , 1. , 1. , 1. ,\n", " 1. , 0. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 0.15901814,\n", " 1. , 1. , 1. , 0.98984772, 1. ,\n", " 1. , 0. , 1. , 1. , 0.64285714,\n", " 0.98984772, 1. , 0.04451683, 1. , 1. ,\n", " 0.04451683, 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 0.98984772, 1. ,\n", " 1. , 1. , 0. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 0.98984772, 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 0.98984772, 0.15901814, 1. ,\n", " 1. , 1. , 0.04451683, 1. , 1. ,\n", " 1. , 1. , 0.15901814, 1. , 0.8 ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 0.64285714, 0.64285714, 1. , 1. ,\n", " 0.08333333, 1. , 1. , 1. , 0.8 ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 0.04451683, 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 0.98984772, 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 0.04451683, 1. , 1. ,\n", " 0.04451683, 1. , 0.15901814, 1. , 0.04451683,\n", " 1. , 1. , 1. , 0.98984772, 0.15901814,\n", " 1. , 0.64285714, 1. , 1. , 1. ,\n", " 0.64285714, 1. , 1. , 0.64285714, 0. ,\n", " 0.04451683, 1. , 1. , 1. , 1. ,\n", " 1. , 0.04451683, 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 0.15901814,\n", " 0.98984772, 1. , 0.15901814, 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 0.04451683, 1. , 0.98984772, 0.04451683,\n", " 1. , 0.98984772, 1. , 1. , 0.04451683,\n", " 1. , 1. , 1. , 0.64285714, 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 0.98984772, 1. , 1. , 1. , 1. ,\n", " 1. , 0.98984772, 1. , 1. , 1. ,\n", " 0.35294118, 0.15901814, 1. , 1. , 0.04451683,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 0.98984772, 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 0.04451683, 1. , 1. , 0.04451683, 1. ,\n", " 0.15901814, 1. , 0.04451683, 1. , 1. ,\n", " 1. , 0.64285714, 1. , 0.04451683, 1. ,\n", " 0.15901814, 1. , 1. , 0.15901814, 1. ,\n", " 1. , 0.04451683, 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 0.04451683, 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 0.04451683, 1. , 1. ,\n", " 1. , 1. , 0.98984772, 0.04451683, 0.04451683,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 0.04451683, 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 0.98984772,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 0.98984772, 0.04451683, 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 0.15901814, 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 0.15901814, 1. , 1. ,\n", " 1. , 0.98984772, 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 0. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 0.15901814,\n", " 0.04451683, 1. , 1. , 0.98984772, 1. ]),\n", " 'preds': array([1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n", " 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1,\n", " 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,\n", " 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n", " 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,\n", " 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0,\n", " 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0,\n", " 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1,\n", " 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1,\n", " 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n", " 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n", " 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n", " 1, 1, 1, 0, 0, 1, 1, 1, 1]),\n", " 'Precision_train': np.float64(0.9209039548022598),\n", " 'Precision_test': np.float64(0.015552099533437015),\n", " 'Recall_train': np.float64(0.9889364739471805),\n", " 'Recall_test': np.float64(0.45454545454545453),\n", " 'Accuracy_train': 0.9521777777777778,\n", " 'Accuracy_test': 0.12244897959183673,\n", " 'ROC_AUC_test': np.float64(0.23084278974882058),\n", " 'F1_train': np.float64(0.9537084839098262),\n", " 'F1_test': np.float64(0.03007518796992481),\n", " 'MCC_test': np.float64(-0.22309912384470268),\n", " 'Cohen_kappa_test': np.float64(-0.029516833411874055),\n", " 'Confusion_matrix': array([[ 80, 633],\n", " [ 12, 10]])}" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from src.utils import run_classification, run_regression\n", "from sklearn import tree\n", "\n", "\n", "fitted_model = tree.DecisionTreeClassifier(max_depth=7, random_state=random_state).fit(\n", " X_train.values, y_train.values.ravel()\n", ")\n", "result = run_classification(\n", " fitted_model, X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test\n", ")\n", "result" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py:540: FitFailedWarning: \n", "450 fits failed out of a total of 1800.\n", "The score on these train-test partitions for these parameters will be set to nan.\n", "If these failures are not expected, you can try to debug them by setting error_score='raise'.\n", "\n", "Below are more details about the failures:\n", "--------------------------------------------------------------------------------\n", "450 fits failed with the following error:\n", "Traceback (most recent call last):\n", " File \"c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n", " estimator.fit(X_train, y_train, **fit_params)\n", " File \"c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n", " return fit_method(estimator, *args, **kwargs)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\tree\\_classes.py\", line 1377, in fit\n", " super()._fit(\n", " File \"c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\tree\\_classes.py\", line 269, in _fit\n", " raise ValueError(\n", "ValueError: Some value(s) of y are negative which is not allowed for Poisson regression.\n", "\n", " warnings.warn(some_fits_failed_message, FitFailedWarning)\n", "c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1103: UserWarning: One or more of the test scores are non-finite: [-4.40722975 -4.40722975 -4.40722975 -4.40722975 -4.40722975 -4.40722975\n", " -4.40722975 -4.40722975 -4.40722975 -4.02547375 -4.02547375 -4.02547375\n", " -4.02547375 -4.02547375 -4.02547375 -4.02547375 -4.02547375 -4.02547375\n", " -4.17872998 -4.17872998 -4.1794137 -4.18153603 -4.18153603 -4.18153603\n", " -4.18153603 -4.18153603 -4.18153603 -4.29228886 -4.27823005 -4.27378326\n", " -4.27592366 -4.27917056 -4.23213065 -4.22140931 -4.2203268 -4.216674\n", " -4.22464637 -4.33616785 -4.27369282 -4.25199473 -4.31065536 -4.27417069\n", " -4.24259248 -4.2377248 -4.2476017 -4.32852991 -4.31823651 -4.28553229\n", " -4.34177931 -4.30250957 -4.28753163 -4.28076012 -4.25271136 -4.25347712\n", " -4.3969495 -4.36131695 -4.32082651 -4.31057625 -4.33653274 -4.3316473\n", " -4.29354845 -4.30058389 -4.2945977 -4.33373406 -4.38833647 -4.35267217\n", " -4.33498752 -4.4013588 -4.32169037 -4.30799648 -4.30856117 -4.35205443\n", " -4.38832687 -4.37876963 -4.36158341 -4.33729531 -4.36291152 -4.34227116\n", " -4.33926356 -4.31523722 -4.31237617 -4.42139371 -4.35079669 -4.40994949\n", " -4.35999754 -4.37508884 -4.39768089 -4.36889387 -4.31988427 -4.30750986\n", " 0.2652827 0.2652827 0.2652827 0.2652827 0.2652827 0.2652827\n", " 0.2652827 0.2652827 0.2652827 -1.9287676 -1.9287676 -1.9287676\n", " -1.9287676 -1.9287676 -1.9287676 -1.9287676 -1.9287676 -1.9287676\n", " -1.97020055 -1.97020055 -1.97042935 -1.97042935 -1.97004014 -1.97004014\n", " -1.9700399 -1.99709038 -1.99715081 -2.00301099 -1.99717467 -1.9930943\n", " -1.99098557 -1.98702489 -1.98700995 -1.98699179 -2.01302486 -2.01501513\n", " -2.10894255 -2.10994722 -2.09952227 -2.08394081 -2.08202479 -2.07188081\n", " -2.07112545 -2.09099692 -2.09034787 -2.15392376 -2.1403913 -2.14641712\n", " -2.12768247 -2.11768777 -2.11183283 -2.03846712 -2.12716538 -2.1272445\n", " -2.16307528 -2.17486177 -2.19327604 -2.16449752 -2.08104295 -2.13803493\n", " -2.08495978 -2.14818606 -2.14245762 -2.18712444 -2.20432741 -2.18741809\n", " -2.17137267 -2.16649999 -2.09136755 -2.16386709 -2.17518238 -2.15065344\n", " -2.24366056 -2.14969955 -2.12174719 -2.11054372 -2.17740208 -2.15664086\n", " -2.16327196 -2.1757021 -2.16935629 -2.20600588 -2.21021639 -2.20224851\n", " -2.19087629 -2.18339126 -2.16722284 -2.16000468 -2.18148192 -2.17331625\n", " -4.40722975 -4.40722975 -4.40722975 -4.40722975 -4.40722975 -4.40722975\n", " -4.40722975 -4.40722975 -4.40722975 -4.02547375 -4.02547375 -4.02547375\n", " -4.02547375 -4.02547375 -4.02547375 -4.02547375 -4.02547375 -4.02547375\n", " -4.17872998 -4.17872998 -4.1794137 -4.18153603 -4.18153603 -4.18153603\n", " -4.18153603 -4.18153603 -4.18153603 -4.31301085 -4.27817322 -4.27362102\n", " -4.27576142 -4.27917056 -4.23213065 -4.22140931 -4.2203268 -4.216674\n", " -4.308967 -4.33406277 -4.27031206 -4.25187889 -4.30964117 -4.27417069\n", " -4.24259248 -4.23837042 -4.2476017 -4.30938294 -4.35432558 -4.28424212\n", " -4.3429458 -4.3011881 -4.28789821 -4.28076012 -4.25271136 -4.25347712\n", " -4.36334041 -4.36324461 -4.32002519 -4.3118746 -4.33707678 -4.33137066\n", " -4.29320914 -4.30010728 -4.2945977 -4.38237244 -4.38340433 -4.34774631\n", " -4.33676135 -4.40583985 -4.33138452 -4.30737588 -4.30825185 -4.35224119\n", " -4.38925553 -4.35711249 -4.35484652 -4.34829607 -4.3661772 -4.35261695\n", " -4.33890487 -4.31498202 -4.31236225 -4.43228313 -4.38839311 -4.40088779\n", " -4.35168819 -4.38046828 -4.39657208 -4.36830249 -4.31947573 -4.30727193\n", " nan nan nan nan nan nan\n", " nan nan nan nan nan nan\n", " nan nan nan nan nan nan\n", " nan nan nan nan nan nan\n", " nan nan nan nan nan nan\n", " nan nan nan nan nan nan\n", " nan nan nan nan nan nan\n", " nan nan nan nan nan nan\n", " nan nan nan nan nan nan\n", " nan nan nan nan nan nan\n", " nan nan nan nan nan nan\n", " nan nan nan nan nan nan\n", " nan nan nan nan nan nan\n", " nan nan nan nan nan nan\n", " nan nan nan nan nan nan]\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "{'criterion': 'absolute_error', 'max_depth': 1, 'min_samples_split': 2}" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "from sklearn import model_selection\n", "\n", "parameters = {\n", " \"criterion\": [\"squared_error\", \"absolute_error\", \"friedman_mse\", \"poisson\"],\n", " \"max_depth\": np.arange(1, 21).tolist()[0::2],\n", " \"min_samples_split\": np.arange(2, 20).tolist()[0::2],\n", "}\n", "\n", "grid = model_selection.GridSearchCV(\n", " tree.DecisionTreeRegressor(random_state=random_state), parameters, n_jobs=-1\n", ")\n", "\n", "grid.fit(X_train, y_train)\n", "grid.best_params_" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\base.py:493: UserWarning: X does not have valid feature names, but DecisionTreeRegressor was fitted with feature names\n", " warnings.warn(\n", "c:\\Users\\user\\Projects\\python\\fuzzy\\.venv\\Lib\\site-packages\\sklearn\\base.py:493: UserWarning: X does not have valid feature names, but DecisionTreeRegressor was fitted with feature names\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "{'RMSE_test': 1.1620245531837428,\n", " 'RMAE_test': 0.793506853815132,\n", " 'R2_test': 0.32558767171344627}" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'fitted': DecisionTreeRegressor(criterion='absolute_error', max_depth=1, random_state=9),\n", " 'train_preds': array([0.40003, 2.4023 , 2.4023 , ..., 0.40003, 0.40003, 0.40003]),\n", " 'preds': array([0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003,\n", " 0.40003, 0.40003, 2.4023 , 0.40003, 2.4023 , 2.4023 , 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 2.4023 , 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003,\n", " 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003,\n", " 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003,\n", " 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003,\n", " 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 2.4023 , 0.40003,\n", " 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n", " 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 2.4023 , 2.4023 , 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 2.4023 ,\n", " 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 2.4023 , 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003,\n", " 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003, 2.4023 , 0.40003,\n", " 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 2.4023 , 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 2.4023 , 0.40003, 2.4023 , 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003,\n", " 2.4023 , 0.40003, 2.4023 , 0.40003, 2.4023 , 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 2.4023 , 0.40003, 2.4023 , 2.4023 , 0.40003, 0.40003, 0.40003,\n", " 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003,\n", " 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003,\n", " 2.4023 , 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 2.4023 , 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 2.4023 , 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n", " 2.4023 , 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 2.4023 ,\n", " 2.4023 , 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003,\n", " 2.4023 , 0.40003, 0.40003, 2.4023 , 2.4023 , 2.4023 , 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n", " 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 2.4023 , 2.4023 , 2.4023 , 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 0.40003, 0.40003, 2.4023 , 0.40003, 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 2.4023 ,\n", " 0.40003, 2.4023 , 0.40003, 0.40003, 2.4023 , 0.40003, 2.4023 ,\n", " 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003, 0.40003,\n", " 2.4023 , 0.40003, 0.40003, 2.4023 , 2.4023 , 0.40003, 2.4023 ]),\n", " 'RMSE_train': 2.6924305509873223,\n", " 'RMSE_test': 1.198383084740795,\n", " 'RMAE_test': 0.8489844401433057,\n", " 'R2_test': 0.2827241118036927}" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "old_metrics = {\n", " \"RMSE_test\": models[\"decision_tree\"][\"RMSE_test\"],\n", " \"RMAE_test\": models[\"decision_tree\"][\"RMAE_test\"],\n", " \"R2_test\": models[\"decision_tree\"][\"R2_test\"],\n", "}\n", "new_metrics = run_regression(grid.best_estimator_, X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test)\n", "\n", "display(old_metrics)\n", "display(new_metrics)" ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "|--- x81 <= -0.00\n", "| |--- x81 <= -0.44\n", "| | |--- x4 <= 0.58\n", "| | | |--- x58 <= -0.08\n", "| | | | |--- class: 0\n", "| | | |--- x58 > -0.08\n", "| | | | |--- class: 1\n", "| | |--- x4 > 0.58\n", "| | | |--- x38 <= 0.04\n", "| | | | |--- x81 <= -0.48\n", "| | | | | |--- class: 1\n", "| | | | |--- x81 > -0.48\n", "| | | | | |--- class: 0\n", "| | | |--- x38 > 0.04\n", "| | | | |--- class: 1\n", "| |--- x81 > -0.44\n", "| | |--- x44 <= -0.07\n", "| | | |--- x40 <= 0.33\n", "| | | | |--- x3 <= 0.72\n", "| | | | | |--- x23 <= 0.17\n", "| | | | | | |--- class: 0\n", "| | | | | |--- x23 > 0.17\n", "| | | | | | |--- x15 <= 1.22\n", "| | | | | | | |--- class: 0\n", "| | | | | | |--- x15 > 1.22\n", "| | | | | | | |--- class: 1\n", "| | | | |--- x3 > 0.72\n", "| | | | | |--- x41 <= 9.63\n", "| | | | | | |--- x81 <= -0.40\n", "| | | | | | | |--- class: 0\n", "| | | | | | |--- x81 > -0.40\n", "| | | | | | | |--- class: 1\n", "| | | | | |--- x41 > 9.63\n", "| | | | | | |--- class: 0\n", "| | | |--- x40 > 0.33\n", "| | | | |--- x75 <= 41.70\n", "| | | | | |--- x23 <= 0.21\n", "| | | | | | |--- class: 1\n", "| | | | | |--- x23 > 0.21\n", "| | | | | | |--- class: 0\n", "| | | | |--- x75 > 41.70\n", "| | | | | |--- x10 <= -0.19\n", "| | | | | | |--- class: 1\n", "| | | | | |--- x10 > -0.19\n", "| | | | | | |--- class: 0\n", "| | |--- x44 > -0.07\n", "| | | |--- x57 <= 0.18\n", "| | | | |--- x7 <= 0.74\n", "| | | | | |--- class: 0\n", "| | | | |--- x7 > 0.74\n", "| | | | | |--- x15 <= 3.68\n", "| | | | | | |--- x28 <= -0.12\n", "| | | | | | | |--- class: 1\n", "| | | | | | |--- x28 > -0.12\n", "| | | | | | | |--- class: 0\n", "| | | | | |--- x15 > 3.68\n", "| | | | | | |--- class: 1\n", "| | | |--- x57 > 0.18\n", "| | | | |--- x49 <= 2.63\n", "| | | | | |--- class: 0\n", "| | | | |--- x49 > 2.63\n", "| | | | | |--- x71 <= 45.23\n", "| | | | | | |--- class: 1\n", "| | | | | |--- x71 > 45.23\n", "| | | | | | |--- class: 0\n", "|--- x81 > -0.00\n", "| |--- x25 <= -812.33\n", "| | |--- class: 1\n", "| |--- x25 > -812.33\n", "| | |--- x36 <= 0.08\n", "| | | |--- x47 <= 17.10\n", "| | | | |--- x43 <= 0.15\n", "| | | | | |--- x13 <= 0.07\n", "| | | | | | |--- class: 0\n", "| | | | | |--- x13 > 0.07\n", "| | | | | | |--- class: 1\n", "| | | | |--- x43 > 0.15\n", "| | | | | |--- x75 <= 33.09\n", "| | | | | | |--- x42 <= -0.92\n", "| | | | | | | |--- class: 1\n", "| | | | | | |--- x42 > -0.92\n", "| | | | | | | |--- class: 0\n", "| | | | | |--- x75 > 33.09\n", "| | | | | | |--- x60 <= 0.66\n", "| | | | | | | |--- class: 0\n", "| | | | | | |--- x60 > 0.66\n", "| | | | | | | |--- class: 1\n", "| | | |--- x47 > 17.10\n", "| | | | |--- class: 1\n", "| | |--- x36 > 0.08\n", "| | | |--- x53 <= -0.02\n", "| | | | |--- class: 1\n", "| | | |--- x53 > -0.02\n", "| | | | |--- x46 <= 0.03\n", "| | | | | |--- x46 <= 0.03\n", "| | | | | | |--- class: 0\n", "| | | | | |--- x46 > 0.03\n", "| | | | | | |--- class: 1\n", "| | | | |--- x46 > 0.03\n", "| | | | | |--- x26 <= 11.02\n", "| | | | | | |--- x79 <= -17.05\n", "| | | | | | | |--- class: 1\n", "| | | | | | |--- x79 > -17.05\n", "| | | | | | | |--- class: 0\n", "| | | | | |--- x26 > 11.02\n", "| | | | | | |--- x24 <= 0.82\n", "| | | | | | | |--- class: 0\n", "| | | | | | |--- x24 > 0.82\n", "| | | | | | | |--- class: 0\n", "\n" ] } ], "source": [ "rules = tree.export_text(\n", " fitted_model,\n", " feature_names=X_train.columns.values.tolist(),\n", ")\n", "print(rules)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "\n", "pickle.dump(\n", " models[\"decision_tree\"][\"fitted\"], open(\"data-distress/vtree.model.sav\", \"wb\")\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 2 }