diff --git a/assets/lec2-split.png b/assets/lec2-split.png new file mode 100644 index 0000000..fad8160 Binary files /dev/null and b/assets/lec2-split.png differ diff --git a/lec2.ipynb b/lec2.ipynb new file mode 100644 index 0000000..361bd66 --- /dev/null +++ b/lec2.ipynb @@ -0,0 +1,1187 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Загрузка данных в DataFrame" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Index: 891 entries, 1 to 891\n", + "Data columns (total 11 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 Survived 891 non-null int64 \n", + " 1 Pclass 891 non-null int64 \n", + " 2 Name 891 non-null object \n", + " 3 Sex 891 non-null object \n", + " 4 Age 714 non-null float64\n", + " 5 SibSp 891 non-null int64 \n", + " 6 Parch 891 non-null int64 \n", + " 7 Ticket 891 non-null object \n", + " 8 Fare 891 non-null float64\n", + " 9 Cabin 204 non-null object \n", + " 10 Embarked 889 non-null object \n", + "dtypes: float64(2), int64(4), object(5)\n", + "memory usage: 83.5+ KB\n" + ] + }, + { + "data": { + "text/plain": [ + "(891, 11)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
SurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
PassengerId
103Braund, Mr. Owen Harrismale22.010A/5 211717.2500NaNS
211Cumings, Mrs. John Bradley (Florence Briggs Th...female38.010PC 1759971.2833C85C
313Heikkinen, Miss. Lainafemale26.000STON/O2. 31012827.9250NaNS
411Futrelle, Mrs. Jacques Heath (Lily May Peel)female35.01011380353.1000C123S
503Allen, Mr. William Henrymale35.0003734508.0500NaNS
\n", + "
" + ], + "text/plain": [ + " Survived Pclass \\\n", + "PassengerId \n", + "1 0 3 \n", + "2 1 1 \n", + "3 1 3 \n", + "4 1 1 \n", + "5 0 3 \n", + "\n", + " Name Sex Age \\\n", + "PassengerId \n", + "1 Braund, Mr. Owen Harris male 22.0 \n", + "2 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 \n", + "3 Heikkinen, Miss. Laina female 26.0 \n", + "4 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 \n", + "5 Allen, Mr. William Henry male 35.0 \n", + "\n", + " SibSp Parch Ticket Fare Cabin Embarked \n", + "PassengerId \n", + "1 1 0 A/5 21171 7.2500 NaN S \n", + "2 1 0 PC 17599 71.2833 C85 C \n", + "3 0 0 STON/O2. 3101282 7.9250 NaN S \n", + "4 1 0 113803 53.1000 C123 S \n", + "5 0 0 373450 8.0500 NaN S " + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "df = pd.read_csv(\"data/titanic.csv\", index_col=\"PassengerId\")\n", + "\n", + "df.info()\n", + "\n", + "display(df.shape)\n", + "\n", + "df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Получение сведений о пропущенных данных" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Типы пропущенных данных:\n", + "- None - представление пустых данных в Python\n", + "- NaN - представление пустых данных в Pandas\n", + "- '' - пустая строка" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Survived 0\n", + "Pclass 0\n", + "Name 0\n", + "Sex 0\n", + "Age 177\n", + "SibSp 0\n", + "Parch 0\n", + "Ticket 0\n", + "Fare 0\n", + "Cabin 687\n", + "Embarked 2\n", + "dtype: int64" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Survived False\n", + "Pclass False\n", + "Name False\n", + "Sex False\n", + "Age True\n", + "SibSp False\n", + "Parch False\n", + "Ticket False\n", + "Fare False\n", + "Cabin True\n", + "Embarked True\n", + "dtype: bool" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'Age процент пустых значений: %19.87'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'Cabin процент пустых значений: %77.10'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'Embarked процент пустых значений: %0.22'" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Количество пустых значений признаков\n", + "display(df.isnull().sum())\n", + "display()\n", + "\n", + "# Есть ли пустые значения признаков\n", + "display(df.isnull().any())\n", + "display()\n", + "\n", + "# Процент пустых значений признаков\n", + "for i in df.columns:\n", + " null_rate = df[i].isnull().sum() / len(df) * 100\n", + " if null_rate > 0:\n", + " display(f\"{i} процент пустых значений: %{null_rate:.2f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Заполнение пропущенных данных\n", + "\n", + "https://pythonmldaily.com/posts/pandas-dataframes-search-drop-empty-values\n", + "\n", + "https://scales.arabpsychology.com/stats/how-to-fill-nan-values-with-median-in-pandas/" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(891, 11)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Survived False\n", + "Pclass False\n", + "Name False\n", + "Sex False\n", + "Age False\n", + "SibSp False\n", + "Parch False\n", + "Ticket False\n", + "Fare False\n", + "Cabin False\n", + "Embarked False\n", + "dtype: bool" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
SurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarkedAgeFillNAAgeFillMedian
PassengerId
88702Montvila, Rev. Juozasmale27.00021153613.00NaNS27.027.0
88811Graham, Miss. Margaret Edithfemale19.00011205330.00B42S19.019.0
88903Johnston, Miss. Catherine Helen \"Carrie\"femaleNaN12W./C. 660723.45NaNS0.028.0
89011Behr, Mr. Karl Howellmale26.00011136930.00C148C26.026.0
89103Dooley, Mr. Patrickmale32.0003703767.75NaNQ32.032.0
\n", + "
" + ], + "text/plain": [ + " Survived Pclass Name \\\n", + "PassengerId \n", + "887 0 2 Montvila, Rev. Juozas \n", + "888 1 1 Graham, Miss. Margaret Edith \n", + "889 0 3 Johnston, Miss. Catherine Helen \"Carrie\" \n", + "890 1 1 Behr, Mr. Karl Howell \n", + "891 0 3 Dooley, Mr. Patrick \n", + "\n", + " Sex Age SibSp Parch Ticket Fare Cabin Embarked \\\n", + "PassengerId \n", + "887 male 27.0 0 0 211536 13.00 NaN S \n", + "888 female 19.0 0 0 112053 30.00 B42 S \n", + "889 female NaN 1 2 W./C. 6607 23.45 NaN S \n", + "890 male 26.0 0 0 111369 30.00 C148 C \n", + "891 male 32.0 0 0 370376 7.75 NaN Q \n", + "\n", + " AgeFillNA AgeFillMedian \n", + "PassengerId \n", + "887 27.0 27.0 \n", + "888 19.0 19.0 \n", + "889 0.0 28.0 \n", + "890 26.0 26.0 \n", + "891 32.0 32.0 " + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fillna_df = df.fillna(0)\n", + "\n", + "display(fillna_df.shape)\n", + "\n", + "display(fillna_df.isnull().any())\n", + "\n", + "# Замена пустых данных на 0\n", + "df[\"AgeFillNA\"] = df[\"Age\"].fillna(0)\n", + "\n", + "# Замена пустых данных на медиану\n", + "df[\"AgeFillMedian\"] = df[\"Age\"].fillna(df[\"Age\"].median())\n", + "\n", + "df.tail()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
SurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarkedAgeFillNAAgeFillMedianAgeCopy
PassengerId
88702Montvila, Rev. Juozasmale27.00021153613.00NaNS27.027.027.0
88811Graham, Miss. Margaret Edithfemale19.00011205330.00B42S19.019.019.0
88903Johnston, Miss. Catherine Helen \"Carrie\"femaleNaN12W./C. 660723.45NaNS0.028.00.0
89011Behr, Mr. Karl Howellmale26.00011136930.00C148C26.026.026.0
89103Dooley, Mr. Patrickmale32.0003703767.75NaNQ32.032.032.0
\n", + "
" + ], + "text/plain": [ + " Survived Pclass Name \\\n", + "PassengerId \n", + "887 0 2 Montvila, Rev. Juozas \n", + "888 1 1 Graham, Miss. Margaret Edith \n", + "889 0 3 Johnston, Miss. Catherine Helen \"Carrie\" \n", + "890 1 1 Behr, Mr. Karl Howell \n", + "891 0 3 Dooley, Mr. Patrick \n", + "\n", + " Sex Age SibSp Parch Ticket Fare Cabin Embarked \\\n", + "PassengerId \n", + "887 male 27.0 0 0 211536 13.00 NaN S \n", + "888 female 19.0 0 0 112053 30.00 B42 S \n", + "889 female NaN 1 2 W./C. 6607 23.45 NaN S \n", + "890 male 26.0 0 0 111369 30.00 C148 C \n", + "891 male 32.0 0 0 370376 7.75 NaN Q \n", + "\n", + " AgeFillNA AgeFillMedian AgeCopy \n", + "PassengerId \n", + "887 27.0 27.0 27.0 \n", + "888 19.0 19.0 19.0 \n", + "889 0.0 28.0 0.0 \n", + "890 26.0 26.0 26.0 \n", + "891 32.0 32.0 32.0 " + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df[\"AgeCopy\"] = df[\"Age\"]\n", + "\n", + "# Замена данных сразу в DataFrame без копирования\n", + "df.fillna({\"AgeCopy\": 0}, inplace=True)\n", + "\n", + "df.tail()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Удаление наблюдений с пропусками" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(183, 14)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Survived False\n", + "Pclass False\n", + "Name False\n", + "Sex False\n", + "Age False\n", + "SibSp False\n", + "Parch False\n", + "Ticket False\n", + "Fare False\n", + "Cabin False\n", + "Embarked False\n", + "dtype: bool" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "dropna_df = df.dropna()\n", + "\n", + "display(dropna_df.shape)\n", + "\n", + "display(fillna_df.isnull().any())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Создание выборок данных\n", + "\n", + "Библиотека scikit-learn\n", + "\n", + "https://scikit-learn.org/stable/index.html" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Pclass\n", + "3 491\n", + "1 216\n", + "2 184\n", + "Name: count, dtype: int64" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'Обучающая выборка: '" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(534, 3)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Pclass\n", + "3 294\n", + "1 130\n", + "2 110\n", + "Name: count, dtype: int64" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'Контрольная выборка: '" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(178, 3)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Pclass\n", + "3 98\n", + "1 43\n", + "2 37\n", + "Name: count, dtype: int64" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'Тестовая выборка: '" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(179, 3)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Pclass\n", + "3 99\n", + "1 43\n", + "2 37\n", + "Name: count, dtype: int64" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Вывод распределения количества наблюдений по меткам (классам)\n", + "from src.utils import split_stratified_into_train_val_test\n", + "\n", + "\n", + "display(df.Pclass.value_counts())\n", + "display()\n", + "\n", + "data = df[[\"Pclass\", \"Survived\", \"AgeFillMedian\"]].copy()\n", + "\n", + "df_train, df_val, df_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n", + " data, stratify_colname=\"Pclass\", frac_train=0.60, frac_val=0.20, frac_test=0.20\n", + ")\n", + "\n", + "display(\"Обучающая выборка: \", df_train.shape)\n", + "display(df_train.Pclass.value_counts())\n", + "\n", + "display(\"Контрольная выборка: \", df_val.shape)\n", + "display(df_val.Pclass.value_counts())\n", + "\n", + "display(\"Тестовая выборка: \", df_test.shape)\n", + "display(df_test.Pclass.value_counts())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Выборка с избытком (oversampling)\n", + "\n", + "https://www.blog.trainindata.com/oversampling-techniques-for-imbalanced-data/\n", + "\n", + "https://datacrayon.com/machine-learning/class-imbalance-and-oversampling/\n", + "\n", + "Выборка с недостатком (undersampling)\n", + "\n", + "https://machinelearningmastery.com/random-oversampling-and-undersampling-for-imbalanced-classification/\n", + "\n", + "Библиотека imbalanced-learn\n", + "\n", + "https://imbalanced-learn.org/stable/" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Обучающая выборка: '" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(534, 3)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Pclass\n", + "3 294\n", + "1 130\n", + "2 110\n", + "Name: count, dtype: int64" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'Обучающая выборка после oversampling: '" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(860, 3)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Pclass\n", + "3 294\n", + "2 288\n", + "1 278\n", + "Name: count, dtype: int64" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
PclassSurvivedAgeFillMedian
03028.000000
13030.000000
21141.000000
33016.000000
43132.000000
............
8552025.000000
8562149.895392
8572049.548159
8582049.410133
8592050.944726
\n", + "

860 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " Pclass Survived AgeFillMedian\n", + "0 3 0 28.000000\n", + "1 3 0 30.000000\n", + "2 1 1 41.000000\n", + "3 3 0 16.000000\n", + "4 3 1 32.000000\n", + ".. ... ... ...\n", + "855 2 0 25.000000\n", + "856 2 1 49.895392\n", + "857 2 0 49.548159\n", + "858 2 0 49.410133\n", + "859 2 0 50.944726\n", + "\n", + "[860 rows x 3 columns]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from imblearn.over_sampling import ADASYN\n", + "\n", + "ada = ADASYN()\n", + "\n", + "display(\"Обучающая выборка: \", df_train.shape)\n", + "display(df_train.Pclass.value_counts())\n", + "\n", + "X_resampled, y_resampled = ada.fit_resample(df_train, df_train[\"Pclass\"]) # type: ignore\n", + "df_train_adasyn = pd.DataFrame(X_resampled)\n", + "\n", + "display(\"Обучающая выборка после oversampling: \", df_train_adasyn.shape)\n", + "display(df_train_adasyn.Pclass.value_counts())\n", + "\n", + "df_train_adasyn" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/poetry.lock b/poetry.lock index 375a1b5..4c0f1c5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -763,6 +763,30 @@ files = [ [package.extras] all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] +[[package]] +name = "imbalanced-learn" +version = "0.12.4" +description = "Toolbox for imbalanced dataset in machine learning." +optional = false +python-versions = "*" +files = [ + {file = "imbalanced-learn-0.12.4.tar.gz", hash = "sha256:8153ba385d296b07d97e0901a2624a86c06b48c94c2f92da3a5354827697b7a3"}, + {file = "imbalanced_learn-0.12.4-py3-none-any.whl", hash = "sha256:d47fc599160d3ea882e712a3a6b02bdd353c1a6436d8d68d41b1922e6ee4a703"}, +] + +[package.dependencies] +joblib = ">=1.1.1" +numpy = ">=1.17.3" +scikit-learn = ">=1.0.2" +scipy = ">=1.5.0" +threadpoolctl = ">=2.0.0" + +[package.extras] +docs = ["keras (>=2.4.3)", "matplotlib (>=3.1.2)", "memory-profiler (>=0.57.0)", "numpydoc (>=1.5.0)", "pandas (>=1.0.5)", "pydata-sphinx-theme (>=0.13.3)", "seaborn (>=0.9.0)", "sphinx (>=6.0.0)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-gallery (>=0.13.0)", "sphinxcontrib-bibtex (>=2.4.1)", "tensorflow (>=2.4.3)"] +examples = ["keras (>=2.4.3)", "matplotlib (>=3.1.2)", "pandas (>=1.0.5)", "seaborn (>=0.9.0)", "tensorflow (>=2.4.3)"] +optional = ["keras (>=2.4.3)", "pandas (>=1.0.5)", "tensorflow (>=2.4.3)"] +tests = ["black (>=23.3.0)", "flake8 (>=3.8.2)", "keras (>=2.4.3)", "mypy (>=1.3.0)", "pandas (>=1.0.5)", "pytest (>=5.0.1)", "pytest-cov (>=2.9.0)", "tensorflow (>=2.4.3)"] + [[package]] name = "ipykernel" version = "6.29.5" @@ -903,6 +927,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "joblib" +version = "1.4.2" +description = "Lightweight pipelining with Python functions" +optional = false +python-versions = ">=3.8" +files = [ + {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"}, + {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, +] + [[package]] name = "json5" version = "0.10.0" @@ -2513,6 +2548,101 @@ files = [ {file = "rpds_py-0.22.0.tar.gz", hash = "sha256:32de71c393f126d8203e9815557c7ff4d72ed1ad3aa3f52f6c7938413176750a"}, ] +[[package]] +name = "scikit-learn" +version = "1.5.2" +description = "A set of python modules for machine learning and data mining" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:299406827fb9a4f862626d0fe6c122f5f87f8910b86fe5daa4c32dcd742139b6"}, + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:2d4cad1119c77930b235579ad0dc25e65c917e756fe80cab96aa3b9428bd3fb0"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c412ccc2ad9bf3755915e3908e677b367ebc8d010acbb3f182814524f2e5540"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a686885a4b3818d9e62904d91b57fa757fc2bed3e465c8b177be652f4dd37c8"}, + {file = "scikit_learn-1.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:c15b1ca23d7c5f33cc2cb0a0d6aaacf893792271cddff0edbd6a40e8319bc113"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03b6158efa3faaf1feea3faa884c840ebd61b6484167c711548fce208ea09445"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1ff45e26928d3b4eb767a8f14a9a6efbf1cbff7c05d1fb0f95f211a89fd4f5de"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f763897fe92d0e903aa4847b0aec0e68cadfff77e8a0687cabd946c89d17e675"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8b0ccd4a902836493e026c03256e8b206656f91fbcc4fde28c57a5b752561f1"}, + {file = "scikit_learn-1.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:6c16d84a0d45e4894832b3c4d0bf73050939e21b99b01b6fd59cbb0cf39163b6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f932a02c3f4956dfb981391ab24bda1dbd90fe3d628e4b42caef3e041c67707a"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3b923d119d65b7bd555c73be5423bf06c0105678ce7e1f558cb4b40b0a5502b1"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca64b3089a6d9b9363cd3546f8978229dcbb737aceb2c12144ee3f70f95684b7"}, + {file = "scikit_learn-1.5.2-cp39-cp39-win_amd64.whl", hash = "sha256:3bed4909ba187aca80580fe2ef370d9180dcf18e621a27c4cf2ef10d279a7efe"}, + {file = "scikit_learn-1.5.2.tar.gz", hash = "sha256:b4237ed7b3fdd0a4882792e68ef2545d5baa50aca3bb45aa7df468138ad8f94d"}, +] + +[package.dependencies] +joblib = ">=1.2.0" +numpy = ">=1.19.5" +scipy = ">=1.6.0" +threadpoolctl = ">=3.1.0" + +[package.extras] +benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"] +build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-design (>=0.6.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"] +examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"] +install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"] +maintenance = ["conda-lock (==2.5.6)"] +tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] + +[[package]] +name = "scipy" +version = "1.14.1" +description = "Fundamental algorithms for scientific computing in Python" +optional = false +python-versions = ">=3.10" +files = [ + {file = "scipy-1.14.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69"}, + {file = "scipy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad"}, + {file = "scipy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2"}, + {file = "scipy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2"}, + {file = "scipy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066"}, + {file = "scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1"}, + {file = "scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e"}, + {file = "scipy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06"}, + {file = "scipy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84"}, + {file = "scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417"}, +] + +[package.dependencies] +numpy = ">=1.23.5,<2.3" + +[package.extras] +dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<=7.3.7)", "sphinx-design (>=0.4.0)"] +test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] + [[package]] name = "send2trash" version = "1.8.3" @@ -2622,6 +2752,17 @@ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"] typing = ["mypy (>=1.6,<2.0)", "traitlets (>=5.11.1)"] +[[package]] +name = "threadpoolctl" +version = "3.5.0" +description = "threadpoolctl" +optional = false +python-versions = ">=3.8" +files = [ + {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"}, + {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, +] + [[package]] name = "tinycss2" version = "1.4.0" @@ -2791,4 +2932,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "499a93cb5bb093f5378ad4a2e77f4f895221c389934fb4a15e5a5127db419128" +content-hash = "d9de29d5a54172d74c1c4d32cc992d85f9ada806a82846ab228dca34419bba41" diff --git a/pyproject.toml b/pyproject.toml index dd58b4a..012070f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ jupyter = "^1.1.1" numpy = "^2.1.0" pandas = "^2.2.2" matplotlib = "^3.9.2" +imbalanced-learn = "^0.12.3" [build-system] diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..cb8c396 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,79 @@ +from typing import Tuple + +import pandas as pd +from pandas import DataFrame +from sklearn.model_selection import train_test_split + + +def split_stratified_into_train_val_test( + df_input, + stratify_colname="y", + frac_train=0.6, + frac_val=0.15, + frac_test=0.25, + random_state=None, +) -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]: + """ + Splits a Pandas dataframe into three subsets (train, val, and test) + following fractional ratios provided by the user, where each subset is + stratified by the values in a specific column (that is, each subset has + the same relative frequency of the values in the column). It performs this + splitting by running train_test_split() twice. + + Parameters + ---------- + df_input : Pandas dataframe + Input dataframe to be split. + stratify_colname : str + The name of the column that will be used for stratification. Usually + this column would be for the label. + frac_train : float + frac_val : float + frac_test : float + The ratios with which the dataframe will be split into train, val, and + test data. The values should be expressed as float fractions and should + sum to 1.0. + random_state : int, None, or RandomStateInstance + Value to be passed to train_test_split(). + + Returns + ------- + df_train, df_val, df_test : + Dataframes containing the three splits. + """ + + if frac_train + frac_val + frac_test != 1.0: + raise ValueError( + "fractions %f, %f, %f do not add up to 1.0" + % (frac_train, frac_val, frac_test) + ) + + if stratify_colname not in df_input.columns: + raise ValueError("%s is not a column in the dataframe" % (stratify_colname)) + + X = df_input # Contains all columns. + y = df_input[ + [stratify_colname] + ] # Dataframe of just the column on which to stratify. + + # Split original dataframe into train and temp dataframes. + df_train, df_temp, y_train, y_temp = train_test_split( + X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state + ) + + if frac_val <= 0: + assert len(df_input) == len(df_train) + len(df_temp) + return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp + + # Split the temp dataframe into val and test dataframes. + relative_frac_test = frac_test / (frac_val + frac_test) + df_val, df_test, y_val, y_test = train_test_split( + df_temp, + y_temp, + stratify=y_temp, + test_size=relative_frac_test, + random_state=random_state, + ) + + assert len(df_input) == len(df_train) + len(df_val) + len(df_test) + return df_train, df_val, df_test, y_train, y_val, y_test