Compare commits
No commits in common. "b40184133948f058955e015ceb34ce4048427047" and "11632e1d53c16035da99f9eb0d6465d90bb0d5fa" have entirely different histories.
b401841339
...
11632e1d53
File diff suppressed because it is too large
Load Diff
Binary file not shown.
3061
distress.ipynb
3061
distress.ipynb
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
115
poetry.lock
generated
115
poetry.lock
generated
@ -777,30 +777,6 @@ files = [
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"]
|
all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "imbalanced-learn"
|
|
||||||
version = "0.12.4"
|
|
||||||
description = "Toolbox for imbalanced dataset in machine learning."
|
|
||||||
optional = false
|
|
||||||
python-versions = "*"
|
|
||||||
files = [
|
|
||||||
{file = "imbalanced-learn-0.12.4.tar.gz", hash = "sha256:8153ba385d296b07d97e0901a2624a86c06b48c94c2f92da3a5354827697b7a3"},
|
|
||||||
{file = "imbalanced_learn-0.12.4-py3-none-any.whl", hash = "sha256:d47fc599160d3ea882e712a3a6b02bdd353c1a6436d8d68d41b1922e6ee4a703"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
joblib = ">=1.1.1"
|
|
||||||
numpy = ">=1.17.3"
|
|
||||||
scikit-learn = ">=1.0.2"
|
|
||||||
scipy = ">=1.5.0"
|
|
||||||
threadpoolctl = ">=2.0.0"
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
docs = ["keras (>=2.4.3)", "matplotlib (>=3.1.2)", "memory-profiler (>=0.57.0)", "numpydoc (>=1.5.0)", "pandas (>=1.0.5)", "pydata-sphinx-theme (>=0.13.3)", "seaborn (>=0.9.0)", "sphinx (>=6.0.0)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-gallery (>=0.13.0)", "sphinxcontrib-bibtex (>=2.4.1)", "tensorflow (>=2.4.3)"]
|
|
||||||
examples = ["keras (>=2.4.3)", "matplotlib (>=3.1.2)", "pandas (>=1.0.5)", "seaborn (>=0.9.0)", "tensorflow (>=2.4.3)"]
|
|
||||||
optional = ["keras (>=2.4.3)", "pandas (>=1.0.5)", "tensorflow (>=2.4.3)"]
|
|
||||||
tests = ["black (>=23.3.0)", "flake8 (>=3.8.2)", "keras (>=2.4.3)", "mypy (>=1.3.0)", "pandas (>=1.0.5)", "pytest (>=5.0.1)", "pytest-cov (>=2.9.0)", "tensorflow (>=2.4.3)"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ipykernel"
|
name = "ipykernel"
|
||||||
version = "6.29.5"
|
version = "6.29.5"
|
||||||
@ -1899,23 +1875,6 @@ files = [
|
|||||||
qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"]
|
qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"]
|
||||||
testing = ["docopt", "pytest"]
|
testing = ["docopt", "pytest"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "patsy"
|
|
||||||
version = "1.0.1"
|
|
||||||
description = "A Python package for describing statistical models and for building design matrices."
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.6"
|
|
||||||
files = [
|
|
||||||
{file = "patsy-1.0.1-py2.py3-none-any.whl", hash = "sha256:751fb38f9e97e62312e921a1954b81e1bb2bcda4f5eeabaf94db251ee791509c"},
|
|
||||||
{file = "patsy-1.0.1.tar.gz", hash = "sha256:e786a9391eec818c054e359b737bbce692f051aee4c661f4141cc88fb459c0c4"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
numpy = ">=1.4"
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
test = ["pytest", "pytest-cov", "scipy"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pexpect"
|
name = "pexpect"
|
||||||
version = "4.9.0"
|
version = "4.9.0"
|
||||||
@ -2700,27 +2659,6 @@ dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodest
|
|||||||
doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<=7.3.7)", "sphinx-design (>=0.4.0)"]
|
doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<=7.3.7)", "sphinx-design (>=0.4.0)"]
|
||||||
test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
|
test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "seaborn"
|
|
||||||
version = "0.13.2"
|
|
||||||
description = "Statistical data visualization"
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.8"
|
|
||||||
files = [
|
|
||||||
{file = "seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987"},
|
|
||||||
{file = "seaborn-0.13.2.tar.gz", hash = "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
matplotlib = ">=3.4,<3.6.1 || >3.6.1"
|
|
||||||
numpy = ">=1.20,<1.24.0 || >1.24.0"
|
|
||||||
pandas = ">=1.2"
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
dev = ["flake8", "flit", "mypy", "pandas-stubs", "pre-commit", "pytest", "pytest-cov", "pytest-xdist"]
|
|
||||||
docs = ["ipykernel", "nbconvert", "numpydoc", "pydata_sphinx_theme (==0.10.0rc2)", "pyyaml", "sphinx (<6.0.0)", "sphinx-copybutton", "sphinx-design", "sphinx-issues"]
|
|
||||||
stats = ["scipy (>=1.7)", "statsmodels (>=0.12)"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "send2trash"
|
name = "send2trash"
|
||||||
version = "1.8.3"
|
version = "1.8.3"
|
||||||
@ -2809,57 +2747,6 @@ pure-eval = "*"
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
|
tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "statsmodels"
|
|
||||||
version = "0.14.4"
|
|
||||||
description = "Statistical computations and models for Python"
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.9"
|
|
||||||
files = [
|
|
||||||
{file = "statsmodels-0.14.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7a62f1fc9086e4b7ee789a6f66b3c0fc82dd8de1edda1522d30901a0aa45e42b"},
|
|
||||||
{file = "statsmodels-0.14.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:46ac7ddefac0c9b7b607eed1d47d11e26fe92a1bc1f4d9af48aeed4e21e87981"},
|
|
||||||
{file = "statsmodels-0.14.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a337b731aa365d09bb0eab6da81446c04fde6c31976b1d8e3d3a911f0f1e07b"},
|
|
||||||
{file = "statsmodels-0.14.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:631bb52159117c5da42ba94bd94859276b68cab25dc4cac86475bc24671143bc"},
|
|
||||||
{file = "statsmodels-0.14.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3bb2e580d382545a65f298589809af29daeb15f9da2eb252af8f79693e618abc"},
|
|
||||||
{file = "statsmodels-0.14.4-cp310-cp310-win_amd64.whl", hash = "sha256:9729642884147ee9db67b5a06a355890663d21f76ed608a56ac2ad98b94d201a"},
|
|
||||||
{file = "statsmodels-0.14.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5ed7e118e6e3e02d6723a079b8c97eaadeed943fa1f7f619f7148dfc7862670f"},
|
|
||||||
{file = "statsmodels-0.14.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f5f537f7d000de4a1708c63400755152b862cd4926bb81a86568e347c19c364b"},
|
|
||||||
{file = "statsmodels-0.14.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa74aaa26eaa5012b0a01deeaa8a777595d0835d3d6c7175f2ac65435a7324d2"},
|
|
||||||
{file = "statsmodels-0.14.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e332c2d9b806083d1797231280602340c5c913f90d4caa0213a6a54679ce9331"},
|
|
||||||
{file = "statsmodels-0.14.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d9c8fa28dfd75753d9cf62769ba1fecd7e73a0be187f35cc6f54076f98aa3f3f"},
|
|
||||||
{file = "statsmodels-0.14.4-cp311-cp311-win_amd64.whl", hash = "sha256:a6087ecb0714f7c59eb24c22781491e6f1cfffb660b4740e167625ca4f052056"},
|
|
||||||
{file = "statsmodels-0.14.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5221dba7424cf4f2561b22e9081de85f5bb871228581124a0d1b572708545199"},
|
|
||||||
{file = "statsmodels-0.14.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:17672b30c6b98afe2b095591e32d1d66d4372f2651428e433f16a3667f19eabb"},
|
|
||||||
{file = "statsmodels-0.14.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab5e6312213b8cfb9dca93dd46a0f4dccb856541f91d3306227c3d92f7659245"},
|
|
||||||
{file = "statsmodels-0.14.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4bbb150620b53133d6cd1c5d14c28a4f85701e6c781d9b689b53681effaa655f"},
|
|
||||||
{file = "statsmodels-0.14.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:bb695c2025d122a101c2aca66d2b78813c321b60d3a7c86bb8ec4467bb53b0f9"},
|
|
||||||
{file = "statsmodels-0.14.4-cp312-cp312-win_amd64.whl", hash = "sha256:7f7917a51766b4e074da283c507a25048ad29a18e527207883d73535e0dc6184"},
|
|
||||||
{file = "statsmodels-0.14.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b5a24f5d2c22852d807d2b42daf3a61740820b28d8381daaf59dcb7055bf1a79"},
|
|
||||||
{file = "statsmodels-0.14.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:df4f7864606fa843d7e7c0e6af288f034a2160dba14e6ccc09020a3cf67cb092"},
|
|
||||||
{file = "statsmodels-0.14.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91341cbde9e8bea5fb419a76e09114e221567d03f34ca26e6d67ae2c27d8fe3c"},
|
|
||||||
{file = "statsmodels-0.14.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1322286a7bfdde2790bf72d29698a1b76c20b8423a55bdcd0d457969d0041f72"},
|
|
||||||
{file = "statsmodels-0.14.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e31b95ac603415887c9f0d344cb523889cf779bc52d68e27e2d23c358958fec7"},
|
|
||||||
{file = "statsmodels-0.14.4-cp313-cp313-win_amd64.whl", hash = "sha256:81030108d27aecc7995cac05aa280cf8c6025f6a6119894eef648997936c2dd0"},
|
|
||||||
{file = "statsmodels-0.14.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4793b01b7a5f5424f5a1dbcefc614c83c7608aa2b035f087538253007c339d5d"},
|
|
||||||
{file = "statsmodels-0.14.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d330da34f59f1653c5193f9fe3a3a258977c880746db7f155fc33713ea858db5"},
|
|
||||||
{file = "statsmodels-0.14.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6e9ddefba1d4e1107c1f20f601b0581421ea3ad9fd75ce3c2ba6a76b6dc4682c"},
|
|
||||||
{file = "statsmodels-0.14.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f43da7957e00190104c5dd0f661bfc6dfc68b87313e3f9c4dbd5e7d222e0aeb"},
|
|
||||||
{file = "statsmodels-0.14.4-cp39-cp39-win_amd64.whl", hash = "sha256:8286f69a5e1d0e0b366ffed5691140c83d3efc75da6dbf34a3d06e88abfaaab6"},
|
|
||||||
{file = "statsmodels-0.14.4.tar.gz", hash = "sha256:5d69e0f39060dc72c067f9bb6e8033b6dccdb0bae101d76a7ef0bcc94e898b67"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
numpy = ">=1.22.3,<3"
|
|
||||||
packaging = ">=21.3"
|
|
||||||
pandas = ">=1.4,<2.1.0 || >2.1.0"
|
|
||||||
patsy = ">=0.5.6"
|
|
||||||
scipy = ">=1.8,<1.9.2 || >1.9.2"
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
build = ["cython (>=3.0.10)"]
|
|
||||||
develop = ["colorama", "cython (>=3.0.10)", "cython (>=3.0.10,<4)", "flake8", "isort", "joblib", "matplotlib (>=3)", "pytest (>=7.3.0,<8)", "pytest-cov", "pytest-randomly", "pytest-xdist", "pywinpty", "setuptools-scm[toml] (>=8.0,<9.0)"]
|
|
||||||
docs = ["ipykernel", "jupyter-client", "matplotlib", "nbconvert", "nbformat", "numpydoc", "pandas-datareader", "sphinx"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "terminado"
|
name = "terminado"
|
||||||
version = "0.18.1"
|
version = "0.18.1"
|
||||||
@ -3065,4 +2952,4 @@ files = [
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.12"
|
python-versions = "^3.12"
|
||||||
content-hash = "f75f65af68bf9b9ae845851b540a182e59667b51060659ab9f4a28bb09ff9880"
|
content-hash = "ee2a6fb525a389426dad97d7f57b9ed7a8187d5892d01cc39a54830695f83a95"
|
||||||
|
@ -15,9 +15,6 @@ matplotlib = "^3.9.2"
|
|||||||
scikit-learn = "^1.5.2"
|
scikit-learn = "^1.5.2"
|
||||||
scikit-fuzzy = "^0.5.0"
|
scikit-fuzzy = "^0.5.0"
|
||||||
networkx = "^3.4.2"
|
networkx = "^3.4.2"
|
||||||
imbalanced-learn = "^0.12.3"
|
|
||||||
seaborn = "^0.13.2"
|
|
||||||
statsmodels = "^0.14.4"
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
137
src/utils.py
137
src/utils.py
@ -1,137 +0,0 @@
|
|||||||
import math
|
|
||||||
from typing import Dict, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from pandas import DataFrame
|
|
||||||
from sklearn import metrics
|
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
from sklearn.pipeline import Pipeline
|
|
||||||
|
|
||||||
|
|
||||||
def split_stratified_into_train_val_test(
|
|
||||||
df_input,
|
|
||||||
stratify_colname="y",
|
|
||||||
frac_train=0.6,
|
|
||||||
frac_val=0.15,
|
|
||||||
frac_test=0.25,
|
|
||||||
random_state=None,
|
|
||||||
) -> Tuple:
|
|
||||||
"""
|
|
||||||
Splits a Pandas dataframe into three subsets (train, val, and test)
|
|
||||||
following fractional ratios provided by the user, where each subset is
|
|
||||||
stratified by the values in a specific column (that is, each subset has
|
|
||||||
the same relative frequency of the values in the column). It performs this
|
|
||||||
splitting by running train_test_split() twice.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
df_input : Pandas dataframe
|
|
||||||
Input dataframe to be split.
|
|
||||||
stratify_colname : str
|
|
||||||
The name of the column that will be used for stratification. Usually
|
|
||||||
this column would be for the label.
|
|
||||||
frac_train : float
|
|
||||||
frac_val : float
|
|
||||||
frac_test : float
|
|
||||||
The ratios with which the dataframe will be split into train, val, and
|
|
||||||
test data. The values should be expressed as float fractions and should
|
|
||||||
sum to 1.0.
|
|
||||||
random_state : int, None, or RandomStateInstance
|
|
||||||
Value to be passed to train_test_split().
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
df_train, df_val, df_test :
|
|
||||||
Dataframes containing the three splits.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if frac_train + frac_val + frac_test != 1.0:
|
|
||||||
raise ValueError(
|
|
||||||
"fractions %f, %f, %f do not add up to 1.0"
|
|
||||||
% (frac_train, frac_val, frac_test)
|
|
||||||
)
|
|
||||||
|
|
||||||
if stratify_colname not in df_input.columns:
|
|
||||||
raise ValueError("%s is not a column in the dataframe" % (stratify_colname))
|
|
||||||
|
|
||||||
# Contains all columns.
|
|
||||||
X = df_input.drop([stratify_colname], axis=1)
|
|
||||||
# Dataframe of just the column on which to stratify.
|
|
||||||
y = df_input[[stratify_colname]]
|
|
||||||
|
|
||||||
# Split original dataframe into train and temp dataframes.
|
|
||||||
df_train, df_temp, y_train, y_temp = train_test_split(
|
|
||||||
X, y, test_size=(1.0 - frac_train), random_state=random_state
|
|
||||||
)
|
|
||||||
|
|
||||||
if frac_val <= 0:
|
|
||||||
assert len(df_input) == len(df_train) + len(df_temp)
|
|
||||||
return df_train, df_temp, y_train, y_temp
|
|
||||||
|
|
||||||
# Split the temp dataframe into val and test dataframes.
|
|
||||||
relative_frac_test = frac_test / (frac_val + frac_test)
|
|
||||||
df_val, df_test, y_val, y_test = train_test_split(
|
|
||||||
df_temp,
|
|
||||||
y_temp,
|
|
||||||
stratify=y_temp,
|
|
||||||
test_size=relative_frac_test,
|
|
||||||
random_state=random_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(df_input) == len(df_train) + len(df_val) + len(df_test)
|
|
||||||
return df_train, df_val, df_test, y_train, y_val, y_test
|
|
||||||
|
|
||||||
|
|
||||||
def run_classification(
|
|
||||||
model: Pipeline,
|
|
||||||
X_train: DataFrame,
|
|
||||||
X_test: DataFrame,
|
|
||||||
y_train: DataFrame,
|
|
||||||
y_test: DataFrame,
|
|
||||||
) -> Dict:
|
|
||||||
result = {}
|
|
||||||
y_train_predict = model.predict(X_train)
|
|
||||||
y_test_probs = model.predict_proba(X_test)[:, 0]
|
|
||||||
y_test_predict = np.where(y_test_probs > 0.5, 1, 0)
|
|
||||||
|
|
||||||
result["pipeline"] = model
|
|
||||||
result["probs"] = y_test_probs
|
|
||||||
result["preds"] = y_test_predict
|
|
||||||
|
|
||||||
result["Precision_train"] = metrics.precision_score(y_train, y_train_predict)
|
|
||||||
result["Precision_test"] = metrics.precision_score(y_test, y_test_predict)
|
|
||||||
result["Recall_train"] = metrics.recall_score(y_train, y_train_predict)
|
|
||||||
result["Recall_test"] = metrics.recall_score(y_test, y_test_predict)
|
|
||||||
result["Accuracy_train"] = metrics.accuracy_score(y_train, y_train_predict)
|
|
||||||
result["Accuracy_test"] = metrics.accuracy_score(y_test, y_test_predict)
|
|
||||||
result["ROC_AUC_test"] = metrics.roc_auc_score(y_test, y_test_probs)
|
|
||||||
result["F1_train"] = metrics.f1_score(y_train, y_train_predict)
|
|
||||||
result["F1_test"] = metrics.f1_score(y_test, y_test_predict)
|
|
||||||
result["MCC_test"] = metrics.matthews_corrcoef(y_test, y_test_predict)
|
|
||||||
result["Cohen_kappa_test"] = metrics.cohen_kappa_score(y_test, y_test_predict)
|
|
||||||
result["Confusion_matrix"] = metrics.confusion_matrix(y_test, y_test_predict)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def run_regression(
|
|
||||||
model: Pipeline,
|
|
||||||
X_train: DataFrame,
|
|
||||||
X_test: DataFrame,
|
|
||||||
y_train: DataFrame,
|
|
||||||
y_test: DataFrame,
|
|
||||||
) -> Dict:
|
|
||||||
result = {}
|
|
||||||
y_train_pred = model.predict(X_train.values)
|
|
||||||
y_test_pred = model.predict(X_test.values)
|
|
||||||
|
|
||||||
result["fitted"] = model
|
|
||||||
result["train_preds"] = y_train_pred
|
|
||||||
result["preds"] = y_test_pred
|
|
||||||
|
|
||||||
result["RMSE_train"] = math.sqrt(metrics.mean_squared_error(y_train, y_train_pred))
|
|
||||||
result["RMSE_test"] = math.sqrt(metrics.mean_squared_error(y_test, y_test_pred))
|
|
||||||
result["RMAE_test"] = math.sqrt(metrics.mean_absolute_error(y_test, y_test_pred))
|
|
||||||
result["R2_test"] = metrics.r2_score(y_test, y_test_pred)
|
|
||||||
|
|
||||||
return result
|
|
Loading…
x
Reference in New Issue
Block a user