Add completed lec3 examples
This commit is contained in:
parent
648718c5e3
commit
224e345b46
70001
data/cardio_cleared.csv
Normal file
70001
data/cardio_cleared.csv
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
1002
lec3-3-results.ipynb
Normal file
1002
lec3-3-results.ipynb
Normal file
File diff suppressed because one or more lines are too long
54
plots.py
Normal file
54
plots.py
Normal file
@ -0,0 +1,54 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import statsmodels.api as sm
|
||||
from matplotlib.axes import Axes
|
||||
from mlxtend.plotting import plot_decision_regions
|
||||
from pandas import DataFrame, Series
|
||||
from statsmodels.discrete.discrete_model import BinaryResultsWrapper
|
||||
|
||||
|
||||
def create_decision_plot(
|
||||
X: DataFrame,
|
||||
y: Series,
|
||||
model: BinaryResultsWrapper,
|
||||
feature_index: List[str],
|
||||
feature_names: List[str],
|
||||
highlight_index: int,
|
||||
filter: Dict[str, Tuple[int, int]],
|
||||
ax: Optional[Axes] = None,
|
||||
) -> Axes:
|
||||
def _get_from_filter(is_value=True) -> Dict[int, int]:
|
||||
filler = dict(
|
||||
(columns.index(k) + 1, filter[k][0 if is_value else 1])
|
||||
for k in filter.keys()
|
||||
if k not in feature_index
|
||||
)
|
||||
constant = {0: 1}
|
||||
return {**constant, **filler}
|
||||
|
||||
if feature_names is None:
|
||||
feature_names = feature_index
|
||||
columns = X.columns.to_list()
|
||||
filler_values = _get_from_filter()
|
||||
filler_ranges = _get_from_filter(False)
|
||||
x_highlight = np.reshape(
|
||||
np.concatenate(([1], X.iloc[highlight_index].to_numpy())),
|
||||
(1, len(columns) + 1),
|
||||
)
|
||||
ax = plot_decision_regions(
|
||||
sm.add_constant(X).to_numpy(), # type: ignore
|
||||
y.to_numpy(),
|
||||
clf=model,
|
||||
feature_index=[columns.index(k) + 1 for k in feature_index],
|
||||
X_highlight=x_highlight,
|
||||
filler_feature_values=filler_values,
|
||||
filler_feature_ranges=filler_ranges,
|
||||
scatter_kwargs={"s": 48, "edgecolor": None, "alpha": 0.7},
|
||||
contourf_kwargs={"alpha": 0.2},
|
||||
legend=2,
|
||||
ax=ax,
|
||||
)
|
||||
ax.set_xlabel(feature_names[0])
|
||||
ax.set_ylabel(feature_names[1])
|
||||
return ax
|
Loading…
x
Reference in New Issue
Block a user