diff --git a/docs/howto/grouped_data.ipynb b/docs/howto/grouped_data.ipynb new file mode 100644 index 00000000..1bfbdd23 --- /dev/null +++ b/docs/howto/grouped_data.ipynb @@ -0,0 +1,273 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How to work with grouped data\n", + "\n", + "One of the often appearing properties of the Data Science problems is the natural grouping of the data. You could for instance have multiple samples for the same customer. In such case, you need to make sure that all samples from a given group are in the same fold e.g. in Cross-Validation.\n", + "\n", + "Let's prepare a dataset with groups." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.datasets import make_classification\n", + "\n", + "X, y = make_classification(n_samples=100, n_features=10, random_state=42)\n", + "groups = [i % 5 for i in range(100)]\n", + "groups[:10]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The integers in `groups` variable indicate the group id, to which a given sample belongs.\n", + "\n", + "One of the easiest ways to ensure that the data is split using the information about groups is using `from sklearn.model_selection import GroupKFold`. You can also read more about other ways of splitting data with groups in sklearn [here](https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation-iterators-for-grouped-data)." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import GroupKFold\n", + "\n", + "cv = GroupKFold(n_splits=5).split(X, y, groups=groups)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Such variable can be passed to the `cv` parameter in `probatus` functions" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from probatus.feature_elimination import ShapRFECV\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.model_selection import RandomizedSearchCV\n", + "\n", + "clf = RandomForestClassifier(random_state=42)\n", + "\n", + "param_grid = {\n", + " 'n_estimators': [5, 7, 10],\n", + " 'max_leaf_nodes': [3, 5, 7, 10],\n", + "}\n", + "search = RandomizedSearchCV(clf, param_grid, n_iter=1, random_state=42)\n", + "\n", + "shap_elimination = ShapRFECV(\n", + " clf=search, step=0.2, cv=cv, scoring='roc_auc', n_jobs=3, random_state=42)\n", + "report = shap_elimination.fit_compute(X, y)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " | num_features | \n", + "features_set | \n", + "eliminated_features | \n", + "train_metric_mean | \n", + "train_metric_std | \n", + "val_metric_mean | \n", + "val_metric_std | \n", + "
---|---|---|---|---|---|---|---|
1 | \n", + "10 | \n", + "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] | \n", + "[8, 7] | \n", + "1.000 | \n", + "0.001 | \n", + "0.957 | \n", + "0.086 | \n", + "
2 | \n", + "8 | \n", + "[0, 1, 2, 3, 4, 5, 6, 9] | \n", + "[5] | \n", + "0.999 | \n", + "0.001 | \n", + "0.966 | \n", + "0.055 | \n", + "
3 | \n", + "7 | \n", + "[0, 1, 2, 3, 4, 6, 9] | \n", + "[4] | \n", + "1.000 | \n", + "0.000 | \n", + "0.942 | \n", + "0.114 | \n", + "
4 | \n", + "6 | \n", + "[0, 1, 2, 3, 6, 9] | \n", + "[9] | \n", + "0.999 | \n", + "0.001 | \n", + "0.980 | \n", + "0.032 | \n", + "
5 | \n", + "5 | \n", + "[0, 1, 2, 3, 6] | \n", + "[6] | \n", + "1.000 | \n", + "0.000 | \n", + "0.960 | \n", + "0.073 | \n", + "
6 | \n", + "4 | \n", + "[0, 1, 2, 3] | \n", + "[1] | \n", + "0.999 | \n", + "0.001 | \n", + "0.951 | \n", + "0.091 | \n", + "
7 | \n", + "3 | \n", + "[0, 2, 3] | \n", + "[3] | \n", + "0.999 | \n", + "0.001 | \n", + "0.971 | \n", + "0.052 | \n", + "
8 | \n", + "2 | \n", + "[0, 2] | \n", + "[0] | \n", + "0.998 | \n", + "0.002 | \n", + "0.925 | \n", + "0.122 | \n", + "
9 | \n", + "1 | \n", + "[2] | \n", + "[] | \n", + "0.998 | \n", + "0.002 | \n", + "0.938 | \n", + "0.098 | \n", + "