Skip to content

Commit

Permalink
test: add more tests on metrics using labels
Browse files Browse the repository at this point in the history
  • Loading branch information
liamj2311 committed May 13, 2024
1 parent e38334c commit ce891b2
Showing 1 changed file with 21 additions and 32 deletions.
53 changes: 21 additions & 32 deletions test/core/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,35 @@ def test_metrics_on_dataset(self):
self.assertIsInstance(ds.scores_metrics, BinaryLabelDatasetScoresMetric)
self.assertIsNotNone(ds)

### METRICS USING LABELS ###

# Disparate Impact
score = ds.metrics.disparate_impact()
print(f"Disparate impact: {score}")
self.assertIsNotNone(score)

# Statistical Parity
score = ds.metrics.statistical_parity_difference()
print(f"Statistical Parity: {score}")
self.assertIsNotNone(score)

# Dirichlet-smoothed base rates
score = ds.metrics._smoothed_base_rates(ds.labels)
print(f"Dirichlet-smoothed base rates: {score}")
self.assertIsNotNone(score)

# Smoothed EDF
score = ds.metrics.smoothed_empirical_differential_fairness()
print(f"Smoothed EDF: {score}")
self.assertIsNotNone(score)

# Consistency
score = ds.metrics.consistency()
print(f"Consistency: {score}")
self.assertIsNotNone(score)

### METRICS USING SCORES ###

score = ds.scores_metrics.new_fancy_metric()
self.assertIsNotNone(score)

Expand Down Expand Up @@ -140,38 +161,6 @@ def test_dataset_creation_with_scores_via_factory(self):
self.assertIsInstance(ds, MulticlassLabelDataset)
self.assertIsNotNone(ds)

def test_metrics_on_dataset(self):
ds = create_dataset("multi class",
# parameters of aequitas.MulticlassLabelDataset init
unprivileged_groups=[{'prot_attr': 0}],
privileged_groups=[{'prot_attr': 1}],
# parameters of aequitas.StructuredDataset init
imputation_strategy=MCMCImputationStrategy(),
# parameters of aif360.MulticlassLabelDataset init
favorable_label=[0, 1., 2.],
unfavorable_label=[3., 4.],
# parameters of aif360.StructuredDataset init
df=generate_multi_label_dataframe_with_scores(),
label_names=['label'],
protected_attribute_names=['prot_attr'],
scores_names="score"
)
mro = False
if mro:
print(f"{ds.__class__.__mro__} MRO (aequitas): {ds.__class__.__mro__}")

self.assertIsInstance(ds, MulticlassLabelDataset)
self.assertIsNotNone(ds)

score = ds.metrics.disparate_impact()
print(f"Disparate impact: {score}")
self.assertIsNotNone(score)

score = ds.scores_metrics.new_fancy_metric()
self.assertIsNotNone(score)




if __name__ == '__main__':
unittest.main()

0 comments on commit ce891b2

Please sign in to comment.