55 lines
1.7 KiB
Python
55 lines
1.7 KiB
Python
from typing import Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import statsmodels.api as sm
|
|
from matplotlib.axes import Axes
|
|
from mlxtend.plotting import plot_decision_regions
|
|
from pandas import DataFrame, Series
|
|
from statsmodels.discrete.discrete_model import BinaryResultsWrapper
|
|
|
|
|
|
def create_decision_plot(
|
|
X: DataFrame,
|
|
y: Series,
|
|
model: BinaryResultsWrapper,
|
|
feature_index: List[str],
|
|
feature_names: List[str],
|
|
highlight_index: int,
|
|
filter: Dict[str, Tuple[int, int]],
|
|
ax: Optional[Axes] = None,
|
|
) -> Axes:
|
|
def _get_from_filter(is_value=True) -> Dict[int, int]:
|
|
filler = dict(
|
|
(columns.index(k) + 1, filter[k][0 if is_value else 1])
|
|
for k in filter.keys()
|
|
if k not in feature_index
|
|
)
|
|
constant = {0: 1}
|
|
return {**constant, **filler}
|
|
|
|
if feature_names is None:
|
|
feature_names = feature_index
|
|
columns = X.columns.to_list()
|
|
filler_values = _get_from_filter()
|
|
filler_ranges = _get_from_filter(False)
|
|
x_highlight = np.reshape(
|
|
np.concatenate(([1], X.iloc[highlight_index].to_numpy())),
|
|
(1, len(columns) + 1),
|
|
)
|
|
ax = plot_decision_regions(
|
|
sm.add_constant(X).to_numpy(), # type: ignore
|
|
y.to_numpy(),
|
|
clf=model,
|
|
feature_index=[columns.index(k) + 1 for k in feature_index],
|
|
X_highlight=x_highlight,
|
|
filler_feature_values=filler_values,
|
|
filler_feature_ranges=filler_ranges,
|
|
scatter_kwargs={"s": 48, "edgecolor": None, "alpha": 0.7},
|
|
contourf_kwargs={"alpha": 0.2},
|
|
legend=2,
|
|
ax=ax,
|
|
)
|
|
ax.set_xlabel(feature_names[0])
|
|
ax.set_ylabel(feature_names[1])
|
|
return ax
|