diff --git a/.gitignore b/.gitignore index 6adbb7b..75baee6 100644 --- a/.gitignore +++ b/.gitignore @@ -276,4 +276,4 @@ cython_debug/ node_modules/ data/aa-domestic-delays-2018.csv.zip -data/aa-domestic-delays-2018.csv/ \ No newline at end of file +data/aa-domestic-delays-2018.csv \ No newline at end of file diff --git a/lec4.ipynb b/lec4.ipynb new file mode 100644 index 0000000..66b82bf --- /dev/null +++ b/lec4.ipynb @@ -0,0 +1,3413 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Y5dMmHXIRYEg" + }, + "source": [ + "#### Загрузка и распаковка данных" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from urllib.request import urlretrieve\n", + "from zipfile import ZipFile\n", + "\n", + "ds_url = \"https://github.com/PacktPublishing/Interpretable-Machine-Learning-with-Python/raw/master/datasets/aa-domestic-delays-2018.csv.zip\"\n", + "ds_zip_filename = \"data/aa-domestic-delays-2018.csv.zip\"\n", + "urlretrieve(ds_url, ds_zip_filename)\n", + "\n", + "with ZipFile(ds_zip_filename, \"r\") as zObject:\n", + " zObject.extractall(path=\"data\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Загрузка данных в Dataframe" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "YvuKMosoRY7K", + "outputId": "f9f05784-9c91-4869-a02e-e1ccc8115a8d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "RangeIndex: 899527 entries, 0 to 899526\n", + "Data columns (total 23 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 FL_NUM 899527 non-null int64 \n", + " 1 ORIGIN 899527 non-null object \n", + " 2 DEST 899527 non-null object \n", + " 3 PLANNED_DEP_DATETIME 899527 non-null object \n", + " 4 CRS_DEP_TIME 899527 non-null int64 \n", + " 5 DEP_TIME 899527 non-null float64\n", + " 6 DEP_DELAY 899527 non-null float64\n", + " 7 DEP_AFPH 899527 non-null float64\n", + " 8 DEP_RFPH 899527 non-null float64\n", + " 9 TAXI_OUT 899527 non-null float64\n", + " 10 WHEELS_OFF 899527 non-null float64\n", + " 11 CRS_ELAPSED_TIME 899527 non-null float64\n", + " 12 PCT_ELAPSED_TIME 899527 non-null float64\n", + " 13 DISTANCE 899527 non-null float64\n", + " 14 CRS_ARR_TIME 899527 non-null int64 \n", + " 15 ARR_AFPH 899527 non-null float64\n", + " 16 ARR_RFPH 899527 non-null float64\n", + " 17 ARR_DELAY 899527 non-null float64\n", + " 18 CARRIER_DELAY 899527 non-null float64\n", + " 19 WEATHER_DELAY 899527 non-null float64\n", + " 20 NAS_DELAY 899527 non-null float64\n", + " 21 SECURITY_DELAY 899527 non-null float64\n", + " 22 LATE_AIRCRAFT_DELAY 899527 non-null float64\n", + "dtypes: float64(17), int64(3), object(3)\n", + "memory usage: 157.8+ MB\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "orig_df = pd.read_csv(\"data/aa-domestic-delays-2018.csv\")\n", + "orig_df.info()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Hen4OWkxSEsb" + }, + "source": [ + "#### Подготовка данных и конструирование признаков" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "id": "qnyD6ZeLSGwL" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
CRS_DEP_TIMEDEP_TIMEDEP_DELAYDEP_AFPHDEP_RFPHTAXI_OUTWHEELS_OFFCRS_ELAPSED_TIMEPCT_ELAPSED_TIMEDISTANCE...ARR_RFPHCARRIER_DELAYWEATHER_DELAYNAS_DELAYSECURITY_DELAYLATE_AIRCRAFT_DELAYDEP_MONTHDEP_DOWORIGIN_HUBDEST_HUB
011551149.0-6.034.4444440.95679014.01203.0219.00.9634701192.0...0.8545730.00.00.00.00.01011
1705700.0-5.017.4545450.24242416.0716.0171.00.9181291192.0...0.7317070.00.00.00.00.01011
211481145.0-3.094.7368420.94736814.01159.0212.00.9716981558.0...1.0924370.00.00.00.00.01001
3825824.0-1.033.5593220.86049516.0840.0271.00.9188191558.0...0.8673790.00.00.00.00.01010
411551147.0-8.033.4615380.92948713.01200.099.00.969697331.0...1.0068030.00.00.00.00.01011
..................................................................
89952215341530.0-4.035.3571430.82225920.01550.0100.00.990000331.0...0.8379450.00.00.00.00.012011
89952317511757.06.071.8181821.04084318.01815.0181.00.972376936.0...0.6976740.00.00.00.00.012011
89952420152010.0-5.063.2727271.19382536.02046.0112.01.142857511.0...0.4828970.00.00.00.00.012010
89952513001323.023.070.8433730.77003711.01334.050.00.820000130.0...0.8880310.00.00.00.00.012010
89952614351443.08.019.4117650.9243708.01451.071.00.830986130.0...1.0119050.00.00.00.00.012001
\n", + "

899527 rows × 22 columns

\n", + "
" + ], + "text/plain": [ + " CRS_DEP_TIME DEP_TIME DEP_DELAY DEP_AFPH DEP_RFPH TAXI_OUT \\\n", + "0 1155 1149.0 -6.0 34.444444 0.956790 14.0 \n", + "1 705 700.0 -5.0 17.454545 0.242424 16.0 \n", + "2 1148 1145.0 -3.0 94.736842 0.947368 14.0 \n", + "3 825 824.0 -1.0 33.559322 0.860495 16.0 \n", + "4 1155 1147.0 -8.0 33.461538 0.929487 13.0 \n", + "... ... ... ... ... ... ... \n", + "899522 1534 1530.0 -4.0 35.357143 0.822259 20.0 \n", + "899523 1751 1757.0 6.0 71.818182 1.040843 18.0 \n", + "899524 2015 2010.0 -5.0 63.272727 1.193825 36.0 \n", + "899525 1300 1323.0 23.0 70.843373 0.770037 11.0 \n", + "899526 1435 1443.0 8.0 19.411765 0.924370 8.0 \n", + "\n", + " WHEELS_OFF CRS_ELAPSED_TIME PCT_ELAPSED_TIME DISTANCE ... \\\n", + "0 1203.0 219.0 0.963470 1192.0 ... \n", + "1 716.0 171.0 0.918129 1192.0 ... \n", + "2 1159.0 212.0 0.971698 1558.0 ... \n", + "3 840.0 271.0 0.918819 1558.0 ... \n", + "4 1200.0 99.0 0.969697 331.0 ... \n", + "... ... ... ... ... ... \n", + "899522 1550.0 100.0 0.990000 331.0 ... \n", + "899523 1815.0 181.0 0.972376 936.0 ... \n", + "899524 2046.0 112.0 1.142857 511.0 ... \n", + "899525 1334.0 50.0 0.820000 130.0 ... \n", + "899526 1451.0 71.0 0.830986 130.0 ... \n", + "\n", + " ARR_RFPH CARRIER_DELAY WEATHER_DELAY NAS_DELAY SECURITY_DELAY \\\n", + "0 0.854573 0.0 0.0 0.0 0.0 \n", + "1 0.731707 0.0 0.0 0.0 0.0 \n", + "2 1.092437 0.0 0.0 0.0 0.0 \n", + "3 0.867379 0.0 0.0 0.0 0.0 \n", + "4 1.006803 0.0 0.0 0.0 0.0 \n", + "... ... ... ... ... ... \n", + "899522 0.837945 0.0 0.0 0.0 0.0 \n", + "899523 0.697674 0.0 0.0 0.0 0.0 \n", + "899524 0.482897 0.0 0.0 0.0 0.0 \n", + "899525 0.888031 0.0 0.0 0.0 0.0 \n", + "899526 1.011905 0.0 0.0 0.0 0.0 \n", + "\n", + " LATE_AIRCRAFT_DELAY DEP_MONTH DEP_DOW ORIGIN_HUB DEST_HUB \n", + "0 0.0 1 0 1 1 \n", + "1 0.0 1 0 1 1 \n", + "2 0.0 1 0 0 1 \n", + "3 0.0 1 0 1 0 \n", + "4 0.0 1 0 1 1 \n", + "... ... ... ... ... ... \n", + "899522 0.0 12 0 1 1 \n", + "899523 0.0 12 0 1 1 \n", + "899524 0.0 12 0 1 0 \n", + "899525 0.0 12 0 1 0 \n", + "899526 0.0 12 0 0 1 \n", + "\n", + "[899527 rows x 22 columns]" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = orig_df.copy()\n", + "# Преобразование даты из строки в datetime\n", + "df[\"PLANNED_DEP_DATETIME\"] = pd.to_datetime(df[\"PLANNED_DEP_DATETIME\"])\n", + "# Получение месяца и дня недели вылета из даты для учета сезонности и особенностей дня недели\n", + "df[\"DEP_MONTH\"] = df[\"PLANNED_DEP_DATETIME\"].dt.month\n", + "df[\"DEP_DOW\"] = df[\"PLANNED_DEP_DATETIME\"].dt.dayofweek\n", + "# Удаление столбца с датой\n", + "df = df.drop([\"PLANNED_DEP_DATETIME\"], axis=1)\n", + "# Список аэропортов-хабов\n", + "hubs = [\"CLT\", \"ORD\", \"DFW\", \"LAX\", \"MIA\", \"JFK\", \"LGA\", \"PHL\", \"PHX\", \"DCA\"]\n", + "# Определение признака хаба для аэропортов вылета и назначения\n", + "is_origin_hub = df[\"ORIGIN\"].isin(hubs)\n", + "is_dest_hub = df[\"DEST\"].isin(hubs)\n", + "# Установка признака хаба для данных\n", + "df[\"ORIGIN_HUB\"] = 0\n", + "df.loc[is_origin_hub, \"ORIGIN_HUB\"] = 1\n", + "df[\"DEST_HUB\"] = 0\n", + "df.loc[is_dest_hub, \"DEST_HUB\"] = 1\n", + "# Удаление лишних столбцов\n", + "df = df.drop([\"FL_NUM\", \"ORIGIN\", \"DEST\"], axis=1)\n", + "# Удаление столбца с общим временем задержки прибытия, так как данные значения будут иметь сильное влияние на результат\n", + "df = df.drop([\"ARR_DELAY\"], axis=1)\n", + "df" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "y6f3Z4UwUAUI" + }, + "source": [ + "#### Формирование тестовой и обучающей выборок данных" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "id": "6LudeeYUUEPt" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
CRS_DEP_TIMEDEP_TIMEDEP_DELAYDEP_AFPHDEP_RFPHTAXI_OUTWHEELS_OFFCRS_ELAPSED_TIMEPCT_ELAPSED_TIMEDISTANCE...ARR_AFPHARR_RFPHWEATHER_DELAYNAS_DELAYSECURITY_DELAYLATE_AIRCRAFT_DELAYDEP_MONTHDEP_DOWORIGIN_HUBDEST_HUB
31121845842.0-3.016.8421050.44321321.0903.0106.00.886792331.0...85.3333331.1454140.00.00.00.01611
63350013151316.01.04.9180330.9836079.01325.0121.01.107438624.0...111.8918921.2861140.00.00.00.09401
74773717101704.0-6.055.5555561.02880714.01718.067.01.074627304.0...17.2881360.7516580.00.00.00.010110
29894318401920.040.028.2000000.58750018.01938.0161.00.888199852.0...38.7804881.15762722.00.00.00.05501
84393218301822.0-8.028.8461540.90144213.01835.0215.00.9302331192.0...90.8108111.1211210.00.00.00.012511
..................................................................
72082213591410.011.053.2394370.99513026.01436.0147.01.006803814.0...25.0000001.1363640.00.00.00.010410
45925322092207.0-2.080.6896551.02138816.02223.086.00.918605413.0...2.3529411.1764710.00.00.00.07610
711294530523.0-7.012.4528300.83018917.0540.0119.00.957983666.0...37.5000000.8928570.00.00.00.010101
872796709706.0-3.0120.0000001.03004312.0718.0169.01.2958581120.0...12.3364490.8507900.047.00.00.012310
51647819252030.065.014.8800001.3527278.02038.0135.00.977778761.0...93.4426231.1124120.00.00.062.07001
\n", + "

764597 rows × 21 columns

\n", + "
" + ], + "text/plain": [ + " CRS_DEP_TIME DEP_TIME DEP_DELAY DEP_AFPH DEP_RFPH TAXI_OUT \\\n", + "31121 845 842.0 -3.0 16.842105 0.443213 21.0 \n", + "633500 1315 1316.0 1.0 4.918033 0.983607 9.0 \n", + "747737 1710 1704.0 -6.0 55.555556 1.028807 14.0 \n", + "298943 1840 1920.0 40.0 28.200000 0.587500 18.0 \n", + "843932 1830 1822.0 -8.0 28.846154 0.901442 13.0 \n", + "... ... ... ... ... ... ... \n", + "720822 1359 1410.0 11.0 53.239437 0.995130 26.0 \n", + "459253 2209 2207.0 -2.0 80.689655 1.021388 16.0 \n", + "711294 530 523.0 -7.0 12.452830 0.830189 17.0 \n", + "872796 709 706.0 -3.0 120.000000 1.030043 12.0 \n", + "516478 1925 2030.0 65.0 14.880000 1.352727 8.0 \n", + "\n", + " WHEELS_OFF CRS_ELAPSED_TIME PCT_ELAPSED_TIME DISTANCE ... \\\n", + "31121 903.0 106.0 0.886792 331.0 ... \n", + "633500 1325.0 121.0 1.107438 624.0 ... \n", + "747737 1718.0 67.0 1.074627 304.0 ... \n", + "298943 1938.0 161.0 0.888199 852.0 ... \n", + "843932 1835.0 215.0 0.930233 1192.0 ... \n", + "... ... ... ... ... ... \n", + "720822 1436.0 147.0 1.006803 814.0 ... \n", + "459253 2223.0 86.0 0.918605 413.0 ... \n", + "711294 540.0 119.0 0.957983 666.0 ... \n", + "872796 718.0 169.0 1.295858 1120.0 ... \n", + "516478 2038.0 135.0 0.977778 761.0 ... \n", + "\n", + " ARR_AFPH ARR_RFPH WEATHER_DELAY NAS_DELAY SECURITY_DELAY \\\n", + "31121 85.333333 1.145414 0.0 0.0 0.0 \n", + "633500 111.891892 1.286114 0.0 0.0 0.0 \n", + "747737 17.288136 0.751658 0.0 0.0 0.0 \n", + "298943 38.780488 1.157627 22.0 0.0 0.0 \n", + "843932 90.810811 1.121121 0.0 0.0 0.0 \n", + "... ... ... ... ... ... \n", + "720822 25.000000 1.136364 0.0 0.0 0.0 \n", + "459253 2.352941 1.176471 0.0 0.0 0.0 \n", + "711294 37.500000 0.892857 0.0 0.0 0.0 \n", + "872796 12.336449 0.850790 0.0 47.0 0.0 \n", + "516478 93.442623 1.112412 0.0 0.0 0.0 \n", + "\n", + " LATE_AIRCRAFT_DELAY DEP_MONTH DEP_DOW ORIGIN_HUB DEST_HUB \n", + "31121 0.0 1 6 1 1 \n", + "633500 0.0 9 4 0 1 \n", + "747737 0.0 10 1 1 0 \n", + "298943 0.0 5 5 0 1 \n", + "843932 0.0 12 5 1 1 \n", + "... ... ... ... ... ... \n", + "720822 0.0 10 4 1 0 \n", + "459253 0.0 7 6 1 0 \n", + "711294 0.0 10 1 0 1 \n", + "872796 0.0 12 3 1 0 \n", + "516478 62.0 7 0 0 1 \n", + "\n", + "[764597 rows x 21 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "31121 0.0\n", + "633500 0.0\n", + "747737 0.0\n", + "298943 0.0\n", + "843932 0.0\n", + " ... \n", + "720822 0.0\n", + "459253 0.0\n", + "711294 0.0\n", + "872796 0.0\n", + "516478 0.0\n", + "Name: CARRIER_DELAY, Length: 764597, dtype: float64" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "31121 0\n", + "633500 0\n", + "747737 0\n", + "298943 0\n", + "843932 0\n", + " ..\n", + "720822 0\n", + "459253 0\n", + "711294 0\n", + "872796 0\n", + "516478 0\n", + "Name: CARRIER_DELAY, Length: 764597, dtype: int64" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Int64Index: 764597 entries, 31121 to 516478\n", + "Data columns (total 21 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 CRS_DEP_TIME 764597 non-null int64 \n", + " 1 DEP_TIME 764597 non-null float64\n", + " 2 DEP_DELAY 764597 non-null float64\n", + " 3 DEP_AFPH 764597 non-null float64\n", + " 4 DEP_RFPH 764597 non-null float64\n", + " 5 TAXI_OUT 764597 non-null float64\n", + " 6 WHEELS_OFF 764597 non-null float64\n", + " 7 CRS_ELAPSED_TIME 764597 non-null float64\n", + " 8 PCT_ELAPSED_TIME 764597 non-null float64\n", + " 9 DISTANCE 764597 non-null float64\n", + " 10 CRS_ARR_TIME 764597 non-null int64 \n", + " 11 ARR_AFPH 764597 non-null float64\n", + " 12 ARR_RFPH 764597 non-null float64\n", + " 13 WEATHER_DELAY 764597 non-null float64\n", + " 14 NAS_DELAY 764597 non-null float64\n", + " 15 SECURITY_DELAY 764597 non-null float64\n", + " 16 LATE_AIRCRAFT_DELAY 764597 non-null float64\n", + " 17 DEP_MONTH 764597 non-null int64 \n", + " 18 DEP_DOW 764597 non-null int64 \n", + " 19 ORIGIN_HUB 764597 non-null int64 \n", + " 20 DEST_HUB 764597 non-null int64 \n", + "dtypes: float64(15), int64(6)\n", + "memory usage: 128.3 MB\n" + ] + } + ], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "\n", + "# Задание фиксированного случайного состояния для воспроизводимости результатов\n", + "rand = 9\n", + "# Выделение признака, который модель должна предсказать\n", + "y = df[\"CARRIER_DELAY\"]\n", + "# Формирование множества признаков, на основе которых модель будет обучаться (удаление столбца с y)\n", + "X = df.drop([\"CARRIER_DELAY\"], axis=1).copy()\n", + "X_train, X_test, y_train_reg, y_test_reg = train_test_split(\n", + " X, y, test_size=0.15, random_state=rand\n", + ")\n", + "# Создание классов для классификаторов в виде двоичных меток (опоздание свыше 15 минут - 1, иначе - 0)\n", + "y_train_class = y_train_reg.apply(lambda x: 1 if x > 15 else 0)\n", + "y_test_class = y_test_reg.apply(lambda x: 1 if x > 15 else 0)\n", + "\n", + "display(X_train)\n", + "display(y_train_reg)\n", + "display(y_train_class)\n", + "X_train.info()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aVjZol-yrYbH" + }, + "source": [ + "#### Определение линейной корреляции признаков с целевым признаком с помощью корреляции Пирсона" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "_1e7XE_Orf0r", + "outputId": "efbb620c-c92b-481e-d613-a34633a4ba86" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "CARRIER_DELAY 1.000000\n", + "DEP_DELAY 0.703935\n", + "ARR_RFPH 0.101742\n", + "LATE_AIRCRAFT_DELAY 0.083166\n", + "DEP_RFPH 0.058659\n", + "ARR_AFPH 0.035135\n", + "DEP_TIME 0.030941\n", + "NAS_DELAY 0.026792\n", + "WHEELS_OFF 0.026787\n", + "TAXI_OUT 0.024635\n", + "PCT_ELAPSED_TIME 0.020980\n", + "CRS_DEP_TIME 0.016032\n", + "ORIGIN_HUB 0.015334\n", + "DEST_HUB 0.013932\n", + "DISTANCE 0.010680\n", + "DEP_MONTH 0.009728\n", + "CRS_ELAPSED_TIME 0.008801\n", + "DEP_DOW 0.007043\n", + "CRS_ARR_TIME 0.007029\n", + "DEP_AFPH 0.006053\n", + "WEATHER_DELAY 0.003002\n", + "SECURITY_DELAY 0.000460\n", + "Name: CARRIER_DELAY, dtype: float64" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "corr = df.corr()\n", + "abs(corr[\"CARRIER_DELAY\"]).sort_values(ascending=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xxPlvpOdrnY_" + }, + "source": [ + "#### Использование регрессионных моделей для предсказания задержки рейса" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "id": "2m81QO87rzal" + }, + "outputs": [], + "source": [ + "from sklearn.pipeline import make_pipeline\n", + "from sklearn.preprocessing import PolynomialFeatures, StandardScaler\n", + "from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n", + "\n", + "reg_models = {\n", + " # Обобщенные линейные модели (GLM-модели)\n", + " \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n", + " \"linear_poly\": {\n", + " \"model\": make_pipeline(\n", + " PolynomialFeatures(degree=2, interaction_only=False),\n", + " linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n", + " memory=None\n", + " )\n", + " },\n", + " \"linear_interact\": {\n", + " \"model\": make_pipeline(\n", + " PolynomialFeatures(degree=2, interaction_only=True),\n", + " linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n", + " memory=None\n", + " )\n", + " },\n", + " \"ridge\": {\"model\": linear_model.RidgeCV()},\n", + " # Деревья\n", + " \"decision_tree\": {\n", + " \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=rand)\n", + " },\n", + " # Ближайшие соседи\n", + " \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n", + " # Ансамблевые методы\n", + " \"random_forest\": {\n", + " \"model\": ensemble.RandomForestRegressor(\n", + " max_depth=7, random_state=rand, n_jobs=-1\n", + " )\n", + " },\n", + " # Нейронные сети\n", + " \"mlp\": {\n", + " \"model\": neural_network.MLPRegressor(\n", + " hidden_layer_sizes=(21,),\n", + " max_iter=500,\n", + " early_stopping=True,\n", + " random_state=rand,\n", + " )\n", + " },\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DFJudSP_tFec" + }, + "source": [ + "#### Обучение и оценка регрессионных моделей" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "XZEDVAR0tIzN", + "outputId": "4f1f55ff-85dd-49d9-ddaf-3758419bdd6a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: linear\n", + "Model: linear_poly\n", + "Model: linear_interact\n", + "Model: ridge\n", + "Model: decision_tree\n", + "Model: knn\n", + "Model: random_forest\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "0.00s - Debugger warning: It seems that frozen modules are being used, which may\n", + "0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off\n", + "0.00s - to python to disable frozen modules.\n", + "0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: mlp\n" + ] + } + ], + "source": [ + "import math\n", + "from sklearn import metrics\n", + "\n", + "for model_name in reg_models.keys():\n", + " print(f'Model: {model_name}')\n", + " fitted_model = reg_models[model_name][\"model\"].fit(\n", + " X_train.values, y_train_reg.to_numpy().ravel()\n", + " )\n", + " y_train_pred = fitted_model.predict(X_train.values)\n", + " y_test_pred = fitted_model.predict(X_test.values)\n", + " reg_models[model_name][\"fitted\"] = fitted_model\n", + " reg_models[model_name][\"preds\"] = y_test_pred\n", + " reg_models[model_name][\"RMSE_train\"] = math.sqrt(\n", + " metrics.mean_squared_error(y_train_reg, y_train_pred)\n", + " )\n", + " reg_models[model_name][\"RMSE_test\"] = math.sqrt(\n", + " metrics.mean_squared_error(y_test_reg, y_test_pred)\n", + " )\n", + " reg_models[model_name][\"R2_test\"] = metrics.r2_score(y_test_reg, y_test_pred)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6GD_HZhGHPXK" + }, + "source": [ + "#### Вывод оценки в виде таблицы" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "id": "lvcbKDfmHQ6p" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
 RMSE_trainRMSE_testR2_test
mlp3.2435163.3085970.987025
random_forest5.1432676.0882490.956065
linear_poly6.2140106.3398430.952359
linear_interact6.4543146.5622840.948957
decision_tree6.5429247.4563350.934102
linear7.8196437.8828750.926347
ridge7.8320667.8981890.926060
knn7.3600989.2594220.898377
\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg_metrics = pd.DataFrame.from_dict(reg_models, \"index\")[\n", + " [\"RMSE_train\", \"RMSE_test\", \"R2_test\"]\n", + "]\n", + "reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n", + " cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n", + ").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"R2_test\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MZrTDbnjJMrB" + }, + "source": [ + "#### Использование классификаторов для предсказания задержки рейса" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "id": "c_1jMq5IJMSL" + }, + "outputs": [], + "source": [ + "from sklearn import naive_bayes\n", + "\n", + "class_models = {\n", + " # Обобщенные линейные модели (GLM-модели)\n", + " \"logistic\": {\"model\": linear_model.LogisticRegression()},\n", + " \"ridge\": {\n", + " \"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")\n", + " },\n", + " # Дерево\n", + " \"decision_tree\": {\n", + " \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=rand)\n", + " },\n", + " # Ближайшие соседи\n", + " \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n", + " # Наивный Байес\n", + " \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n", + " # Ансамблевые методы\n", + " \"gradient_boosting\": {\n", + " \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n", + " },\n", + " \"random_forest\": {\n", + " \"model\": ensemble.RandomForestClassifier(\n", + " max_depth=11, class_weight=\"balanced\", random_state=rand\n", + " )\n", + " },\n", + " # Нейронные сети\n", + " \"mlp\": {\n", + " \"model\": make_pipeline(\n", + " StandardScaler(),\n", + " neural_network.MLPClassifier(\n", + " hidden_layer_sizes=(7,),\n", + " max_iter=500,\n", + " early_stopping=True,\n", + " random_state=rand,\n", + " ),\n", + " memory=None\n", + " )\n", + " },\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rMzuL4OxKAns" + }, + "source": [ + "#### Определение сбалансированности выборки для классификации" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "id": "BOVmTCPHJ5Au" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.061283264255549" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_train_class[y_train_class == 1].shape[0] / y_train_class.shape[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "H8Z7-KugJ7xR" + }, + "source": [ + "#### Обучение и оценка классификаторов" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "id": "RyR1_5m4KBTs" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: logistic\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/user/Projects/python/ckexp/.venv/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:465: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: ridge\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/user/Projects/python/ckexp/.venv/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:465: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: decision_tree\n", + "Model: knn\n", + "Model: naive_bayes\n", + "Model: gradient_boosting\n", + "Model: random_forest\n", + "Model: mlp\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "for model_name in class_models.keys():\n", + " print(f\"Model: {model_name}\")\n", + " fitted_model = class_models[model_name][\"model\"].fit(\n", + " X_train.values,\n", + " y_train_class.to_numpy().ravel(),\n", + " )\n", + " y_train_pred = fitted_model.predict(X_train.values)\n", + " y_test_prob = fitted_model.predict_proba(X_test.values)[:, 1]\n", + " y_test_pred = fitted_model.predict(X_test.values)\n", + "\n", + " class_models[model_name][\"fitted\"] = fitted_model\n", + " class_models[model_name][\"probs\"] = y_test_prob\n", + " class_models[model_name][\"preds\"] = y_test_pred\n", + "\n", + " class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n", + " y_train_class, y_train_pred\n", + " )\n", + " class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n", + " y_test_class, y_test_pred\n", + " )\n", + " class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n", + " y_train_class, y_train_pred\n", + " )\n", + " class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n", + " y_test_class, y_test_pred\n", + " )\n", + " class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n", + " y_test_class, y_test_prob\n", + " )\n", + " class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test_class, y_test_pred)\n", + " class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n", + " y_test_class, y_test_pred\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NzCyXXXFKdx2" + }, + "source": [ + "#### Вывод оценки в виде таблицы" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "id": "VOhaFiEYKeN5" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
 Accuracy_trainAccuracy_testRecall_trainRecall_testROC_AUC_testF1_testMCC_test
mlp0.9984820.9985550.9871310.9888650.9998770.9882070.987437
gradient_boosting0.9917250.9916620.8929300.8938510.9988850.9292230.925619
random_forest0.9411660.9403250.9995520.9923750.9951450.6706750.685702
decision_tree0.9832970.9828950.8569690.8522150.9949320.8591820.850110
ridge0.9434530.9425260.9455790.9409340.9837770.6672100.673110
logistic0.9750540.9750310.6832920.6808280.9602870.7695460.763854
knn0.9728860.9651230.6806450.6077220.9483870.6809060.668176
naive_bayes0.9251190.9255390.2791260.2742680.8118690.3108580.274984
\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n", + " [\n", + " \"Accuracy_train\",\n", + " \"Accuracy_test\",\n", + " \"Recall_train\",\n", + " \"Recall_test\",\n", + " \"ROC_AUC_test\",\n", + " \"F1_test\",\n", + " \"MCC_test\",\n", + " ]\n", + "]\n", + "class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n", + " cmap=\"plasma\", low=0.3, high=1, subset=[\"Accuracy_train\", \"Accuracy_test\"]\n", + ").background_gradient(\n", + " cmap=\"viridis\",\n", + " low=1,\n", + " high=0.3,\n", + " subset=[\"Recall_train\", \"Recall_test\", \"ROC_AUC_test\", \"F1_test\", \"MCC_test\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Интерпретация результатов для моделей на основе \"белого ящика\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Линейная регрессия" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "coefficients:\t[ 4.54984539e-03 -5.25067742e-03 8.94125541e-01 -1.52961053e-02\n", + " -4.69623002e-01 1.25277815e-01 -6.46744472e-04 -1.26240049e-02\n", + " 4.50112895e+01 6.76385421e-04 -3.69920254e-04 5.47855860e-04\n", + " 3.73866548e-01 -9.06364154e-01 -6.74052666e-01 -9.17411191e-01\n", + " -9.29843952e-01 -3.96621856e-02 -1.79666480e-02 -1.02912927e+00\n", + " -3.94934854e-01]\n", + "intercept:\t-37.86177932752649\n", + "y = -37.86 + 0.0045X1 + -0.0053X2 + 0.894X3 + ...\n" + ] + } + ], + "source": [ + "coefs_lm = reg_models[\"linear\"][\"fitted\"].coef_\n", + "intercept_lm = reg_models[\"linear\"][\"fitted\"].intercept_\n", + "print(\"coefficients:\\t%s\" % coefs_lm)\n", + "print(\"intercept:\\t%s\" % intercept_lm)\n", + "print(\n", + " \"y = %0.2f + %0.4fX1 + %0.4fX2 + %0.3fX3 + ...\"\n", + " % (intercept_lm, coefs_lm[0], coefs_lm[1], coefs_lm[2])\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
featurecoef
0CRS_DEP_TIME0.004550
1DEP_TIME-0.005251
2DEP_DELAY0.894126
3DEP_AFPH-0.015296
4DEP_RFPH-0.469623
5TAXI_OUT0.125278
6WHEELS_OFF-0.000647
7CRS_ELAPSED_TIME-0.012624
8PCT_ELAPSED_TIME45.011289
9DISTANCE0.000676
10CRS_ARR_TIME-0.000370
11ARR_AFPH0.000548
12ARR_RFPH0.373867
13WEATHER_DELAY-0.906364
14NAS_DELAY-0.674053
15SECURITY_DELAY-0.917411
16LATE_AIRCRAFT_DELAY-0.929844
17DEP_MONTH-0.039662
18DEP_DOW-0.017967
19ORIGIN_HUB-1.029129
20DEST_HUB-0.394935
\n", + "
" + ], + "text/plain": [ + " feature coef\n", + "0 CRS_DEP_TIME 0.004550\n", + "1 DEP_TIME -0.005251\n", + "2 DEP_DELAY 0.894126\n", + "3 DEP_AFPH -0.015296\n", + "4 DEP_RFPH -0.469623\n", + "5 TAXI_OUT 0.125278\n", + "6 WHEELS_OFF -0.000647\n", + "7 CRS_ELAPSED_TIME -0.012624\n", + "8 PCT_ELAPSED_TIME 45.011289\n", + "9 DISTANCE 0.000676\n", + "10 CRS_ARR_TIME -0.000370\n", + "11 ARR_AFPH 0.000548\n", + "12 ARR_RFPH 0.373867\n", + "13 WEATHER_DELAY -0.906364\n", + "14 NAS_DELAY -0.674053\n", + "15 SECURITY_DELAY -0.917411\n", + "16 LATE_AIRCRAFT_DELAY -0.929844\n", + "17 DEP_MONTH -0.039662\n", + "18 DEP_DOW -0.017967\n", + "19 ORIGIN_HUB -1.029129\n", + "20 DEST_HUB -0.394935" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "coef_df = pd.DataFrame({\"feature\": X_train.columns.values.tolist(), \"coef\": coefs_lm})\n", + "display(coef_df)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
 featureCoef.Std.Err.tP>|t|[0.0250.975]t_abs
2DEP_DELAY0.8941260.0003032951.0559780.0000000.8935320.8947192951.055978
16LATE_AIRCRAFT_DELAY-0.9298440.000509-1827.0180820.000000-0.930841-0.9288461827.018082
13WEATHER_DELAY-0.9063640.000911-995.3664230.000000-0.908149-0.904579995.366423
14NAS_DELAY-0.6740530.000813-829.1286570.000000-0.675646-0.672459829.128657
8PCT_ELAPSED_TIME45.0112890.117195384.0725660.00000044.78159245.240987384.072566
15SECURITY_DELAY-0.9174110.005465-167.8570850.000000-0.928123-0.906699167.857085
5TAXI_OUT0.1252780.001203104.1195790.0000000.1229200.127636104.119579
0CRS_DEP_TIME0.0045500.00007262.8716930.0000000.0044080.00469262.871693
1DEP_TIME-0.0052510.000092-57.1158950.000000-0.005431-0.00507057.115895
3DEP_AFPH-0.0152960.000321-47.7245060.000000-0.015924-0.01466847.724506
19ORIGIN_HUB-1.0291290.026669-38.5894110.000000-1.081399-0.97686038.589411
12ARR_RFPH0.3738670.01317128.3860310.0000000.3480520.39968128.386031
4DEP_RFPH-0.4696230.017169-27.3531790.000000-0.503273-0.43597327.353179
7CRS_ELAPSED_TIME-0.0126240.000660-19.1315160.000000-0.013917-0.01133119.131516
10CRS_ARR_TIME-0.0003700.000022-16.9386610.000000-0.000413-0.00032716.938661
20DEST_HUB-0.3949350.026256-15.0414590.000000-0.446397-0.34347315.041459
17DEP_MONTH-0.0396620.002641-15.0188080.000000-0.044838-0.03448615.018808
6WHEELS_OFF-0.0006470.000067-9.6461040.000000-0.000778-0.0005159.646104
9DISTANCE0.0006760.0000808.4288350.0000000.0005190.0008348.428835
18DEP_DOW-0.0179670.004487-4.0045610.000062-0.026760-0.0091734.004561
11ARR_AFPH0.0005480.0003321.6507880.098782-0.0001030.0011981.650788
\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import statsmodels.api as sm\n", + "\n", + "linreg_mdl = sm.OLS(y_train_reg, sm.add_constant(X_train))\n", + "linreg_mdl = linreg_mdl.fit()\n", + "summary_df = linreg_mdl.summary2().tables[1]\n", + "summary_df = (\n", + " summary_df.drop([\"const\"]).reset_index().rename(columns={\"index\": \"feature\"})\n", + ")\n", + "summary_df[\"t_abs\"] = abs(summary_df[\"t\"])\n", + "summary_df.sort_values(by=\"t_abs\", ascending=False).style.background_gradient(\n", + " cmap=\"plasma_r\", low=0, high=0.1, subset=[\"P>|t|\"]\n", + ").background_gradient(cmap=\"plasma_r\", low=0, high=0.1, subset=[\"t_abs\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Гребневая регрессия" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
 featurecoef_linearcoef_ridge
0CRS_DEP_TIME0.0045500.004275
1DEP_TIME-0.005251-0.005485
2DEP_DELAY0.8941260.894229
3DEP_AFPH-0.015296-0.015304
4DEP_RFPH-0.469623-0.469623
5TAXI_OUT0.1252780.125284
6WHEELS_OFF-0.000647-0.000889
7CRS_ELAPSED_TIME-0.012624-0.012618
8PCT_ELAPSED_TIME45.01128945.010279
9DISTANCE0.0006760.000718
10CRS_ARR_TIME-0.000370-0.000546
11ARR_AFPH0.0005480.000550
12ARR_RFPH0.3738670.373865
13WEATHER_DELAY-0.906364-0.906358
14NAS_DELAY-0.674053-0.674045
15SECURITY_DELAY-0.917411-0.917411
16LATE_AIRCRAFT_DELAY-0.929844-0.929805
17DEP_MONTH-0.039662-0.039661
18DEP_DOW-0.017967-0.017967
19ORIGIN_HUB-1.029129-1.029140
20DEST_HUB-0.394935-0.394948
\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "coefs_ridge = reg_models[\"ridge\"][\"fitted\"].coef_\n", + "coef_ridge_df = pd.DataFrame(\n", + " {\n", + " \"feature\": X_train.columns.values.tolist(),\n", + " \"coef_linear\": coefs_lm,\n", + " \"coef_ridge\": coefs_ridge,\n", + " }\n", + ")\n", + "coef_ridge_df.style.background_gradient(cmap=\"viridis_r\", low=0.3, high=0.2, axis=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Полиномиальная регрессия" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "253" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "232" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display(\n", + " reg_models[\"linear_poly\"][\"fitted\"].get_params()[\"linearregression\"].coef_.shape[0]\n", + ")\n", + "display(\n", + " reg_models[\"linear_interact\"][\"fitted\"]\n", + " .get_params()[\"linearregression\"]\n", + " .coef_.shape[0]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Логистическая регрессия" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "coefficients:\t[[-0.00132811 0.00034525 0.15746107 0.00349808 -0.00215053 -0.00445293\n", + " 0.00029184 -0.05167613 -0.00175222 0.0055682 -0.00031922 -0.00757532\n", + " -0.00273998 -0.15351444 -0.12133964 -0.00595224 -0.16451117 -0.01303235\n", + " -0.0052911 0.00048854 -0.00206977]]\n", + "intercept:\t[-0.00229272]\n" + ] + }, + { + "data": { + "text/plain": [ + "DEP_DELAY 6.969920\n", + "CRS_ELAPSED_TIME 4.101834\n", + "LATE_AIRCRAFT_DELAY 4.065346\n", + "DISTANCE 3.616141\n", + "NAS_DELAY 1.672065\n", + "WEATHER_DELAY 1.604186\n", + "CRS_DEP_TIME 0.665926\n", + "ARR_AFPH 0.267888\n", + "DEP_TIME 0.177772\n", + "CRS_ARR_TIME 0.168589\n", + "WHEELS_OFF 0.150765\n", + "DEP_AFPH 0.124024\n", + "DEP_MONTH 0.044475\n", + "TAXI_OUT 0.043947\n", + "DEP_DOW 0.010574\n", + "SECURITY_DELAY 0.009756\n", + "ARR_RFPH 0.001976\n", + "DEP_RFPH 0.001215\n", + "DEST_HUB 0.001007\n", + "ORIGIN_HUB 0.000238\n", + "PCT_ELAPSED_TIME 0.000185\n", + "dtype: float64" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "coefs_log = class_models[\"logistic\"][\"fitted\"].coef_\n", + "intercept_log = class_models[\"logistic\"][\"fitted\"].intercept_\n", + "print(\"coefficients:\\t%s\" % coefs_log)\n", + "print(\"intercept:\\t%s\" % intercept_log)\n", + "stdv = np.std(X_train, 0)\n", + "abs(\n", + " coefs_log.reshape(\n", + " 21,\n", + " )\n", + " * stdv\n", + ").sort_values(ascending=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Дерево решений" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "|--- DEP_DELAY <= 20.50\n", + "| |--- DEP_DELAY <= 15.50\n", + "| | |--- class: 0\n", + "| |--- DEP_DELAY > 15.50\n", + "| | |--- PCT_ELAPSED_TIME <= 0.99\n", + "| | | |--- PCT_ELAPSED_TIME <= 0.98\n", + "| | | | |--- PCT_ELAPSED_TIME <= 0.96\n", + "| | | | | |--- CRS_ELAPSED_TIME <= 65.50\n", + "| | | | | | |--- PCT_ELAPSED_TIME <= 0.94\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- PCT_ELAPSED_TIME > 0.94\n", + "| | | | | | | |--- class: 0\n", + "| | | | | |--- CRS_ELAPSED_TIME > 65.50\n", + "| | | | | | |--- PCT_ELAPSED_TIME <= 0.95\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- PCT_ELAPSED_TIME > 0.95\n", + "| | | | | | | |--- class: 0\n", + "| | | | |--- PCT_ELAPSED_TIME > 0.96\n", + "| | | | | |--- CRS_ELAPSED_TIME <= 140.50\n", + "| | | | | | |--- DEP_DELAY <= 18.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- DEP_DELAY > 18.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | |--- CRS_ELAPSED_TIME > 140.50\n", + "| | | | | | |--- DEP_DELAY <= 19.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- DEP_DELAY > 19.50\n", + "| | | | | | | |--- class: 0\n", + "| | | |--- PCT_ELAPSED_TIME > 0.98\n", + "| | | | |--- DEP_DELAY <= 18.50\n", + "| | | | | |--- DISTANCE <= 326.50\n", + "| | | | | | |--- LATE_AIRCRAFT_DELAY <= 0.50\n", + "| | | | | | | |--- class: 1\n", + "| | | | | | |--- LATE_AIRCRAFT_DELAY > 0.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | |--- DISTANCE > 326.50\n", + "| | | | | | |--- DEP_DELAY <= 17.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- DEP_DELAY > 17.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | |--- DEP_DELAY > 18.50\n", + "| | | | | |--- LATE_AIRCRAFT_DELAY <= 1.50\n", + "| | | | | | |--- DISTANCE <= 1358.50\n", + "| | | | | | | |--- class: 1\n", + "| | | | | | |--- DISTANCE > 1358.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | |--- LATE_AIRCRAFT_DELAY > 1.50\n", + "| | | | | | |--- class: 0\n", + "| | |--- PCT_ELAPSED_TIME > 0.99\n", + "| | | |--- LATE_AIRCRAFT_DELAY <= 1.50\n", + "| | | | |--- WEATHER_DELAY <= 2.00\n", + "| | | | | |--- NAS_DELAY <= 17.50\n", + "| | | | | | |--- PCT_ELAPSED_TIME <= 1.00\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- PCT_ELAPSED_TIME > 1.00\n", + "| | | | | | | |--- class: 1\n", + "| | | | | |--- NAS_DELAY > 17.50\n", + "| | | | | | |--- PCT_ELAPSED_TIME <= 1.09\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- PCT_ELAPSED_TIME > 1.09\n", + "| | | | | | | |--- class: 1\n", + "| | | | |--- WEATHER_DELAY > 2.00\n", + "| | | | | |--- class: 0\n", + "| | | |--- LATE_AIRCRAFT_DELAY > 1.50\n", + "| | | | |--- LATE_AIRCRAFT_DELAY <= 3.50\n", + "| | | | | |--- DEP_DELAY <= 18.50\n", + "| | | | | | |--- DISTANCE <= 153.50\n", + "| | | | | | | |--- class: 1\n", + "| | | | | | |--- DISTANCE > 153.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | |--- DEP_DELAY > 18.50\n", + "| | | | | | |--- WEATHER_DELAY <= 2.50\n", + "| | | | | | | |--- class: 1\n", + "| | | | | | |--- WEATHER_DELAY > 2.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | |--- LATE_AIRCRAFT_DELAY > 3.50\n", + "| | | | | |--- LATE_AIRCRAFT_DELAY <= 4.50\n", + "| | | | | | |--- DEP_DELAY <= 19.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- DEP_DELAY > 19.50\n", + "| | | | | | | |--- class: 1\n", + "| | | | | |--- LATE_AIRCRAFT_DELAY > 4.50\n", + "| | | | | | |--- class: 0\n", + "|--- DEP_DELAY > 20.50\n", + "| |--- LATE_AIRCRAFT_DELAY <= 11.50\n", + "| | |--- NAS_DELAY <= 27.50\n", + "| | | |--- DEP_DELAY <= 35.50\n", + "| | | | |--- PCT_ELAPSED_TIME <= 0.96\n", + "| | | | | |--- DEP_DELAY <= 28.50\n", + "| | | | | | |--- PCT_ELAPSED_TIME <= 0.93\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- PCT_ELAPSED_TIME > 0.93\n", + "| | | | | | | |--- class: 0\n", + "| | | | | |--- DEP_DELAY > 28.50\n", + "| | | | | | |--- PCT_ELAPSED_TIME <= 0.92\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- PCT_ELAPSED_TIME > 0.92\n", + "| | | | | | | |--- class: 1\n", + "| | | | |--- PCT_ELAPSED_TIME > 0.96\n", + "| | | | | |--- WEATHER_DELAY <= 4.50\n", + "| | | | | | |--- LATE_AIRCRAFT_DELAY <= 6.50\n", + "| | | | | | | |--- class: 1\n", + "| | | | | | |--- LATE_AIRCRAFT_DELAY > 6.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | |--- WEATHER_DELAY > 4.50\n", + "| | | | | | |--- WEATHER_DELAY <= 10.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- WEATHER_DELAY > 10.50\n", + "| | | | | | | |--- class: 0\n", + "| | | |--- DEP_DELAY > 35.50\n", + "| | | | |--- WEATHER_DELAY <= 16.50\n", + "| | | | | |--- DEP_DELAY <= 44.50\n", + "| | | | | | |--- PCT_ELAPSED_TIME <= 0.93\n", + "| | | | | | | |--- class: 1\n", + "| | | | | | |--- PCT_ELAPSED_TIME > 0.93\n", + "| | | | | | | |--- class: 1\n", + "| | | | | |--- DEP_DELAY > 44.50\n", + "| | | | | | |--- SECURITY_DELAY <= 20.50\n", + "| | | | | | | |--- class: 1\n", + "| | | | | | |--- SECURITY_DELAY > 20.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | |--- WEATHER_DELAY > 16.50\n", + "| | | | | |--- WEATHER_DELAY <= 23.50\n", + "| | | | | | |--- DEP_DELAY <= 57.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- DEP_DELAY > 57.50\n", + "| | | | | | | |--- class: 1\n", + "| | | | | |--- WEATHER_DELAY > 23.50\n", + "| | | | | | |--- DEP_DELAY <= 88.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- DEP_DELAY > 88.50\n", + "| | | | | | | |--- class: 0\n", + "| | |--- NAS_DELAY > 27.50\n", + "| | | |--- PCT_ELAPSED_TIME <= 1.11\n", + "| | | | |--- NAS_DELAY <= 31.50\n", + "| | | | | |--- PCT_ELAPSED_TIME <= 1.07\n", + "| | | | | | |--- DEP_DELAY <= 69.00\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- DEP_DELAY > 69.00\n", + "| | | | | | | |--- class: 1\n", + "| | | | | |--- PCT_ELAPSED_TIME > 1.07\n", + "| | | | | | |--- WEATHER_DELAY <= 10.50\n", + "| | | | | | | |--- class: 1\n", + "| | | | | | |--- WEATHER_DELAY > 10.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | |--- NAS_DELAY > 31.50\n", + "| | | | | |--- DEP_DELAY <= 471.50\n", + "| | | | | | |--- CRS_ELAPSED_TIME <= 420.00\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- CRS_ELAPSED_TIME > 420.00\n", + "| | | | | | | |--- class: 1\n", + "| | | | | |--- DEP_DELAY > 471.50\n", + "| | | | | | |--- NAS_DELAY <= 388.00\n", + "| | | | | | | |--- class: 1\n", + "| | | | | | |--- NAS_DELAY > 388.00\n", + "| | | | | | | |--- class: 0\n", + "| | | |--- PCT_ELAPSED_TIME > 1.11\n", + "| | | | |--- NAS_DELAY <= 64.50\n", + "| | | | | |--- WEATHER_DELAY <= 20.50\n", + "| | | | | | |--- DEP_DELAY <= 43.50\n", + "| | | | | | | |--- class: 1\n", + "| | | | | | |--- DEP_DELAY > 43.50\n", + "| | | | | | | |--- class: 1\n", + "| | | | | |--- WEATHER_DELAY > 20.50\n", + "| | | | | | |--- WHEELS_OFF <= 36.00\n", + "| | | | | | | |--- class: 1\n", + "| | | | | | |--- WHEELS_OFF > 36.00\n", + "| | | | | | | |--- class: 0\n", + "| | | | |--- NAS_DELAY > 64.50\n", + "| | | | | |--- PCT_ELAPSED_TIME <= 1.44\n", + "| | | | | | |--- NAS_DELAY <= 78.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- NAS_DELAY > 78.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | |--- PCT_ELAPSED_TIME > 1.44\n", + "| | | | | | |--- NAS_DELAY <= 119.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- NAS_DELAY > 119.50\n", + "| | | | | | | |--- class: 0\n", + "| |--- LATE_AIRCRAFT_DELAY > 11.50\n", + "| | |--- DEP_DELAY <= 75.50\n", + "| | | |--- DEP_DELAY <= 41.50\n", + "| | | | |--- LATE_AIRCRAFT_DELAY <= 14.50\n", + "| | | | | |--- DEP_DELAY <= 29.50\n", + "| | | | | | |--- DEP_DELAY <= 27.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- DEP_DELAY > 27.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | |--- DEP_DELAY > 29.50\n", + "| | | | | | |--- PCT_ELAPSED_TIME <= 0.97\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- PCT_ELAPSED_TIME > 0.97\n", + "| | | | | | | |--- class: 1\n", + "| | | | |--- LATE_AIRCRAFT_DELAY > 14.50\n", + "| | | | | |--- LATE_AIRCRAFT_DELAY <= 20.50\n", + "| | | | | | |--- DEP_DELAY <= 32.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- DEP_DELAY > 32.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | |--- LATE_AIRCRAFT_DELAY > 20.50\n", + "| | | | | | |--- DEP_DELAY <= 38.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- DEP_DELAY > 38.50\n", + "| | | | | | | |--- class: 0\n", + "| | | |--- DEP_DELAY > 41.50\n", + "| | | | |--- LATE_AIRCRAFT_DELAY <= 29.50\n", + "| | | | | |--- PCT_ELAPSED_TIME <= 0.94\n", + "| | | | | | |--- DEP_DELAY <= 55.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- DEP_DELAY > 55.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | |--- PCT_ELAPSED_TIME > 0.94\n", + "| | | | | | |--- WEATHER_DELAY <= 0.50\n", + "| | | | | | | |--- class: 1\n", + "| | | | | | |--- WEATHER_DELAY > 0.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | |--- LATE_AIRCRAFT_DELAY > 29.50\n", + "| | | | | |--- LATE_AIRCRAFT_DELAY <= 38.50\n", + "| | | | | | |--- DEP_DELAY <= 59.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- DEP_DELAY > 59.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | |--- LATE_AIRCRAFT_DELAY > 38.50\n", + "| | | | | | |--- DEP_DELAY <= 60.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- DEP_DELAY > 60.50\n", + "| | | | | | | |--- class: 0\n", + "| | |--- DEP_DELAY > 75.50\n", + "| | | |--- LATE_AIRCRAFT_DELAY <= 60.50\n", + "| | | | |--- WEATHER_DELAY <= 0.50\n", + "| | | | | |--- NAS_DELAY <= 38.50\n", + "| | | | | | |--- DEP_DELAY <= 88.50\n", + "| | | | | | | |--- class: 1\n", + "| | | | | | |--- DEP_DELAY > 88.50\n", + "| | | | | | | |--- class: 1\n", + "| | | | | |--- NAS_DELAY > 38.50\n", + "| | | | | | |--- TAXI_OUT <= 63.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- TAXI_OUT > 63.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | |--- WEATHER_DELAY > 0.50\n", + "| | | | | |--- WEATHER_DELAY <= 18.50\n", + "| | | | | | |--- LATE_AIRCRAFT_DELAY <= 31.50\n", + "| | | | | | | |--- class: 1\n", + "| | | | | | |--- LATE_AIRCRAFT_DELAY > 31.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | |--- WEATHER_DELAY > 18.50\n", + "| | | | | | |--- DEP_AFPH <= 99.64\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- DEP_AFPH > 99.64\n", + "| | | | | | | |--- class: 0\n", + "| | | |--- LATE_AIRCRAFT_DELAY > 60.50\n", + "| | | | |--- DEP_DELAY <= 114.50\n", + "| | | | | |--- LATE_AIRCRAFT_DELAY <= 71.50\n", + "| | | | | | |--- DEP_DELAY <= 95.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- DEP_DELAY > 95.50\n", + "| | | | | | | |--- class: 1\n", + "| | | | | |--- LATE_AIRCRAFT_DELAY > 71.50\n", + "| | | | | | |--- DEP_DELAY <= 96.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- DEP_DELAY > 96.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | |--- DEP_DELAY > 114.50\n", + "| | | | | |--- LATE_AIRCRAFT_DELAY <= 98.50\n", + "| | | | | | |--- WEATHER_DELAY <= 1.00\n", + "| | | | | | | |--- class: 1\n", + "| | | | | | |--- WEATHER_DELAY > 1.00\n", + "| | | | | | | |--- class: 0\n", + "| | | | | |--- LATE_AIRCRAFT_DELAY > 98.50\n", + "| | | | | | |--- DEP_DELAY <= 171.50\n", + "| | | | | | | |--- class: 0\n", + "| | | | | | |--- DEP_DELAY > 171.50\n", + "| | | | | | | |--- class: 0\n", + "\n" + ] + } + ], + "source": [ + "text_tree = tree.export_text(\n", + " class_models[\"decision_tree\"][\"fitted\"],\n", + " feature_names=X_train.columns.values.tolist(),\n", + ")\n", + "print(text_tree)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
featureimportance
2DEP_DELAY0.527482
16LATE_AIRCRAFT_DELAY0.199153
8PCT_ELAPSED_TIME0.105381
13WEATHER_DELAY0.101649
14NAS_DELAY0.062732
15SECURITY_DELAY0.001998
9DISTANCE0.001019
7CRS_ELAPSED_TIME0.000281
5TAXI_OUT0.000239
6WHEELS_OFF0.000035
3DEP_AFPH0.000031
0CRS_DEP_TIME0.000000
19ORIGIN_HUB0.000000
18DEP_DOW0.000000
17DEP_MONTH0.000000
10CRS_ARR_TIME0.000000
12ARR_RFPH0.000000
11ARR_AFPH0.000000
1DEP_TIME0.000000
4DEP_RFPH0.000000
20DEST_HUB0.000000
\n", + "
" + ], + "text/plain": [ + " feature importance\n", + "2 DEP_DELAY 0.527482\n", + "16 LATE_AIRCRAFT_DELAY 0.199153\n", + "8 PCT_ELAPSED_TIME 0.105381\n", + "13 WEATHER_DELAY 0.101649\n", + "14 NAS_DELAY 0.062732\n", + "15 SECURITY_DELAY 0.001998\n", + "9 DISTANCE 0.001019\n", + "7 CRS_ELAPSED_TIME 0.000281\n", + "5 TAXI_OUT 0.000239\n", + "6 WHEELS_OFF 0.000035\n", + "3 DEP_AFPH 0.000031\n", + "0 CRS_DEP_TIME 0.000000\n", + "19 ORIGIN_HUB 0.000000\n", + "18 DEP_DOW 0.000000\n", + "17 DEP_MONTH 0.000000\n", + "10 CRS_ARR_TIME 0.000000\n", + "12 ARR_RFPH 0.000000\n", + "11 ARR_AFPH 0.000000\n", + "1 DEP_TIME 0.000000\n", + "4 DEP_RFPH 0.000000\n", + "20 DEST_HUB 0.000000" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dt_imp_df = pd.DataFrame(\n", + " {\n", + " \"feature\": X_train.columns.values.tolist(),\n", + " \"importance\": class_models[\"decision_tree\"][\"fitted\"].feature_importances_,\n", + " }\n", + ").sort_values(by=\"importance\", ascending=False)\n", + "dt_imp_df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "k ближайших соседей" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CRS_DEP_TIME 655.000000\n", + "DEP_TIME 1055.000000\n", + "DEP_DELAY 240.000000\n", + "DEP_AFPH 90.800000\n", + "DEP_RFPH 0.890196\n", + "TAXI_OUT 35.000000\n", + "WHEELS_OFF 1130.000000\n", + "CRS_ELAPSED_TIME 259.000000\n", + "PCT_ELAPSED_TIME 1.084942\n", + "DISTANCE 1660.000000\n", + "CRS_ARR_TIME 914.000000\n", + "ARR_AFPH 40.434783\n", + "ARR_RFPH 1.064073\n", + "WEATHER_DELAY 0.000000\n", + "NAS_DELAY 22.000000\n", + "SECURITY_DELAY 0.000000\n", + "LATE_AIRCRAFT_DELAY 221.000000\n", + "DEP_MONTH 10.000000\n", + "DEP_DOW 4.000000\n", + "ORIGIN_HUB 1.000000\n", + "DEST_HUB 0.000000\n", + "Name: 721043, dtype: float64\n" + ] + } + ], + "source": [ + "print(X_test.loc[721043, :])" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([[143.3160128 , 173.90740076, 192.66705727, 211.57109221,\n", + " 243.57211853, 259.61593993, 259.77507391]]),\n", + " array([[105172, 571912, 73409, 89450, 77474, 705972, 706911]]))" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class_models[\"knn\"][\"fitted\"].kneighbors(\n", + " X_test.loc[721043, :].values.reshape(1, 21), 7\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3813 0\n", + "229062 1\n", + "283316 0\n", + "385831 0\n", + "581905 1\n", + "726784 1\n", + "179364 0\n", + "Name: CARRIER_DELAY, dtype: int64" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_train_class.iloc[[105172, 571912, 73409, 89450, 77474, 705972, 706911]]" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'euclidean'" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class_models[\"knn\"][\"fitted\"].effective_metric_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Гауссов наивный Байес" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.93871674, 0.06128326])" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class_models[\"naive_bayes\"][\"fitted\"].class_prior_" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[2.50123026e+05, 2.61324730e+05, 9.21572605e+02, 1.26123968e+03,\n", + " 2.08339528e-01, 9.58074414e+01, 2.62606651e+05, 6.30102550e+03,\n", + " 1.13475535e-02, 4.22470414e+05, 2.75433641e+05, 1.25314386e+03,\n", + " 3.48655340e-01, 1.11234714e+02, 1.91877186e+02, 2.80302201e+00,\n", + " 5.06561612e+02, 1.17346654e+01, 3.99122491e+00, 2.39015406e-01,\n", + " 2.34996222e-01],\n", + " [2.60629652e+05, 2.96009867e+05, 1.19307931e+04, 1.14839167e+03,\n", + " 1.99929921e+00, 1.20404927e+02, 3.08568277e+05, 6.29066219e+03,\n", + " 1.38936741e-02, 4.10198938e+05, 3.28574000e+05, 1.09023147e+03,\n", + " 3.08997044e+00, 7.79140423e+01, 1.56184090e+02, 9.12112286e-01,\n", + " 2.11279954e+03, 1.02712368e+01, 4.02943162e+00, 1.77750796e-01,\n", + " 2.50208354e-01]])" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class_models[\"naive_bayes\"][\"fitted\"].var_" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1.30740577e+03, 1.31006271e+03, 5.14196506e+00, 5.45864877e+01,\n", + " 1.09377996e+00, 1.87120810e+01, 1.33552258e+03, 1.70734929e+02,\n", + " 9.71131781e-01, 1.01824369e+03, 1.48438931e+03, 5.39873058e+01,\n", + " 1.09644787e+00, 7.39971299e-01, 2.85434558e+00, 2.41814585e-02,\n", + " 4.14674395e+00, 6.55045281e+00, 2.95035528e+00, 6.06800513e-01,\n", + " 6.24199571e-01],\n", + " [1.41305545e+03, 1.48087887e+03, 8.45867640e+01, 6.14731036e+01,\n", + " 1.25429654e+00, 1.99378321e+01, 1.49409412e+03, 1.72229998e+02,\n", + " 9.83974416e-01, 1.04363666e+03, 1.54821862e+03, 4.26486417e+01,\n", + " 1.36373798e+00, 4.50733082e-01, 4.71991378e+00, 2.11281132e-02,\n", + " 1.40744819e+01, 6.73367907e+00, 3.04251232e+00, 7.69575517e-01,\n", + " 4.85391724e-01]])" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class_models[\"naive_bayes\"][\"fitted\"].theta_" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": ".venv (3.11.12)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.12" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/poetry.lock b/poetry.lock index 2c35213..006e3a7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. [[package]] name = "anyio" @@ -624,23 +624,6 @@ files = [ [package.extras] devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] -[[package]] -name = "filelock" -version = "3.18.0" -description = "A platform independent file lock." -optional = false -python-versions = ">=3.9" -groups = ["main"] -files = [ - {file = "filelock-3.18.0-py3-none-any.whl", hash = "sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de"}, - {file = "filelock-3.18.0.tar.gz", hash = "sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2"}, -] - -[package.extras] -docs = ["furo (>=2024.8.6)", "sphinx (>=8.1.3)", "sphinx-autodoc-typehints (>=3)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.6.10)", "diff-cover (>=9.2.1)", "pytest (>=8.3.4)", "pytest-asyncio (>=0.25.2)", "pytest-cov (>=6)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.28.1)"] -typing = ["typing-extensions (>=4.12.2) ; python_version < \"3.11\""] - [[package]] name = "fonttools" version = "4.56.0" @@ -727,46 +710,6 @@ files = [ {file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"}, ] -[[package]] -name = "fsspec" -version = "2025.3.0" -description = "File-system specification" -optional = false -python-versions = ">=3.8" -groups = ["main"] -files = [ - {file = "fsspec-2025.3.0-py3-none-any.whl", hash = "sha256:efb87af3efa9103f94ca91a7f8cb7a4df91af9f74fc106c9c7ea0efd7277c1b3"}, - {file = "fsspec-2025.3.0.tar.gz", hash = "sha256:a935fd1ea872591f2b5148907d103488fc523295e6c64b835cfad8c3eca44972"}, -] - -[package.extras] -abfs = ["adlfs"] -adl = ["adlfs"] -arrow = ["pyarrow (>=1)"] -dask = ["dask", "distributed"] -dev = ["pre-commit", "ruff"] -doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"] -dropbox = ["dropbox", "dropboxdrivefs", "requests"] -full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] -fuse = ["fusepy"] -gcs = ["gcsfs"] -git = ["pygit2"] -github = ["requests"] -gs = ["gcsfs"] -gui = ["panel"] -hdfs = ["pyarrow (>=1)"] -http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"] -libarchive = ["libarchive-c"] -oci = ["ocifs"] -s3 = ["s3fs"] -sftp = ["paramiko"] -smb = ["smbprotocol"] -ssh = ["paramiko"] -test = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "numpy", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "requests"] -test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"] -test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] -tqdm = ["tqdm"] - [[package]] name = "h11" version = "0.14.0" @@ -1444,29 +1387,6 @@ files = [ {file = "kiwisolver-1.4.8.tar.gz", hash = "sha256:23d5f023bdc8c7e54eb65f03ca5d5bb25b601eac4d7f1a042888a1f45237987e"}, ] -[[package]] -name = "machine-learning-datasets" -version = "0.1.23" -description = "A simple library for loading machine learning datasets and performing some common machine learning interpretation functions. Built for the book \"Interpretable Machine Learning with Python, 2E\"." -optional = false -python-versions = ">=3.9,<3.13" -groups = ["main"] -files = [ - {file = "machine_learning_datasets-0.1.23-py3-none-any.whl", hash = "sha256:0af33f197e8bc6451b00f25c22d260a909ae5f93d79395762dc049bfe155e0f6"}, - {file = "machine_learning_datasets-0.1.23.tar.gz", hash = "sha256:9e27cdec597718e659658e95401e8efc2381ee8ea0997fe6539b13cbec227dca"}, -] - -[package.dependencies] -matplotlib = ">=3.7.1,<4.0.0" -numpy = ">=1.23.5,<2.0.0" -opencv-python = ">=4.5.1,<5.0.0" -pandas = ">=1.5.3,<2.0.0" -requests = ">=2.31.0" -scikit-learn = ">=1.2.2,<2.0.0" -scipy = ">=1.11.3,<2.0.0" -seaborn = ">=0.12.2,<0.13.0" -torchvision = ">=0.16.0,<0.17.0" - [[package]] name = "markupsafe" version = "3.0.2" @@ -1647,24 +1567,6 @@ scipy = ">=1.2.1" docs = ["mkdocs", "mkdocs-bootswatch", "nbconvert", "python-markdown-math"] testing = ["pytest"] -[[package]] -name = "mpmath" -version = "1.3.0" -description = "Python library for arbitrary-precision floating-point arithmetic" -optional = false -python-versions = "*" -groups = ["main"] -files = [ - {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, - {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, -] - -[package.extras] -develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] -docs = ["sphinx"] -gmpy = ["gmpy2 (>=2.1.0a4) ; platform_python_implementation != \"PyPy\""] -tests = ["pytest (>=4.6)"] - [[package]] name = "nbclient" version = "0.10.2" @@ -1759,26 +1661,6 @@ files = [ {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, ] -[[package]] -name = "networkx" -version = "3.4.2" -description = "Python package for creating and manipulating graphs and networks" -optional = false -python-versions = ">=3.10" -groups = ["main"] -files = [ - {file = "networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f"}, - {file = "networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1"}, -] - -[package.extras] -default = ["matplotlib (>=3.7)", "numpy (>=1.24)", "pandas (>=2.0)", "scipy (>=1.10,!=1.11.0,!=1.11.1)"] -developer = ["changelist (==0.5)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"] -doc = ["intersphinx-registry", "myst-nb (>=1.1)", "numpydoc (>=1.8.0)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.15)", "sphinx (>=7.3)", "sphinx-gallery (>=0.16)", "texext (>=0.6.7)"] -example = ["cairocffi (>=1.7)", "contextily (>=1.6)", "igraph (>=0.11)", "momepy (>=0.7.2)", "osmnx (>=1.9)", "scikit-learn (>=1.5)", "seaborn (>=0.13)"] -extra = ["lxml (>=4.6)", "pydot (>=3.0.1)", "pygraphviz (>=1.14)", "sympy (>=1.10)"] -test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] - [[package]] name = "notebook" version = "7.3.3" @@ -1867,192 +1749,6 @@ files = [ {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, ] -[[package]] -name = "nvidia-cublas-cu12" -version = "12.1.3.1" -description = "CUBLAS native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, - {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, -] - -[[package]] -name = "nvidia-cuda-cupti-cu12" -version = "12.1.105" -description = "CUDA profiling tools runtime libs." -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, - {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, -] - -[[package]] -name = "nvidia-cuda-nvrtc-cu12" -version = "12.1.105" -description = "NVRTC native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, - {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, -] - -[[package]] -name = "nvidia-cuda-runtime-cu12" -version = "12.1.105" -description = "CUDA Runtime native Libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, - {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, -] - -[[package]] -name = "nvidia-cudnn-cu12" -version = "8.9.2.26" -description = "cuDNN runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, -] - -[package.dependencies] -nvidia-cublas-cu12 = "*" - -[[package]] -name = "nvidia-cufft-cu12" -version = "11.0.2.54" -description = "CUFFT native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, - {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, -] - -[[package]] -name = "nvidia-curand-cu12" -version = "10.3.2.106" -description = "CURAND native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, - {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, -] - -[[package]] -name = "nvidia-cusolver-cu12" -version = "11.4.5.107" -description = "CUDA solver native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, - {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, -] - -[package.dependencies] -nvidia-cublas-cu12 = "*" -nvidia-cusparse-cu12 = "*" -nvidia-nvjitlink-cu12 = "*" - -[[package]] -name = "nvidia-cusparse-cu12" -version = "12.1.0.106" -description = "CUSPARSE native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, - {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, -] - -[package.dependencies] -nvidia-nvjitlink-cu12 = "*" - -[[package]] -name = "nvidia-nccl-cu12" -version = "2.18.1" -description = "NVIDIA Collective Communication Library (NCCL) Runtime" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:1a6c4acefcbebfa6de320f412bf7866de856e786e0462326ba1bac40de0b5e71"}, -] - -[[package]] -name = "nvidia-nvjitlink-cu12" -version = "12.8.93" -description = "Nvidia JIT LTO Library" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88"}, - {file = "nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:adccd7161ace7261e01bb91e44e88da350895c270d23f744f0820c818b7229e7"}, - {file = "nvidia_nvjitlink_cu12-12.8.93-py3-none-win_amd64.whl", hash = "sha256:bd93fbeeee850917903583587f4fc3a4eafa022e34572251368238ab5e6bd67f"}, -] - -[[package]] -name = "nvidia-nvtx-cu12" -version = "12.1.105" -description = "NVIDIA Tools Extension" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, - {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, -] - -[[package]] -name = "opencv-python" -version = "4.11.0.86" -description = "Wrapper package for OpenCV python bindings." -optional = false -python-versions = ">=3.6" -groups = ["main"] -files = [ - {file = "opencv-python-4.11.0.86.tar.gz", hash = "sha256:03d60ccae62304860d232272e4a4fda93c39d595780cb40b161b310244b736a4"}, - {file = "opencv_python-4.11.0.86-cp37-abi3-macosx_13_0_arm64.whl", hash = "sha256:432f67c223f1dc2824f5e73cdfcd9db0efc8710647d4e813012195dc9122a52a"}, - {file = "opencv_python-4.11.0.86-cp37-abi3-macosx_13_0_x86_64.whl", hash = "sha256:9d05ef13d23fe97f575153558653e2d6e87103995d54e6a35db3f282fe1f9c66"}, - {file = "opencv_python-4.11.0.86-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b92ae2c8852208817e6776ba1ea0d6b1e0a1b5431e971a2a0ddd2a8cc398202"}, - {file = "opencv_python-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b02611523803495003bd87362db3e1d2a0454a6a63025dc6658a9830570aa0d"}, - {file = "opencv_python-4.11.0.86-cp37-abi3-win32.whl", hash = "sha256:810549cb2a4aedaa84ad9a1c92fbfdfc14090e2749cedf2c1589ad8359aa169b"}, - {file = "opencv_python-4.11.0.86-cp37-abi3-win_amd64.whl", hash = "sha256:085ad9b77c18853ea66283e98affefe2de8cc4c1f43eda4c100cf9b2721142ec"}, -] - -[package.dependencies] -numpy = {version = ">=1.23.5", markers = "python_version >= \"3.11\""} - [[package]] name = "overrides" version = "7.7.0" @@ -3147,24 +2843,6 @@ build = ["cython (>=3.0.10)"] develop = ["colorama", "cython (>=3.0.10)", "cython (>=3.0.10,<4)", "flake8", "isort", "joblib", "matplotlib (>=3)", "pytest (>=7.3.0,<8)", "pytest-cov", "pytest-randomly", "pytest-xdist", "pywinpty ; os_name == \"nt\"", "setuptools-scm[toml] (>=8.0,<9.0)"] docs = ["ipykernel", "jupyter-client", "matplotlib", "nbconvert", "nbformat", "numpydoc", "pandas-datareader", "sphinx"] -[[package]] -name = "sympy" -version = "1.13.3" -description = "Computer algebra system (CAS) in Python" -optional = false -python-versions = ">=3.8" -groups = ["main"] -files = [ - {file = "sympy-1.13.3-py3-none-any.whl", hash = "sha256:54612cf55a62755ee71824ce692986f23c88ffa77207b30c1368eda4a7060f73"}, - {file = "sympy-1.13.3.tar.gz", hash = "sha256:b27fd2c6530e0ab39e275fc9b683895367e51d5da91baa8d3d64db2565fec4d9"}, -] - -[package.dependencies] -mpmath = ">=1.1.0,<1.4" - -[package.extras] -dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] - [[package]] name = "terminado" version = "0.18.1" @@ -3218,99 +2896,6 @@ webencodings = ">=0.4" doc = ["sphinx", "sphinx_rtd_theme"] test = ["pytest", "ruff"] -[[package]] -name = "torch" -version = "2.1.2" -description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" -optional = false -python-versions = ">=3.8.0" -groups = ["main"] -files = [ - {file = "torch-2.1.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:3a871edd6c02dae77ad810335c0833391c1a4ce49af21ea8cf0f6a5d2096eea8"}, - {file = "torch-2.1.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:bef6996c27d8f6e92ea4e13a772d89611da0e103b48790de78131e308cf73076"}, - {file = "torch-2.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:0e13034fd5fb323cbbc29e56d0637a3791e50dd589616f40c79adfa36a5a35a1"}, - {file = "torch-2.1.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:d9b535cad0df3d13997dbe8bd68ac33e0e3ae5377639c9881948e40794a61403"}, - {file = "torch-2.1.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:f9a55d55af02826ebfbadf4e9b682f0f27766bc33df8236b48d28d705587868f"}, - {file = "torch-2.1.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:a6ebbe517097ef289cc7952783588c72de071d4b15ce0f8b285093f0916b1162"}, - {file = "torch-2.1.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:8f32ce591616a30304f37a7d5ea80b69ca9e1b94bba7f308184bf616fdaea155"}, - {file = "torch-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:e0ee6cf90c8970e05760f898d58f9ac65821c37ffe8b04269ec787aa70962b69"}, - {file = "torch-2.1.2-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:76d37967c31c99548ad2c4d3f2cf191db48476f2e69b35a0937137116da356a1"}, - {file = "torch-2.1.2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:e2d83f07b4aac983453ea5bf8f9aa9dacf2278a8d31247f5d9037f37befc60e4"}, - {file = "torch-2.1.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:f41fe0c7ecbf903a568c73486139a75cfab287a0f6c17ed0698fdea7a1e8641d"}, - {file = "torch-2.1.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e3225f47d50bb66f756fe9196a768055d1c26b02154eb1f770ce47a2578d3aa7"}, - {file = "torch-2.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:33d59cd03cb60106857f6c26b36457793637512998666ee3ce17311f217afe2b"}, - {file = "torch-2.1.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:8e221deccd0def6c2badff6be403e0c53491805ed9915e2c029adbcdb87ab6b5"}, - {file = "torch-2.1.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:05b18594f60a911a0c4f023f38a8bda77131fba5fd741bda626e97dcf5a3dd0a"}, - {file = "torch-2.1.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9ca96253b761e9aaf8e06fb30a66ee301aecbf15bb5a303097de1969077620b6"}, - {file = "torch-2.1.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d93ba70f67b08c2ae5598ee711cbc546a1bc8102cef938904b8c85c2089a51a0"}, - {file = "torch-2.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:255b50bc0608db177e6a3cc118961d77de7e5105f07816585fa6f191f33a9ff3"}, - {file = "torch-2.1.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6984cd5057c0c977b3c9757254e989d3f1124f4ce9d07caa6cb637783c71d42a"}, - {file = "torch-2.1.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:bc195d7927feabc0eb7c110e457c955ed2ab616f3c7c28439dd4188cf589699f"}, -] - -[package.dependencies] -filelock = "*" -fsspec = "*" -jinja2 = "*" -networkx = "*" -nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nccl-cu12 = {version = "2.18.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -sympy = "*" -triton = {version = "2.1.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -typing-extensions = "*" - -[package.extras] -dynamo = ["jinja2"] -opt-einsum = ["opt-einsum (>=3.3)"] - -[[package]] -name = "torchvision" -version = "0.16.2" -description = "image and video datasets and models for torch deep learning" -optional = false -python-versions = ">=3.8" -groups = ["main"] -files = [ - {file = "torchvision-0.16.2-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:bc86f2800cb2c0c1a09c581409cdd6bff66e62f103dc83fc63f73346264c3756"}, - {file = "torchvision-0.16.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b024bd412df6d3a007dcebf311a894eb3c5c21e1af80d12be382bbcb097a7c3a"}, - {file = "torchvision-0.16.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:e89f10f3c8351972b6e3fda95bc3e479ea8dbfc9dfcfd2c32902dbad4ba5cfc5"}, - {file = "torchvision-0.16.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:96c7583700112a410bdc4e1e4f118c429dab49c29c9a31a2cc3579bc9b08b19d"}, - {file = "torchvision-0.16.2-cp310-cp310-win_amd64.whl", hash = "sha256:9f4032ebb3277fb07ff6a9b818d50a547fb8fcd89d958cfd9e773322454bb688"}, - {file = "torchvision-0.16.2-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:67b1aaf8b8cb02ce75dd445f291a27c8036a502f8c0aa76e28c37a0faac2e153"}, - {file = "torchvision-0.16.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bef30d03e1d1c629761f4dca51d3b7d8a0dc0acce6f4068ab2a1634e8e7b64e0"}, - {file = "torchvision-0.16.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:e59cc7b2bd1ab5c0ce4ae382e4e37be8f1c174e8b5de2f6a23c170de9ae28495"}, - {file = "torchvision-0.16.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e130b08cc9b3cc73a6c59d6edf032394a322f9579bfd21d14bc2e1d0999aa758"}, - {file = "torchvision-0.16.2-cp311-cp311-win_amd64.whl", hash = "sha256:8692ab1e48807e9604046a6f4beeb67b523294cee1b00828654bb0df2cfce2b2"}, - {file = "torchvision-0.16.2-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:b82732dcf876a37c852772342aa6ee3480c03bb3e2a802ae109fc5f7e28d26e9"}, - {file = "torchvision-0.16.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4b065143d1a720fe8a9077fd4be35d491f98819ec80b3dbbc3ec64d0b707a906"}, - {file = "torchvision-0.16.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:bc5f274e4ecd1b86062063cdf4fd385a1d39d147a3a2685fbbde9ff08bb720b8"}, - {file = "torchvision-0.16.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:335959c43b371c0474af34c1ef2a52efdc7603c45700d29e4475eeb02984170c"}, - {file = "torchvision-0.16.2-cp38-cp38-win_amd64.whl", hash = "sha256:7fd22d86e08eba321af70cad291020c2cdeac069b00ce88b923ca52e06174769"}, - {file = "torchvision-0.16.2-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:56115268b37f0b75364e3654e47ad9abc66ac34c1f9e5e3dfa89a22d6a40017a"}, - {file = "torchvision-0.16.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:82805f8445b094f9d1e770390ee6cc86855e89955e08ce34af2e2274fc0e5c45"}, - {file = "torchvision-0.16.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:3f4bd5fcbc361476e2e78016636ac7d5509e59d9962521f06eb98e6803898182"}, - {file = "torchvision-0.16.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:8199acdf8ab066a28b84a5b6f4d97b58976d9e164b1acc3a9d14fccfaf74bb3a"}, - {file = "torchvision-0.16.2-cp39-cp39-win_amd64.whl", hash = "sha256:41dd4fa9f176d563fe9f1b9adef3b7e582cdfb60ce8c9bc51b094a025be687c9"}, -] - -[package.dependencies] -numpy = "*" -pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0" -requests = "*" -torch = "2.1.2" - -[package.extras] -scipy = ["scipy"] - [[package]] name = "tornado" version = "6.4.2" @@ -3348,33 +2933,6 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.2)", "pytest-mock", "pytest-mypy-testing"] -[[package]] -name = "triton" -version = "2.1.0" -description = "A language and compiler for custom Deep Learning operations" -optional = false -python-versions = "*" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:66439923a30d5d48399b08a9eae10370f6c261a5ec864a64983bae63152d39d7"}, - {file = "triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:919b06453f0033ea52c13eaf7833de0e57db3178d23d4e04f9fc71c4f2c32bf8"}, - {file = "triton-2.1.0-0-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ae4bb8a91de790e1866405211c4d618379781188f40d5c4c399766914e84cd94"}, - {file = "triton-2.1.0-0-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:39f6fb6bdccb3e98f3152e3fbea724f1aeae7d749412bbb1fa9c441d474eba26"}, - {file = "triton-2.1.0-0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:21544e522c02005a626c8ad63d39bdff2f31d41069592919ef281e964ed26446"}, - {file = "triton-2.1.0-0-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:143582ca31dd89cd982bd3bf53666bab1c7527d41e185f9e3d8a3051ce1b663b"}, - {file = "triton-2.1.0-0-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:82fc5aeeedf6e36be4e4530cbdcba81a09d65c18e02f52dc298696d45721f3bd"}, - {file = "triton-2.1.0-0-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:81a96d110a738ff63339fc892ded095b31bd0d205e3aace262af8400d40b6fa8"}, -] - -[package.dependencies] -filelock = "*" - -[package.extras] -build = ["cmake (>=3.18)", "lit"] -tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] -tutorials = ["matplotlib", "pandas", "tabulate"] - [[package]] name = "types-python-dateutil" version = "2.9.0.20241206" @@ -3500,4 +3058,4 @@ files = [ [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.12" -content-hash = "2cec5dd5e3b848faf739f79116a9720df2e37d7ead5b32e2f4242d31a0087d25" +content-hash = "99801b0ad912851f8cd3d6b085b7210017cc1b5dcf58785869a72db48f6b7fea" diff --git a/pyproject.toml b/pyproject.toml index 314e98b..f3a0870 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,9 +17,9 @@ numpy = "^1.26.4" pandas = "^1.5.3" scikit-learn = "^1.6.1" matplotlib= "^3.10.1" -machine-learning-datasets = "^0.1.23" statsmodels = "^0.14.4" mlxtend = "^0.23.4" +seaborn = "^0.12.2" [build-system] requires = ["poetry-core"]