-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #154 from ing-bank/grouped_data_notebook
Add notebook for Grouped data
- Loading branch information
Showing
2 changed files
with
275 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,273 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# How to work with grouped data\n", | ||
"\n", | ||
"One of the often appearing properties of the Data Science problems is the natural grouping of the data. You could for instance have multiple samples for the same customer. In such case, you need to make sure that all samples from a given group are in the same fold e.g. in Cross-Validation.\n", | ||
"\n", | ||
"Let's prepare a dataset with groups." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]" | ||
] | ||
}, | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"from sklearn.datasets import make_classification\n", | ||
"\n", | ||
"X, y = make_classification(n_samples=100, n_features=10, random_state=42)\n", | ||
"groups = [i % 5 for i in range(100)]\n", | ||
"groups[:10]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The integers in `groups` variable indicate the group id, to which a given sample belongs.\n", | ||
"\n", | ||
"One of the easiest ways to ensure that the data is split using the information about groups is using `from sklearn.model_selection import GroupKFold`. You can also read more about other ways of splitting data with groups in sklearn [here](https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation-iterators-for-grouped-data)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from sklearn.model_selection import GroupKFold\n", | ||
"\n", | ||
"cv = GroupKFold(n_splits=5).split(X, y, groups=groups)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Such variable can be passed to the `cv` parameter in `probatus` functions" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from probatus.feature_elimination import ShapRFECV\n", | ||
"from sklearn.ensemble import RandomForestClassifier\n", | ||
"from sklearn.model_selection import RandomizedSearchCV\n", | ||
"\n", | ||
"clf = RandomForestClassifier(random_state=42)\n", | ||
"\n", | ||
"param_grid = {\n", | ||
" 'n_estimators': [5, 7, 10],\n", | ||
" 'max_leaf_nodes': [3, 5, 7, 10],\n", | ||
"}\n", | ||
"search = RandomizedSearchCV(clf, param_grid, n_iter=1, random_state=42)\n", | ||
"\n", | ||
"shap_elimination = ShapRFECV(\n", | ||
" clf=search, step=0.2, cv=cv, scoring='roc_auc', n_jobs=3, random_state=42)\n", | ||
"report = shap_elimination.fit_compute(X, y)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"<div>\n", | ||
"<style scoped>\n", | ||
" .dataframe tbody tr th:only-of-type {\n", | ||
" vertical-align: middle;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe tbody tr th {\n", | ||
" vertical-align: top;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe thead th {\n", | ||
" text-align: right;\n", | ||
" }\n", | ||
"</style>\n", | ||
"<table border=\"1\" class=\"dataframe\">\n", | ||
" <thead>\n", | ||
" <tr style=\"text-align: right;\">\n", | ||
" <th></th>\n", | ||
" <th>num_features</th>\n", | ||
" <th>features_set</th>\n", | ||
" <th>eliminated_features</th>\n", | ||
" <th>train_metric_mean</th>\n", | ||
" <th>train_metric_std</th>\n", | ||
" <th>val_metric_mean</th>\n", | ||
" <th>val_metric_std</th>\n", | ||
" </tr>\n", | ||
" </thead>\n", | ||
" <tbody>\n", | ||
" <tr>\n", | ||
" <th>1</th>\n", | ||
" <td>10</td>\n", | ||
" <td>[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]</td>\n", | ||
" <td>[8, 7]</td>\n", | ||
" <td>1.000</td>\n", | ||
" <td>0.001</td>\n", | ||
" <td>0.957</td>\n", | ||
" <td>0.086</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>2</th>\n", | ||
" <td>8</td>\n", | ||
" <td>[0, 1, 2, 3, 4, 5, 6, 9]</td>\n", | ||
" <td>[5]</td>\n", | ||
" <td>0.999</td>\n", | ||
" <td>0.001</td>\n", | ||
" <td>0.966</td>\n", | ||
" <td>0.055</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>3</th>\n", | ||
" <td>7</td>\n", | ||
" <td>[0, 1, 2, 3, 4, 6, 9]</td>\n", | ||
" <td>[4]</td>\n", | ||
" <td>1.000</td>\n", | ||
" <td>0.000</td>\n", | ||
" <td>0.942</td>\n", | ||
" <td>0.114</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>4</th>\n", | ||
" <td>6</td>\n", | ||
" <td>[0, 1, 2, 3, 6, 9]</td>\n", | ||
" <td>[9]</td>\n", | ||
" <td>0.999</td>\n", | ||
" <td>0.001</td>\n", | ||
" <td>0.980</td>\n", | ||
" <td>0.032</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>5</th>\n", | ||
" <td>5</td>\n", | ||
" <td>[0, 1, 2, 3, 6]</td>\n", | ||
" <td>[6]</td>\n", | ||
" <td>1.000</td>\n", | ||
" <td>0.000</td>\n", | ||
" <td>0.960</td>\n", | ||
" <td>0.073</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>6</th>\n", | ||
" <td>4</td>\n", | ||
" <td>[0, 1, 2, 3]</td>\n", | ||
" <td>[1]</td>\n", | ||
" <td>0.999</td>\n", | ||
" <td>0.001</td>\n", | ||
" <td>0.951</td>\n", | ||
" <td>0.091</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>7</th>\n", | ||
" <td>3</td>\n", | ||
" <td>[0, 2, 3]</td>\n", | ||
" <td>[3]</td>\n", | ||
" <td>0.999</td>\n", | ||
" <td>0.001</td>\n", | ||
" <td>0.971</td>\n", | ||
" <td>0.052</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>8</th>\n", | ||
" <td>2</td>\n", | ||
" <td>[0, 2]</td>\n", | ||
" <td>[0]</td>\n", | ||
" <td>0.998</td>\n", | ||
" <td>0.002</td>\n", | ||
" <td>0.925</td>\n", | ||
" <td>0.122</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>9</th>\n", | ||
" <td>1</td>\n", | ||
" <td>[2]</td>\n", | ||
" <td>[]</td>\n", | ||
" <td>0.998</td>\n", | ||
" <td>0.002</td>\n", | ||
" <td>0.938</td>\n", | ||
" <td>0.098</td>\n", | ||
" </tr>\n", | ||
" </tbody>\n", | ||
"</table>\n", | ||
"</div>" | ||
], | ||
"text/plain": [ | ||
" num_features features_set eliminated_features \\\n", | ||
"1 10 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] [8, 7] \n", | ||
"2 8 [0, 1, 2, 3, 4, 5, 6, 9] [5] \n", | ||
"3 7 [0, 1, 2, 3, 4, 6, 9] [4] \n", | ||
"4 6 [0, 1, 2, 3, 6, 9] [9] \n", | ||
"5 5 [0, 1, 2, 3, 6] [6] \n", | ||
"6 4 [0, 1, 2, 3] [1] \n", | ||
"7 3 [0, 2, 3] [3] \n", | ||
"8 2 [0, 2] [0] \n", | ||
"9 1 [2] [] \n", | ||
"\n", | ||
" train_metric_mean train_metric_std val_metric_mean val_metric_std \n", | ||
"1 1.000 0.001 0.957 0.086 \n", | ||
"2 0.999 0.001 0.966 0.055 \n", | ||
"3 1.000 0.000 0.942 0.114 \n", | ||
"4 0.999 0.001 0.980 0.032 \n", | ||
"5 1.000 0.000 0.960 0.073 \n", | ||
"6 0.999 0.001 0.951 0.091 \n", | ||
"7 0.999 0.001 0.971 0.052 \n", | ||
"8 0.998 0.002 0.925 0.122 \n", | ||
"9 0.998 0.002 0.938 0.098 " | ||
] | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"report" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters