from typing import Dict, List, Optional, Tuple import numpy as np import statsmodels.api as sm from matplotlib.axes import Axes from mlxtend.plotting import plot_decision_regions from pandas import DataFrame, Series from statsmodels.discrete.discrete_model import BinaryResultsWrapper def create_decision_plot( X: DataFrame, y: Series, model: BinaryResultsWrapper, feature_index: List[str], feature_names: List[str], highlight_index: int, filter: Dict[str, Tuple[int, int]], ax: Optional[Axes] = None, ) -> Axes: def _get_from_filter(is_value=True) -> Dict[int, int]: filler = dict( (columns.index(k) + 1, filter[k][0 if is_value else 1]) for k in filter.keys() if k not in feature_index ) constant = {0: 1} return {**constant, **filler} if feature_names is None: feature_names = feature_index columns = X.columns.to_list() filler_values = _get_from_filter() filler_ranges = _get_from_filter(False) x_highlight = np.reshape( np.concatenate(([1], X.iloc[highlight_index].to_numpy())), (1, len(columns) + 1), ) ax = plot_decision_regions( sm.add_constant(X).to_numpy(), # type: ignore y.to_numpy(), clf=model, feature_index=[columns.index(k) + 1 for k in feature_index], X_highlight=x_highlight, filler_feature_values=filler_values, filler_feature_ranges=filler_ranges, scatter_kwargs={"s": 48, "edgecolor": None, "alpha": 0.7}, contourf_kwargs={"alpha": 0.2}, legend=2, ax=ax, ) ax.set_xlabel(feature_names[0]) ax.set_ylabel(feature_names[1]) return ax