Lec2
This commit is contained in:
parent
c333fdd365
commit
cecf94809b
BIN
assets/lec2-split.png
Normal file
BIN
assets/lec2-split.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 63 KiB |
1187
lec2.ipynb
Normal file
1187
lec2.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
143
poetry.lock
generated
143
poetry.lock
generated
@ -763,6 +763,30 @@ files = [
|
||||
[package.extras]
|
||||
all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "imbalanced-learn"
|
||||
version = "0.12.4"
|
||||
description = "Toolbox for imbalanced dataset in machine learning."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "imbalanced-learn-0.12.4.tar.gz", hash = "sha256:8153ba385d296b07d97e0901a2624a86c06b48c94c2f92da3a5354827697b7a3"},
|
||||
{file = "imbalanced_learn-0.12.4-py3-none-any.whl", hash = "sha256:d47fc599160d3ea882e712a3a6b02bdd353c1a6436d8d68d41b1922e6ee4a703"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
joblib = ">=1.1.1"
|
||||
numpy = ">=1.17.3"
|
||||
scikit-learn = ">=1.0.2"
|
||||
scipy = ">=1.5.0"
|
||||
threadpoolctl = ">=2.0.0"
|
||||
|
||||
[package.extras]
|
||||
docs = ["keras (>=2.4.3)", "matplotlib (>=3.1.2)", "memory-profiler (>=0.57.0)", "numpydoc (>=1.5.0)", "pandas (>=1.0.5)", "pydata-sphinx-theme (>=0.13.3)", "seaborn (>=0.9.0)", "sphinx (>=6.0.0)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-gallery (>=0.13.0)", "sphinxcontrib-bibtex (>=2.4.1)", "tensorflow (>=2.4.3)"]
|
||||
examples = ["keras (>=2.4.3)", "matplotlib (>=3.1.2)", "pandas (>=1.0.5)", "seaborn (>=0.9.0)", "tensorflow (>=2.4.3)"]
|
||||
optional = ["keras (>=2.4.3)", "pandas (>=1.0.5)", "tensorflow (>=2.4.3)"]
|
||||
tests = ["black (>=23.3.0)", "flake8 (>=3.8.2)", "keras (>=2.4.3)", "mypy (>=1.3.0)", "pandas (>=1.0.5)", "pytest (>=5.0.1)", "pytest-cov (>=2.9.0)", "tensorflow (>=2.4.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "ipykernel"
|
||||
version = "6.29.5"
|
||||
@ -903,6 +927,17 @@ MarkupSafe = ">=2.0"
|
||||
[package.extras]
|
||||
i18n = ["Babel (>=2.7)"]
|
||||
|
||||
[[package]]
|
||||
name = "joblib"
|
||||
version = "1.4.2"
|
||||
description = "Lightweight pipelining with Python functions"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"},
|
||||
{file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "json5"
|
||||
version = "0.10.0"
|
||||
@ -2513,6 +2548,101 @@ files = [
|
||||
{file = "rpds_py-0.22.0.tar.gz", hash = "sha256:32de71c393f126d8203e9815557c7ff4d72ed1ad3aa3f52f6c7938413176750a"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scikit-learn"
|
||||
version = "1.5.2"
|
||||
description = "A set of python modules for machine learning and data mining"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "scikit_learn-1.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:299406827fb9a4f862626d0fe6c122f5f87f8910b86fe5daa4c32dcd742139b6"},
|
||||
{file = "scikit_learn-1.5.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:2d4cad1119c77930b235579ad0dc25e65c917e756fe80cab96aa3b9428bd3fb0"},
|
||||
{file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c412ccc2ad9bf3755915e3908e677b367ebc8d010acbb3f182814524f2e5540"},
|
||||
{file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a686885a4b3818d9e62904d91b57fa757fc2bed3e465c8b177be652f4dd37c8"},
|
||||
{file = "scikit_learn-1.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:c15b1ca23d7c5f33cc2cb0a0d6aaacf893792271cddff0edbd6a40e8319bc113"},
|
||||
{file = "scikit_learn-1.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03b6158efa3faaf1feea3faa884c840ebd61b6484167c711548fce208ea09445"},
|
||||
{file = "scikit_learn-1.5.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1ff45e26928d3b4eb767a8f14a9a6efbf1cbff7c05d1fb0f95f211a89fd4f5de"},
|
||||
{file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f763897fe92d0e903aa4847b0aec0e68cadfff77e8a0687cabd946c89d17e675"},
|
||||
{file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8b0ccd4a902836493e026c03256e8b206656f91fbcc4fde28c57a5b752561f1"},
|
||||
{file = "scikit_learn-1.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:6c16d84a0d45e4894832b3c4d0bf73050939e21b99b01b6fd59cbb0cf39163b6"},
|
||||
{file = "scikit_learn-1.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f932a02c3f4956dfb981391ab24bda1dbd90fe3d628e4b42caef3e041c67707a"},
|
||||
{file = "scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3b923d119d65b7bd555c73be5423bf06c0105678ce7e1f558cb4b40b0a5502b1"},
|
||||
{file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"},
|
||||
{file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"},
|
||||
{file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"},
|
||||
{file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"},
|
||||
{file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"},
|
||||
{file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"},
|
||||
{file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca64b3089a6d9b9363cd3546f8978229dcbb737aceb2c12144ee3f70f95684b7"},
|
||||
{file = "scikit_learn-1.5.2-cp39-cp39-win_amd64.whl", hash = "sha256:3bed4909ba187aca80580fe2ef370d9180dcf18e621a27c4cf2ef10d279a7efe"},
|
||||
{file = "scikit_learn-1.5.2.tar.gz", hash = "sha256:b4237ed7b3fdd0a4882792e68ef2545d5baa50aca3bb45aa7df468138ad8f94d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
joblib = ">=1.2.0"
|
||||
numpy = ">=1.19.5"
|
||||
scipy = ">=1.6.0"
|
||||
threadpoolctl = ">=3.1.0"
|
||||
|
||||
[package.extras]
|
||||
benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"]
|
||||
build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"]
|
||||
docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-design (>=0.6.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"]
|
||||
examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"]
|
||||
install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"]
|
||||
maintenance = ["conda-lock (==2.5.6)"]
|
||||
tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "scipy"
|
||||
version = "1.14.1"
|
||||
description = "Fundamental algorithms for scientific computing in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
files = [
|
||||
{file = "scipy-1.14.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389"},
|
||||
{file = "scipy-1.14.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3"},
|
||||
{file = "scipy-1.14.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0"},
|
||||
{file = "scipy-1.14.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3"},
|
||||
{file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d"},
|
||||
{file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69"},
|
||||
{file = "scipy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad"},
|
||||
{file = "scipy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5"},
|
||||
{file = "scipy-1.14.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675"},
|
||||
{file = "scipy-1.14.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2"},
|
||||
{file = "scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617"},
|
||||
{file = "scipy-1.14.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8"},
|
||||
{file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37"},
|
||||
{file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2"},
|
||||
{file = "scipy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2"},
|
||||
{file = "scipy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94"},
|
||||
{file = "scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d"},
|
||||
{file = "scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07"},
|
||||
{file = "scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5"},
|
||||
{file = "scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc"},
|
||||
{file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310"},
|
||||
{file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066"},
|
||||
{file = "scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1"},
|
||||
{file = "scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f"},
|
||||
{file = "scipy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79"},
|
||||
{file = "scipy-1.14.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e"},
|
||||
{file = "scipy-1.14.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73"},
|
||||
{file = "scipy-1.14.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e"},
|
||||
{file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d"},
|
||||
{file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e"},
|
||||
{file = "scipy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06"},
|
||||
{file = "scipy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84"},
|
||||
{file = "scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = ">=1.23.5,<2.3"
|
||||
|
||||
[package.extras]
|
||||
dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"]
|
||||
doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<=7.3.7)", "sphinx-design (>=0.4.0)"]
|
||||
test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
|
||||
|
||||
[[package]]
|
||||
name = "send2trash"
|
||||
version = "1.8.3"
|
||||
@ -2622,6 +2752,17 @@ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
|
||||
test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"]
|
||||
typing = ["mypy (>=1.6,<2.0)", "traitlets (>=5.11.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "threadpoolctl"
|
||||
version = "3.5.0"
|
||||
description = "threadpoolctl"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"},
|
||||
{file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinycss2"
|
||||
version = "1.4.0"
|
||||
@ -2791,4 +2932,4 @@ files = [
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.12"
|
||||
content-hash = "499a93cb5bb093f5378ad4a2e77f4f895221c389934fb4a15e5a5127db419128"
|
||||
content-hash = "d9de29d5a54172d74c1c4d32cc992d85f9ada806a82846ab228dca34419bba41"
|
||||
|
@ -12,6 +12,7 @@ jupyter = "^1.1.1"
|
||||
numpy = "^2.1.0"
|
||||
pandas = "^2.2.2"
|
||||
matplotlib = "^3.9.2"
|
||||
imbalanced-learn = "^0.12.3"
|
||||
|
||||
|
||||
[build-system]
|
||||
|
79
src/utils.py
Normal file
79
src/utils.py
Normal file
@ -0,0 +1,79 @@
|
||||
from typing import Tuple
|
||||
|
||||
import pandas as pd
|
||||
from pandas import DataFrame
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
def split_stratified_into_train_val_test(
|
||||
df_input,
|
||||
stratify_colname="y",
|
||||
frac_train=0.6,
|
||||
frac_val=0.15,
|
||||
frac_test=0.25,
|
||||
random_state=None,
|
||||
) -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:
|
||||
"""
|
||||
Splits a Pandas dataframe into three subsets (train, val, and test)
|
||||
following fractional ratios provided by the user, where each subset is
|
||||
stratified by the values in a specific column (that is, each subset has
|
||||
the same relative frequency of the values in the column). It performs this
|
||||
splitting by running train_test_split() twice.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df_input : Pandas dataframe
|
||||
Input dataframe to be split.
|
||||
stratify_colname : str
|
||||
The name of the column that will be used for stratification. Usually
|
||||
this column would be for the label.
|
||||
frac_train : float
|
||||
frac_val : float
|
||||
frac_test : float
|
||||
The ratios with which the dataframe will be split into train, val, and
|
||||
test data. The values should be expressed as float fractions and should
|
||||
sum to 1.0.
|
||||
random_state : int, None, or RandomStateInstance
|
||||
Value to be passed to train_test_split().
|
||||
|
||||
Returns
|
||||
-------
|
||||
df_train, df_val, df_test :
|
||||
Dataframes containing the three splits.
|
||||
"""
|
||||
|
||||
if frac_train + frac_val + frac_test != 1.0:
|
||||
raise ValueError(
|
||||
"fractions %f, %f, %f do not add up to 1.0"
|
||||
% (frac_train, frac_val, frac_test)
|
||||
)
|
||||
|
||||
if stratify_colname not in df_input.columns:
|
||||
raise ValueError("%s is not a column in the dataframe" % (stratify_colname))
|
||||
|
||||
X = df_input # Contains all columns.
|
||||
y = df_input[
|
||||
[stratify_colname]
|
||||
] # Dataframe of just the column on which to stratify.
|
||||
|
||||
# Split original dataframe into train and temp dataframes.
|
||||
df_train, df_temp, y_train, y_temp = train_test_split(
|
||||
X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state
|
||||
)
|
||||
|
||||
if frac_val <= 0:
|
||||
assert len(df_input) == len(df_train) + len(df_temp)
|
||||
return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp
|
||||
|
||||
# Split the temp dataframe into val and test dataframes.
|
||||
relative_frac_test = frac_test / (frac_val + frac_test)
|
||||
df_val, df_test, y_val, y_test = train_test_split(
|
||||
df_temp,
|
||||
y_temp,
|
||||
stratify=y_temp,
|
||||
test_size=relative_frac_test,
|
||||
random_state=random_state,
|
||||
)
|
||||
|
||||
assert len(df_input) == len(df_train) + len(df_val) + len(df_test)
|
||||
return df_train, df_val, df_test, y_train, y_val, y_test
|
Loading…
Reference in New Issue
Block a user