ckexp/plots.py

55 lines
1.7 KiB
Python

from typing import Dict, List, Optional, Tuple
import numpy as np
import statsmodels.api as sm
from matplotlib.axes import Axes
from mlxtend.plotting import plot_decision_regions
from pandas import DataFrame, Series
from statsmodels.discrete.discrete_model import BinaryResultsWrapper
def create_decision_plot(
X: DataFrame,
y: Series,
model: BinaryResultsWrapper,
feature_index: List[str],
feature_names: List[str],
highlight_index: int,
filter: Dict[str, Tuple[int, int]],
ax: Optional[Axes] = None,
) -> Axes:
def _get_from_filter(is_value=True) -> Dict[int, int]:
filler = dict(
(columns.index(k) + 1, filter[k][0 if is_value else 1])
for k in filter.keys()
if k not in feature_index
)
constant = {0: 1}
return {**constant, **filler}
if feature_names is None:
feature_names = feature_index
columns = X.columns.to_list()
filler_values = _get_from_filter()
filler_ranges = _get_from_filter(False)
x_highlight = np.reshape(
np.concatenate(([1], X.iloc[highlight_index].to_numpy())),
(1, len(columns) + 1),
)
ax = plot_decision_regions(
sm.add_constant(X).to_numpy(), # type: ignore
y.to_numpy(),
clf=model,
feature_index=[columns.index(k) + 1 for k in feature_index],
X_highlight=x_highlight,
filler_feature_values=filler_values,
filler_feature_ranges=filler_ranges,
scatter_kwargs={"s": 48, "edgecolor": None, "alpha": 0.7},
contourf_kwargs={"alpha": 0.2},
legend=2,
ax=ax,
)
ax.set_xlabel(feature_names[0])
ax.set_ylabel(feature_names[1])
return ax