Add distress experiment
This commit is contained in:
parent
4ddb6e77a9
commit
5c948c5ece
3673
data-distress/FinancialDistress.csv
Normal file
3673
data-distress/FinancialDistress.csv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
data-distress/tree.model.sav
Normal file
BIN
data-distress/tree.model.sav
Normal file
Binary file not shown.
3061
distress.ipynb
Normal file
3061
distress.ipynb
Normal file
File diff suppressed because one or more lines are too long
1501
distress_regression.ipynb
Normal file
1501
distress_regression.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
115
poetry.lock
generated
115
poetry.lock
generated
@ -777,6 +777,30 @@ files = [
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"]
|
all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "imbalanced-learn"
|
||||||
|
version = "0.12.4"
|
||||||
|
description = "Toolbox for imbalanced dataset in machine learning."
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "imbalanced-learn-0.12.4.tar.gz", hash = "sha256:8153ba385d296b07d97e0901a2624a86c06b48c94c2f92da3a5354827697b7a3"},
|
||||||
|
{file = "imbalanced_learn-0.12.4-py3-none-any.whl", hash = "sha256:d47fc599160d3ea882e712a3a6b02bdd353c1a6436d8d68d41b1922e6ee4a703"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
joblib = ">=1.1.1"
|
||||||
|
numpy = ">=1.17.3"
|
||||||
|
scikit-learn = ">=1.0.2"
|
||||||
|
scipy = ">=1.5.0"
|
||||||
|
threadpoolctl = ">=2.0.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
docs = ["keras (>=2.4.3)", "matplotlib (>=3.1.2)", "memory-profiler (>=0.57.0)", "numpydoc (>=1.5.0)", "pandas (>=1.0.5)", "pydata-sphinx-theme (>=0.13.3)", "seaborn (>=0.9.0)", "sphinx (>=6.0.0)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-gallery (>=0.13.0)", "sphinxcontrib-bibtex (>=2.4.1)", "tensorflow (>=2.4.3)"]
|
||||||
|
examples = ["keras (>=2.4.3)", "matplotlib (>=3.1.2)", "pandas (>=1.0.5)", "seaborn (>=0.9.0)", "tensorflow (>=2.4.3)"]
|
||||||
|
optional = ["keras (>=2.4.3)", "pandas (>=1.0.5)", "tensorflow (>=2.4.3)"]
|
||||||
|
tests = ["black (>=23.3.0)", "flake8 (>=3.8.2)", "keras (>=2.4.3)", "mypy (>=1.3.0)", "pandas (>=1.0.5)", "pytest (>=5.0.1)", "pytest-cov (>=2.9.0)", "tensorflow (>=2.4.3)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ipykernel"
|
name = "ipykernel"
|
||||||
version = "6.29.5"
|
version = "6.29.5"
|
||||||
@ -1875,6 +1899,23 @@ files = [
|
|||||||
qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"]
|
qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"]
|
||||||
testing = ["docopt", "pytest"]
|
testing = ["docopt", "pytest"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "patsy"
|
||||||
|
version = "1.0.1"
|
||||||
|
description = "A Python package for describing statistical models and for building design matrices."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6"
|
||||||
|
files = [
|
||||||
|
{file = "patsy-1.0.1-py2.py3-none-any.whl", hash = "sha256:751fb38f9e97e62312e921a1954b81e1bb2bcda4f5eeabaf94db251ee791509c"},
|
||||||
|
{file = "patsy-1.0.1.tar.gz", hash = "sha256:e786a9391eec818c054e359b737bbce692f051aee4c661f4141cc88fb459c0c4"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
numpy = ">=1.4"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
test = ["pytest", "pytest-cov", "scipy"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pexpect"
|
name = "pexpect"
|
||||||
version = "4.9.0"
|
version = "4.9.0"
|
||||||
@ -2659,6 +2700,27 @@ dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodest
|
|||||||
doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<=7.3.7)", "sphinx-design (>=0.4.0)"]
|
doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<=7.3.7)", "sphinx-design (>=0.4.0)"]
|
||||||
test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
|
test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "seaborn"
|
||||||
|
version = "0.13.2"
|
||||||
|
description = "Statistical data visualization"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987"},
|
||||||
|
{file = "seaborn-0.13.2.tar.gz", hash = "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
matplotlib = ">=3.4,<3.6.1 || >3.6.1"
|
||||||
|
numpy = ">=1.20,<1.24.0 || >1.24.0"
|
||||||
|
pandas = ">=1.2"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["flake8", "flit", "mypy", "pandas-stubs", "pre-commit", "pytest", "pytest-cov", "pytest-xdist"]
|
||||||
|
docs = ["ipykernel", "nbconvert", "numpydoc", "pydata_sphinx_theme (==0.10.0rc2)", "pyyaml", "sphinx (<6.0.0)", "sphinx-copybutton", "sphinx-design", "sphinx-issues"]
|
||||||
|
stats = ["scipy (>=1.7)", "statsmodels (>=0.12)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "send2trash"
|
name = "send2trash"
|
||||||
version = "1.8.3"
|
version = "1.8.3"
|
||||||
@ -2747,6 +2809,57 @@ pure-eval = "*"
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
|
tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "statsmodels"
|
||||||
|
version = "0.14.4"
|
||||||
|
description = "Statistical computations and models for Python"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.9"
|
||||||
|
files = [
|
||||||
|
{file = "statsmodels-0.14.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7a62f1fc9086e4b7ee789a6f66b3c0fc82dd8de1edda1522d30901a0aa45e42b"},
|
||||||
|
{file = "statsmodels-0.14.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:46ac7ddefac0c9b7b607eed1d47d11e26fe92a1bc1f4d9af48aeed4e21e87981"},
|
||||||
|
{file = "statsmodels-0.14.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a337b731aa365d09bb0eab6da81446c04fde6c31976b1d8e3d3a911f0f1e07b"},
|
||||||
|
{file = "statsmodels-0.14.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:631bb52159117c5da42ba94bd94859276b68cab25dc4cac86475bc24671143bc"},
|
||||||
|
{file = "statsmodels-0.14.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3bb2e580d382545a65f298589809af29daeb15f9da2eb252af8f79693e618abc"},
|
||||||
|
{file = "statsmodels-0.14.4-cp310-cp310-win_amd64.whl", hash = "sha256:9729642884147ee9db67b5a06a355890663d21f76ed608a56ac2ad98b94d201a"},
|
||||||
|
{file = "statsmodels-0.14.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5ed7e118e6e3e02d6723a079b8c97eaadeed943fa1f7f619f7148dfc7862670f"},
|
||||||
|
{file = "statsmodels-0.14.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f5f537f7d000de4a1708c63400755152b862cd4926bb81a86568e347c19c364b"},
|
||||||
|
{file = "statsmodels-0.14.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa74aaa26eaa5012b0a01deeaa8a777595d0835d3d6c7175f2ac65435a7324d2"},
|
||||||
|
{file = "statsmodels-0.14.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e332c2d9b806083d1797231280602340c5c913f90d4caa0213a6a54679ce9331"},
|
||||||
|
{file = "statsmodels-0.14.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d9c8fa28dfd75753d9cf62769ba1fecd7e73a0be187f35cc6f54076f98aa3f3f"},
|
||||||
|
{file = "statsmodels-0.14.4-cp311-cp311-win_amd64.whl", hash = "sha256:a6087ecb0714f7c59eb24c22781491e6f1cfffb660b4740e167625ca4f052056"},
|
||||||
|
{file = "statsmodels-0.14.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5221dba7424cf4f2561b22e9081de85f5bb871228581124a0d1b572708545199"},
|
||||||
|
{file = "statsmodels-0.14.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:17672b30c6b98afe2b095591e32d1d66d4372f2651428e433f16a3667f19eabb"},
|
||||||
|
{file = "statsmodels-0.14.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab5e6312213b8cfb9dca93dd46a0f4dccb856541f91d3306227c3d92f7659245"},
|
||||||
|
{file = "statsmodels-0.14.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4bbb150620b53133d6cd1c5d14c28a4f85701e6c781d9b689b53681effaa655f"},
|
||||||
|
{file = "statsmodels-0.14.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:bb695c2025d122a101c2aca66d2b78813c321b60d3a7c86bb8ec4467bb53b0f9"},
|
||||||
|
{file = "statsmodels-0.14.4-cp312-cp312-win_amd64.whl", hash = "sha256:7f7917a51766b4e074da283c507a25048ad29a18e527207883d73535e0dc6184"},
|
||||||
|
{file = "statsmodels-0.14.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b5a24f5d2c22852d807d2b42daf3a61740820b28d8381daaf59dcb7055bf1a79"},
|
||||||
|
{file = "statsmodels-0.14.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:df4f7864606fa843d7e7c0e6af288f034a2160dba14e6ccc09020a3cf67cb092"},
|
||||||
|
{file = "statsmodels-0.14.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91341cbde9e8bea5fb419a76e09114e221567d03f34ca26e6d67ae2c27d8fe3c"},
|
||||||
|
{file = "statsmodels-0.14.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1322286a7bfdde2790bf72d29698a1b76c20b8423a55bdcd0d457969d0041f72"},
|
||||||
|
{file = "statsmodels-0.14.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e31b95ac603415887c9f0d344cb523889cf779bc52d68e27e2d23c358958fec7"},
|
||||||
|
{file = "statsmodels-0.14.4-cp313-cp313-win_amd64.whl", hash = "sha256:81030108d27aecc7995cac05aa280cf8c6025f6a6119894eef648997936c2dd0"},
|
||||||
|
{file = "statsmodels-0.14.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4793b01b7a5f5424f5a1dbcefc614c83c7608aa2b035f087538253007c339d5d"},
|
||||||
|
{file = "statsmodels-0.14.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d330da34f59f1653c5193f9fe3a3a258977c880746db7f155fc33713ea858db5"},
|
||||||
|
{file = "statsmodels-0.14.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6e9ddefba1d4e1107c1f20f601b0581421ea3ad9fd75ce3c2ba6a76b6dc4682c"},
|
||||||
|
{file = "statsmodels-0.14.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f43da7957e00190104c5dd0f661bfc6dfc68b87313e3f9c4dbd5e7d222e0aeb"},
|
||||||
|
{file = "statsmodels-0.14.4-cp39-cp39-win_amd64.whl", hash = "sha256:8286f69a5e1d0e0b366ffed5691140c83d3efc75da6dbf34a3d06e88abfaaab6"},
|
||||||
|
{file = "statsmodels-0.14.4.tar.gz", hash = "sha256:5d69e0f39060dc72c067f9bb6e8033b6dccdb0bae101d76a7ef0bcc94e898b67"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
numpy = ">=1.22.3,<3"
|
||||||
|
packaging = ">=21.3"
|
||||||
|
pandas = ">=1.4,<2.1.0 || >2.1.0"
|
||||||
|
patsy = ">=0.5.6"
|
||||||
|
scipy = ">=1.8,<1.9.2 || >1.9.2"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
build = ["cython (>=3.0.10)"]
|
||||||
|
develop = ["colorama", "cython (>=3.0.10)", "cython (>=3.0.10,<4)", "flake8", "isort", "joblib", "matplotlib (>=3)", "pytest (>=7.3.0,<8)", "pytest-cov", "pytest-randomly", "pytest-xdist", "pywinpty", "setuptools-scm[toml] (>=8.0,<9.0)"]
|
||||||
|
docs = ["ipykernel", "jupyter-client", "matplotlib", "nbconvert", "nbformat", "numpydoc", "pandas-datareader", "sphinx"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "terminado"
|
name = "terminado"
|
||||||
version = "0.18.1"
|
version = "0.18.1"
|
||||||
@ -2952,4 +3065,4 @@ files = [
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.12"
|
python-versions = "^3.12"
|
||||||
content-hash = "ee2a6fb525a389426dad97d7f57b9ed7a8187d5892d01cc39a54830695f83a95"
|
content-hash = "f75f65af68bf9b9ae845851b540a182e59667b51060659ab9f4a28bb09ff9880"
|
||||||
|
@ -15,6 +15,9 @@ matplotlib = "^3.9.2"
|
|||||||
scikit-learn = "^1.5.2"
|
scikit-learn = "^1.5.2"
|
||||||
scikit-fuzzy = "^0.5.0"
|
scikit-fuzzy = "^0.5.0"
|
||||||
networkx = "^3.4.2"
|
networkx = "^3.4.2"
|
||||||
|
imbalanced-learn = "^0.12.3"
|
||||||
|
seaborn = "^0.13.2"
|
||||||
|
statsmodels = "^0.14.4"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
137
src/utils.py
Normal file
137
src/utils.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
import math
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pandas import DataFrame
|
||||||
|
from sklearn import metrics
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.pipeline import Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def split_stratified_into_train_val_test(
|
||||||
|
df_input,
|
||||||
|
stratify_colname="y",
|
||||||
|
frac_train=0.6,
|
||||||
|
frac_val=0.15,
|
||||||
|
frac_test=0.25,
|
||||||
|
random_state=None,
|
||||||
|
) -> Tuple:
|
||||||
|
"""
|
||||||
|
Splits a Pandas dataframe into three subsets (train, val, and test)
|
||||||
|
following fractional ratios provided by the user, where each subset is
|
||||||
|
stratified by the values in a specific column (that is, each subset has
|
||||||
|
the same relative frequency of the values in the column). It performs this
|
||||||
|
splitting by running train_test_split() twice.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
df_input : Pandas dataframe
|
||||||
|
Input dataframe to be split.
|
||||||
|
stratify_colname : str
|
||||||
|
The name of the column that will be used for stratification. Usually
|
||||||
|
this column would be for the label.
|
||||||
|
frac_train : float
|
||||||
|
frac_val : float
|
||||||
|
frac_test : float
|
||||||
|
The ratios with which the dataframe will be split into train, val, and
|
||||||
|
test data. The values should be expressed as float fractions and should
|
||||||
|
sum to 1.0.
|
||||||
|
random_state : int, None, or RandomStateInstance
|
||||||
|
Value to be passed to train_test_split().
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
df_train, df_val, df_test :
|
||||||
|
Dataframes containing the three splits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if frac_train + frac_val + frac_test != 1.0:
|
||||||
|
raise ValueError(
|
||||||
|
"fractions %f, %f, %f do not add up to 1.0"
|
||||||
|
% (frac_train, frac_val, frac_test)
|
||||||
|
)
|
||||||
|
|
||||||
|
if stratify_colname not in df_input.columns:
|
||||||
|
raise ValueError("%s is not a column in the dataframe" % (stratify_colname))
|
||||||
|
|
||||||
|
# Contains all columns.
|
||||||
|
X = df_input.drop([stratify_colname], axis=1)
|
||||||
|
# Dataframe of just the column on which to stratify.
|
||||||
|
y = df_input[[stratify_colname]]
|
||||||
|
|
||||||
|
# Split original dataframe into train and temp dataframes.
|
||||||
|
df_train, df_temp, y_train, y_temp = train_test_split(
|
||||||
|
X, y, test_size=(1.0 - frac_train), random_state=random_state
|
||||||
|
)
|
||||||
|
|
||||||
|
if frac_val <= 0:
|
||||||
|
assert len(df_input) == len(df_train) + len(df_temp)
|
||||||
|
return df_train, df_temp, y_train, y_temp
|
||||||
|
|
||||||
|
# Split the temp dataframe into val and test dataframes.
|
||||||
|
relative_frac_test = frac_test / (frac_val + frac_test)
|
||||||
|
df_val, df_test, y_val, y_test = train_test_split(
|
||||||
|
df_temp,
|
||||||
|
y_temp,
|
||||||
|
stratify=y_temp,
|
||||||
|
test_size=relative_frac_test,
|
||||||
|
random_state=random_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(df_input) == len(df_train) + len(df_val) + len(df_test)
|
||||||
|
return df_train, df_val, df_test, y_train, y_val, y_test
|
||||||
|
|
||||||
|
|
||||||
|
def run_classification(
|
||||||
|
model: Pipeline,
|
||||||
|
X_train: DataFrame,
|
||||||
|
X_test: DataFrame,
|
||||||
|
y_train: DataFrame,
|
||||||
|
y_test: DataFrame,
|
||||||
|
) -> Dict:
|
||||||
|
result = {}
|
||||||
|
y_train_predict = model.predict(X_train)
|
||||||
|
y_test_probs = model.predict_proba(X_test)[:, 0]
|
||||||
|
y_test_predict = np.where(y_test_probs > 0.5, 1, 0)
|
||||||
|
|
||||||
|
result["pipeline"] = model
|
||||||
|
result["probs"] = y_test_probs
|
||||||
|
result["preds"] = y_test_predict
|
||||||
|
|
||||||
|
result["Precision_train"] = metrics.precision_score(y_train, y_train_predict)
|
||||||
|
result["Precision_test"] = metrics.precision_score(y_test, y_test_predict)
|
||||||
|
result["Recall_train"] = metrics.recall_score(y_train, y_train_predict)
|
||||||
|
result["Recall_test"] = metrics.recall_score(y_test, y_test_predict)
|
||||||
|
result["Accuracy_train"] = metrics.accuracy_score(y_train, y_train_predict)
|
||||||
|
result["Accuracy_test"] = metrics.accuracy_score(y_test, y_test_predict)
|
||||||
|
result["ROC_AUC_test"] = metrics.roc_auc_score(y_test, y_test_probs)
|
||||||
|
result["F1_train"] = metrics.f1_score(y_train, y_train_predict)
|
||||||
|
result["F1_test"] = metrics.f1_score(y_test, y_test_predict)
|
||||||
|
result["MCC_test"] = metrics.matthews_corrcoef(y_test, y_test_predict)
|
||||||
|
result["Cohen_kappa_test"] = metrics.cohen_kappa_score(y_test, y_test_predict)
|
||||||
|
result["Confusion_matrix"] = metrics.confusion_matrix(y_test, y_test_predict)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def run_regression(
|
||||||
|
model: Pipeline,
|
||||||
|
X_train: DataFrame,
|
||||||
|
X_test: DataFrame,
|
||||||
|
y_train: DataFrame,
|
||||||
|
y_test: DataFrame,
|
||||||
|
) -> Dict:
|
||||||
|
result = {}
|
||||||
|
y_train_pred = model.predict(X_train.values)
|
||||||
|
y_test_pred = model.predict(X_test.values)
|
||||||
|
|
||||||
|
result["fitted"] = model
|
||||||
|
result["train_preds"] = y_train_pred
|
||||||
|
result["preds"] = y_test_pred
|
||||||
|
|
||||||
|
result["RMSE_train"] = math.sqrt(metrics.mean_squared_error(y_train, y_train_pred))
|
||||||
|
result["RMSE_test"] = math.sqrt(metrics.mean_squared_error(y_test, y_test_pred))
|
||||||
|
result["RMAE_test"] = math.sqrt(metrics.mean_absolute_error(y_test, y_test_pred))
|
||||||
|
result["R2_test"] = metrics.r2_score(y_test, y_test_pred)
|
||||||
|
|
||||||
|
return result
|
Loading…
Reference in New Issue
Block a user