Skip to content

Commit

Permalink
Merge pull request #89 from ing-bank/support_search_cv
Browse files Browse the repository at this point in the history
Support any classifier extending basesearchcv
  • Loading branch information
timvink authored Mar 9, 2021
2 parents b9b3639 + 473017e commit e09e1c6
Show file tree
Hide file tree
Showing 2 changed files with 5,824 additions and 65 deletions.
5,838 changes: 5,794 additions & 44 deletions docs/tutorials/nb_shap_feature_elimination.ipynb

Large diffs are not rendered by default.

51 changes: 30 additions & 21 deletions probatus/feature_elimination/feature_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV, check_cv
from sklearn.model_selection._search import BaseSearchCV
from sklearn.model_selection import check_cv
from sklearn.base import clone, is_classifier
from joblib import Parallel, delayed
import warnings
Expand All @@ -20,8 +21,10 @@ class ShapRFECV(BaseFitComputePlotClass):
This class performs Backwards Recursive Feature Elimination, using SHAP feature importance. At each round, for a
given feature set, starting from all available features, the following steps are applied:
1. (Optional) Tune the hyperparameters of the model using [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LassoCV.html)
or [RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html?highlight=randomized#sklearn.model_selection.RandomizedSearchCV),
1. (Optional) Tune the hyperparameters of the model using sklearn compatible search CV e.g.
[GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LassoCV.html),
[RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html?highlight=randomized#sklearn.model_selection.RandomizedSearchCV), or
[BayesSearchCV](https://scikit-optimize.github.io/stable/modules/generated/skopt.BayesSearchCV.html),
2. Apply Cross-validation (CV) to estimate the SHAP feature importance on the provided dataset. In each CV
iteration, the model is fitted on the train folds, and applied on the validation fold to estimate
SHAP feature importance.
Expand All @@ -32,12 +35,13 @@ class ShapRFECV(BaseFitComputePlotClass):
The functionality is similar to [RFECV](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.RFECV.html).
The main difference is removing the lowest importance features based on SHAP features importance. It also
supports the use of [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html)
and [RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html)
passed as the `clf`, thanks to which` you can perform hyperparameter optimization at each step of the search.
hyperparameters of the model at each round, to tune the model for each features set. Lastly, it supports
categorical features (object and category dtype) and missing values in the data, as long as the model supports
them.
supports the use of sklearn compatible search CV for hyperparameter optimization e.g.
[GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LassoCV.html),
[RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html?highlight=randomized#sklearn.model_selection.RandomizedSearchCV), or
[BayesSearchCV](https://scikit-optimize.github.io/stable/modules/generated/skopt.BayesSearchCV.html), which
needs to be passed as the `clf`. Thanks to this you can perform hyperparameter optimization at each step of
the feature elimination. Lastly, it supports categorical features (object and category dtype) and missing values
in the data, as long as the model supports them.
We recommend using [LGBMClassifier](https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html),
because by default it handles missing values and categorical features. In case of other models, make sure to
Expand Down Expand Up @@ -84,6 +88,7 @@ class ShapRFECV(BaseFitComputePlotClass):
final_features_set = shap_elimination.get_reduced_features_set(num_features=3)
```
<img src="../img/shaprfecv.png" width="500" />
"""

def __init__(
Expand All @@ -101,12 +106,14 @@ def __init__(
This method initializes the class:
Args:
clf (binary classifier, GridSearchCV or RandomizedSearchCV):
clf (binary classifier, sklearn compatible search CV e.g. GridSearchCV, RandomizedSearchCV or BayesSearchCV):
A model that will be optimized and trained at each round of features elimination. The recommended model
is [LGBMClassifier](https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html),
because it by default handles the missing values and categorical variables. This parameter also supports
[GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html)
and [RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html).
any hyperparameter search schema that is consistent with the sklearn API e.g.
[GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html),
[RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html)
or [BayesSearchCV](https://scikit-optimize.github.io/stable/modules/generated/skopt.BayesSearchCV.html#skopt.BayesSearchCV).
step (int or float, optional):
Number of lowest importance features removed each round. If it is an int, then each round such number of
Expand Down Expand Up @@ -151,9 +158,7 @@ def __init__(
"""
self.clf = clf

if isinstance(self.clf, RandomizedSearchCV) or isinstance(
self.clf, GridSearchCV
):
if isinstance(self.clf, BaseSearchCV):
self.search_clf = True
else:
self.search_clf = False
Expand Down Expand Up @@ -373,9 +378,11 @@ def _get_feature_shap_values_per_fold(
def fit(self, X, y, columns_to_keep=None, column_names=None):
"""
Fits the object with the provided data. The algorithm starts with the entire dataset, and then sequentially
eliminates features. If [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html)
or [RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html)
object assigned as clf, the hyperparameter optimization is applied first. Then, the SHAP feature importance
eliminates features. If sklearn compatible search CV is passed as clf e.g.
[GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html),
[RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html)
or [BayesSearchCV](https://scikit-optimize.github.io/stable/modules/generated/skopt.BayesSearchCV.html),
the hyperparameter optimization is applied at each step of the elimination. Then, the SHAP feature importance
is calculated using Cross-Validation, and `step` lowest importance features are removed.
Args:
Expand Down Expand Up @@ -558,9 +565,11 @@ def compute(self):
def fit_compute(self, X, y, columns_to_keep=None, column_names=None):
"""
Fits the object with the provided data. The algorithm starts with the entire dataset, and then sequentially
eliminates features. If [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html)
or [RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html)
object assigned as clf, the hyperparameter optimization is applied first. Then, the SHAP feature importance
eliminates features. If sklearn compatible search CV is passed as clf e.g.
[GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html),
[RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html)
or [BayesSearchCV](https://scikit-optimize.github.io/stable/modules/generated/skopt.BayesSearchCV.html),
the hyperparameter optimization is applied at each step of the elimination. Then, the SHAP feature importance
is calculated using Cross-Validation, and `step` lowest importance features are removed. At the end, the
report containing results from each iteration is computed and returned to the user.
Expand Down

0 comments on commit e09e1c6

Please sign in to comment.