Add completed lec3 examples
This commit is contained in:
parent
648718c5e3
commit
224e345b46
70001
data/cardio_cleared.csv
Normal file
70001
data/cardio_cleared.csv
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
1002
lec3-3-results.ipynb
Normal file
1002
lec3-3-results.ipynb
Normal file
File diff suppressed because one or more lines are too long
54
plots.py
Normal file
54
plots.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import statsmodels.api as sm
|
||||||
|
from matplotlib.axes import Axes
|
||||||
|
from mlxtend.plotting import plot_decision_regions
|
||||||
|
from pandas import DataFrame, Series
|
||||||
|
from statsmodels.discrete.discrete_model import BinaryResultsWrapper
|
||||||
|
|
||||||
|
|
||||||
|
def create_decision_plot(
|
||||||
|
X: DataFrame,
|
||||||
|
y: Series,
|
||||||
|
model: BinaryResultsWrapper,
|
||||||
|
feature_index: List[str],
|
||||||
|
feature_names: List[str],
|
||||||
|
highlight_index: int,
|
||||||
|
filter: Dict[str, Tuple[int, int]],
|
||||||
|
ax: Optional[Axes] = None,
|
||||||
|
) -> Axes:
|
||||||
|
def _get_from_filter(is_value=True) -> Dict[int, int]:
|
||||||
|
filler = dict(
|
||||||
|
(columns.index(k) + 1, filter[k][0 if is_value else 1])
|
||||||
|
for k in filter.keys()
|
||||||
|
if k not in feature_index
|
||||||
|
)
|
||||||
|
constant = {0: 1}
|
||||||
|
return {**constant, **filler}
|
||||||
|
|
||||||
|
if feature_names is None:
|
||||||
|
feature_names = feature_index
|
||||||
|
columns = X.columns.to_list()
|
||||||
|
filler_values = _get_from_filter()
|
||||||
|
filler_ranges = _get_from_filter(False)
|
||||||
|
x_highlight = np.reshape(
|
||||||
|
np.concatenate(([1], X.iloc[highlight_index].to_numpy())),
|
||||||
|
(1, len(columns) + 1),
|
||||||
|
)
|
||||||
|
ax = plot_decision_regions(
|
||||||
|
sm.add_constant(X).to_numpy(), # type: ignore
|
||||||
|
y.to_numpy(),
|
||||||
|
clf=model,
|
||||||
|
feature_index=[columns.index(k) + 1 for k in feature_index],
|
||||||
|
X_highlight=x_highlight,
|
||||||
|
filler_feature_values=filler_values,
|
||||||
|
filler_feature_ranges=filler_ranges,
|
||||||
|
scatter_kwargs={"s": 48, "edgecolor": None, "alpha": 0.7},
|
||||||
|
contourf_kwargs={"alpha": 0.2},
|
||||||
|
legend=2,
|
||||||
|
ax=ax,
|
||||||
|
)
|
||||||
|
ax.set_xlabel(feature_names[0])
|
||||||
|
ax.set_ylabel(feature_names[1])
|
||||||
|
return ax
|
Loading…
x
Reference in New Issue
Block a user