-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example for unbiased naive bayes
- Loading branch information
1 parent
f42dc29
commit 9e4dd3b
Showing
1 changed file
with
354 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,354 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import sys\n", | ||
"sys.path.append(\"../\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import pandas as pd\n", | ||
"pd.set_option('display.width', 500)\n", | ||
"import aequitas.detection.descriptive_stats as dstats\n", | ||
"import aequitas.detection.metrics as metrics\n", | ||
"import aequitas.mitigation.data as technique\n", | ||
"import aequitas.mitigation.models as model\n", | ||
"import aequitas.tools.data_manip as dm\n", | ||
"import aequitas.tools as tools" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#Import dataset\n", | ||
"dataset_name=\"Census_Income_Dataset.csv\"\n", | ||
"dataset_directory=\"../datasets/\"+dataset_name\n", | ||
"dataset = pd.read_csv(dataset_directory)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Dataset Pre-Processing\n", | ||
"\n", | ||
"# remove fnlwgt column (per instructions)\n", | ||
"dataset = dataset.drop('fnlwgt', axis=1)\n", | ||
"\n", | ||
"# remove education column since there is an educution_num\n", | ||
"dataset = dataset.drop('education', axis=1)\n", | ||
"\n", | ||
"# impute the missing values\n", | ||
"num_data = dataset.shape[0]\n", | ||
"col_names = dataset.columns\n", | ||
"for c in col_names:\n", | ||
"\tdataset[c] = dataset[c].replace(\"?\", np.NaN)\n", | ||
"dataset = dataset.apply(lambda x:x.fillna(x.value_counts().index[0]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Dataset:\n", | ||
" Column Name Data Type Column Type (suggestion) Number_Values Values\n", | ||
"0 age int64 Continuous 74 -\n", | ||
"1 workclass text Categorical/Ordinal 8 [Private, Local-gov, Self-emp-not-inc, Federal...\n", | ||
"2 educational-num int64 Categorical/Ordinal 16 [7, 9, 12, 10, 6, 15, 4, 13, 14, 16, 3, 11, 5,...\n", | ||
"3 marital-status text Categorical/Ordinal 7 [Never-married, Married-civ-spouse, Widowed, D...\n", | ||
"4 occupation text Categorical/Ordinal 14 [Machine-op-inspct, Farming-fishing, Protectiv...\n", | ||
"5 relationship text Categorical/Ordinal 6 [Own-child, Husband, Not-in-family, Unmarried,...\n", | ||
"6 race text Categorical/Ordinal 5 [Black, White, Asian-Pac-Islander, Other, Amer...\n", | ||
"7 gender text Binary 2 [Male, Female]\n", | ||
"8 capital-gain int64 Continuous 123 -\n", | ||
"9 capital-loss int64 Continuous 99 -\n", | ||
"10 hours-per-week int64 Continuous 96 -\n", | ||
"11 native-country text Categorical/Ordinal 41 [United-States, Peru, Guatemala, Mexico, Domin...\n", | ||
"12 income text Binary 2 [<=50K, >50K]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# display dataset structure\n", | ||
"dataset_struct=dstats.analyse_dataset(dataset,verbose=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# split dataset to training and test samples\n", | ||
"training_sample,test_sample = dm.split_dataset(dataset,ratio=0.2,random_state=51)\n", | ||
"\n", | ||
"# Define appropriate transformations for dataset (feature encoding and scaling if required)\n", | ||
"# Define appropriate transformations for dataset (feature encoding and scaling if required)\n", | ||
"transform_dict = {\n", | ||
" \"income\": {\n", | ||
" \"encode\": \"labeling\",\n", | ||
" \"labels\": {\n", | ||
" \"<=50K\": 0,\n", | ||
" \">50K\": 1, \n", | ||
" }\n", | ||
" },\n", | ||
" \"gender\": {\n", | ||
" \"encode\": \"labeling\",\n", | ||
" \"labels\": {\n", | ||
" \"Female\": 0,\n", | ||
" \"Male\": 1, \n", | ||
" }\n", | ||
" },\n", | ||
" \"workclass\": {\n", | ||
" \"encode\": \"labeling\",\n", | ||
" },\n", | ||
" \"race\": {\n", | ||
" \"encode\": \"labeling\", \n", | ||
" },\n", | ||
" \"marital-status\": {\n", | ||
" \"encode\": \"labeling\",\n", | ||
" },\n", | ||
" \"occupation\": {\n", | ||
" \"encode\": \"labeling\", \n", | ||
" },\n", | ||
" \"relationship\": {\n", | ||
" \"encode\": \"labeling\", \n", | ||
" },\n", | ||
" \"native-country\": {\n", | ||
" \"encode\": \"labeling\", \n", | ||
" },\n", | ||
" \"age\":{\n", | ||
" \"scaling\": \"min-max\"\n", | ||
" },\n", | ||
" \"educational-num\":{\n", | ||
" \"scaling\": \"min-max\"\n", | ||
" },\n", | ||
" \"capital-gain\":{\n", | ||
" },\n", | ||
" \"capital-loss\":{\n", | ||
" },\n", | ||
" \"hours-per-week\":{\n", | ||
" \"scaling\": \"min-max\"\n", | ||
" }\n", | ||
"}\n", | ||
"\n", | ||
"# Transform the training sample\n", | ||
"training_sample, transformers = dm.transform_training_data(training_sample, transform_dict)\n", | ||
"\n", | ||
"# Transform the test sample\n", | ||
"test_sample = dm.transform_test_data(test_sample, transform_dict, transformers)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# define fairness parameters\n", | ||
"class_attribute='income'\n", | ||
"sensitive_attribute='gender'\n", | ||
"outcome=1 # >50K\n", | ||
"priv_group=1 # Male" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Training sample:\n", | ||
"Statistical/Demographic Parity:\n", | ||
"Outcome: 1\n", | ||
" 1 0\n", | ||
"1 0.0 0.196285\n", | ||
"\n", | ||
"\n", | ||
"Test sample:\n", | ||
"Statistical/Demographic Parity:\n", | ||
"Outcome: 1\n", | ||
" 1 0\n", | ||
"1 0.0 0.187467\n", | ||
"\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# measure statistical parity before classification\n", | ||
"print(\"Training sample:\")\n", | ||
"res=metrics.statistical_parity(training_sample,class_attribute,sensitive_attribute,positive_outcome=outcome,privileged_group=priv_group,verbose=True)\n", | ||
"\n", | ||
"print(\"Test sample:\")\n", | ||
"res=metrics.statistical_parity(test_sample,class_attribute,sensitive_attribute,positive_outcome=outcome,privileged_group=priv_group,verbose=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# define classifier parameters\n", | ||
"classifier_type=\"Naive_Bayes\"\n", | ||
"normal_feat = [\"age\",\"educational-num\",\"hours-per-week\"]\n", | ||
"classifier_params= {\n", | ||
" \"normal_features\":normal_feat\n", | ||
"}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Classifier Accuracy: 0.86\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Train a classifier on training sample\n", | ||
"clf=tools.train_classifier(training_sample,class_attribute,classifier_type,classifier_params)\n", | ||
"\n", | ||
"# Test classifier on test sample\n", | ||
"predicted_test_sample, _, _, _= tools.test_classifier(clf,test_sample,class_attribute,verbose=True)\n", | ||
"prediction=np.array(predicted_test_sample[class_attribute])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Statistical/Demographic Parity:\n", | ||
"Outcome: 1\n", | ||
" 1 0\n", | ||
"1 0.0 0.244571\n", | ||
"\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# measure statistical parity after classification on test sample\n", | ||
"res=metrics.statistical_parity(predicted_test_sample,class_attribute,sensitive_attribute,positive_outcome=outcome,privileged_group=priv_group,verbose=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Mitigation unbiased Naive Bayes Model!\n", | ||
"Step: 0 Discrimination: 0.2510332455486799 \n", | ||
"Step: 5 Discrimination: 0.2340343272980231 \n", | ||
"Step: 10 Discrimination: 0.2192761755440438 \n", | ||
"Step: 15 Discrimination: 0.20374534568776193 \n", | ||
"Step: 20 Discrimination: 0.18249669787444095 \n", | ||
"Step: 25 Discrimination: 0.152672653697833 \n", | ||
"Step: 30 Discrimination: 0.11540482643624986 \n", | ||
"Step: 35 Discrimination: 0.060157934079452 \n", | ||
"Final step: 39 Discrimination: 0.0028079451036868153 \n", | ||
"Classifier Accuracy: 0.83\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Train a bias mitigated re-weighting classifier\n", | ||
"print(\"Mitigation unbiased Naive Bayes Model!\")\n", | ||
"clf=model.unbiasedNaiveBayes(training_sample,class_attribute,sensitive_attribute,normal_features=normal_feat)\n", | ||
"\n", | ||
"# Test a classifier\n", | ||
"predicted_test_sample, _, _, _= tools.test_classifier(clf,test_sample,class_attribute,verbose=True)\n", | ||
"prediction=np.array(predicted_test_sample[class_attribute])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Statistical/Demographic Parity:\n", | ||
"Outcome: 1\n", | ||
" 1 0\n", | ||
"1 0.0 -0.00886\n", | ||
"\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"\n", | ||
"# Measure Discrimination on predicted test sample\n", | ||
"res=metrics.statistical_parity(predicted_test_sample,class_attribute,sensitive_attribute,positive_outcome=outcome,privileged_group=priv_group,verbose=True)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3.11.5 ('aequitas')", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
}, | ||
"orig_nbformat": 4, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "aafc4311810ab362ec8b20f1ad7cefee81be5411161712b32fbd26f1f27127c5" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |