From 90c007981f54cf2335cc09e5f0f1fa21b790fb72 Mon Sep 17 00:00:00 2001 From: Gael Varoquaux Date: Sat, 27 Oct 2012 19:33:05 +0200 Subject: [PATCH] BUG: make precision_recall invariant by scaling probs --- sklearn/metrics/metrics.py | 5 +++-- sklearn/metrics/tests/test_metrics.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 23fbdacacd0c1..200103eaa4efc 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -849,7 +849,6 @@ def precision_recall_curve(y_true, probas_pred): # Initialize true and false positive counts, precision and recall total_positive = float(y_true.sum()) tp_count, fp_count = 0., 0. - last_prob_val = 1. thresholds = [] precision = [1.] recall = [0.] @@ -864,12 +863,14 @@ def precision_recall_curve(y_true, probas_pred): # are encountered) sorted_pred_idxs = np.argsort(probas_pred, kind="mergesort")[::-1] pairs = np.vstack((probas_pred, y_true)).T + last_prob_val = probas_pred[sorted_pred_idxs[0]] + smallest_prob_val = probas_pred[sorted_pred_idxs[-1]] for idx, (prob_val, class_val) in enumerate(pairs[sorted_pred_idxs, :]): if class_val: tp_count += 1. else: fp_count += 1. - if (prob_val < last_prob_val) and (prob_val > 0.): + if (prob_val < last_prob_val) and (prob_val > smallest_prob_val): thresholds.append(prob_val) fn_count = float(total_positive - tp_count) precision.append(tp_count / (tp_count + fp_count)) diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index e5861c0044ffc..5a38c8f452080 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -365,6 +365,22 @@ def test_precision_recall_curve(): assert_array_almost_equal(precision_recall_auc, 0.75, 3) +def test_score_scale_invariance(): + # Test that average_precision_score and auc_score are invariant by + # the scaling or shifting of probabilities + y_true, _, probas_pred = make_prediction(binary=True) + roc_auc = auc_score(y_true, probas_pred) + roc_auc_scaled = auc_score(y_true, 100 * probas_pred) + roc_auc_shifted = auc_score(y_true, probas_pred - 10) + assert_equal(roc_auc, roc_auc_scaled) + assert_equal(roc_auc, roc_auc_shifted) + pr_auc = average_precision_score(y_true, probas_pred) + pr_auc_scaled = average_precision_score(y_true, 100 * probas_pred) + pr_auc_shifted = average_precision_score(y_true, probas_pred - 10) + assert_equal(pr_auc, pr_auc_scaled) + assert_equal(pr_auc, pr_auc_shifted) + + def test_losses(): """Test loss functions""" y_true, y_pred, _ = make_prediction(binary=True)