From 2524faddd4dc0a56a50c4b4461a45e72ba7be27e Mon Sep 17 00:00:00 2001 From: Mayer Antoine Date: Fri, 12 Jul 2019 14:10:53 -0400 Subject: [PATCH] Fix totals for pandas.MultiIndex input fixes #84 (#109) Closes #84 --- recordlinkage/measures.py | 9 +++++++++ tests/test_measures.py | 8 ++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/recordlinkage/measures.py b/recordlinkage/measures.py index 91e56ca8..6f26d2e3 100644 --- a/recordlinkage/measures.py +++ b/recordlinkage/measures.py @@ -285,6 +285,9 @@ def confusion_matrix(links_true, links_pred, total=None): if total is None: tn = numpy.nan else: + + if isinstance(total, pandas.MultiIndex): + total = len(total) tn = true_negatives(links_true, links_pred, total) return numpy.array([[tp, fn], [fp, tn]]) @@ -383,6 +386,9 @@ def accuracy(links_true, links_pred=None, total=None): The accuracy """ + if isinstance(total, pandas.MultiIndex): + total = len(total) + if _isconfusionmatrix(links_true): confusion_matrix = links_true @@ -432,6 +438,9 @@ def specificity(links_true, links_pred=None, total=None): else: fp = false_positives(links_true, links_pred) + + if isinstance(total, pandas.MultiIndex): + total = len(total) tn = true_negatives(links_true, links_pred, total) v = tn / (fp + tn) diff --git a/tests/test_measures.py b/tests/test_measures.py index 92891b0c..fbe45ad1 100644 --- a/tests/test_measures.py +++ b/tests/test_measures.py @@ -22,10 +22,12 @@ class TestMeasures(object): def test_confusion_matrix(self): - result = rl.confusion_matrix(LINKS_TRUE, LINKS_PRED, len(FULL_INDEX)) + result_len = rl.confusion_matrix(LINKS_TRUE, LINKS_PRED, len(FULL_INDEX)) + result_full_index = rl.confusion_matrix(LINKS_TRUE, LINKS_PRED, FULL_INDEX) expected = numpy.array([[1, 2], [3, 3]]) - numpy.testing.assert_array_equal(result, expected) + numpy.testing.assert_array_equal(result_len, expected) + numpy.testing.assert_array_equal(result_full_index, expected) def test_tp_fp_tn_fn(self): @@ -61,6 +63,7 @@ def test_accuracy(self): assert rl.accuracy(LINKS_TRUE, LINKS_PRED, len(FULL_INDEX)) == 4 / 9 assert rl.accuracy(cm) == 4 / 9 + assert rl.accuracy(LINKS_TRUE, LINKS_PRED, FULL_INDEX) == 4 / 9 def test_specificity(self): @@ -69,6 +72,7 @@ def test_specificity(self): assert rl.specificity(LINKS_TRUE, LINKS_PRED, len(FULL_INDEX)) == 1 / 2 assert rl.specificity(cm) == 1 / 2 + assert rl.specificity(LINKS_TRUE, LINKS_PRED, FULL_INDEX) == 1 / 2 def test_fscore(self):