Skip to content

Commit

Permalink
Fix totals for pandas.MultiIndex input fixes #84 (#109)
Browse files Browse the repository at this point in the history
Closes #84
  • Loading branch information
mayerantoine authored and J535D165 committed Jul 12, 2019
1 parent 423a2eb commit 2524fad
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
9 changes: 9 additions & 0 deletions recordlinkage/measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ def confusion_matrix(links_true, links_pred, total=None):
if total is None:
tn = numpy.nan
else:

if isinstance(total, pandas.MultiIndex):
total = len(total)
tn = true_negatives(links_true, links_pred, total)

return numpy.array([[tp, fn], [fp, tn]])
Expand Down Expand Up @@ -383,6 +386,9 @@ def accuracy(links_true, links_pred=None, total=None):
The accuracy
"""

if isinstance(total, pandas.MultiIndex):
total = len(total)

if _isconfusionmatrix(links_true):

confusion_matrix = links_true
Expand Down Expand Up @@ -432,6 +438,9 @@ def specificity(links_true, links_pred=None, total=None):
else:

fp = false_positives(links_true, links_pred)

if isinstance(total, pandas.MultiIndex):
total = len(total)
tn = true_negatives(links_true, links_pred, total)
v = tn / (fp + tn)

Expand Down
8 changes: 6 additions & 2 deletions tests/test_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
class TestMeasures(object):
def test_confusion_matrix(self):

result = rl.confusion_matrix(LINKS_TRUE, LINKS_PRED, len(FULL_INDEX))
result_len = rl.confusion_matrix(LINKS_TRUE, LINKS_PRED, len(FULL_INDEX))
result_full_index = rl.confusion_matrix(LINKS_TRUE, LINKS_PRED, FULL_INDEX)
expected = numpy.array([[1, 2], [3, 3]])

numpy.testing.assert_array_equal(result, expected)
numpy.testing.assert_array_equal(result_len, expected)
numpy.testing.assert_array_equal(result_full_index, expected)

def test_tp_fp_tn_fn(self):

Expand Down Expand Up @@ -61,6 +63,7 @@ def test_accuracy(self):

assert rl.accuracy(LINKS_TRUE, LINKS_PRED, len(FULL_INDEX)) == 4 / 9
assert rl.accuracy(cm) == 4 / 9
assert rl.accuracy(LINKS_TRUE, LINKS_PRED, FULL_INDEX) == 4 / 9

def test_specificity(self):

Expand All @@ -69,6 +72,7 @@ def test_specificity(self):

assert rl.specificity(LINKS_TRUE, LINKS_PRED, len(FULL_INDEX)) == 1 / 2
assert rl.specificity(cm) == 1 / 2
assert rl.specificity(LINKS_TRUE, LINKS_PRED, FULL_INDEX) == 1 / 2

def test_fscore(self):

Expand Down

0 comments on commit 2524fad

Please sign in to comment.