Skip to content

Commit

Permalink
BUG: make precision_recall invariant by scaling probs
Browse files Browse the repository at this point in the history
  • Loading branch information
GaelVaroquaux committed Oct 27, 2012
1 parent 402a304 commit 90c0079
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
5 changes: 3 additions & 2 deletions sklearn/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,6 @@ def precision_recall_curve(y_true, probas_pred):
# Initialize true and false positive counts, precision and recall
total_positive = float(y_true.sum())
tp_count, fp_count = 0., 0.
last_prob_val = 1.
thresholds = []
precision = [1.]
recall = [0.]
Expand All @@ -864,12 +863,14 @@ def precision_recall_curve(y_true, probas_pred):
# are encountered)
sorted_pred_idxs = np.argsort(probas_pred, kind="mergesort")[::-1]
pairs = np.vstack((probas_pred, y_true)).T
last_prob_val = probas_pred[sorted_pred_idxs[0]]
smallest_prob_val = probas_pred[sorted_pred_idxs[-1]]
for idx, (prob_val, class_val) in enumerate(pairs[sorted_pred_idxs, :]):
if class_val:
tp_count += 1.
else:
fp_count += 1.
if (prob_val < last_prob_val) and (prob_val > 0.):
if (prob_val < last_prob_val) and (prob_val > smallest_prob_val):
thresholds.append(prob_val)
fn_count = float(total_positive - tp_count)
precision.append(tp_count / (tp_count + fp_count))
Expand Down
16 changes: 16 additions & 0 deletions sklearn/metrics/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,22 @@ def test_precision_recall_curve():
assert_array_almost_equal(precision_recall_auc, 0.75, 3)


def test_score_scale_invariance():
# Test that average_precision_score and auc_score are invariant by
# the scaling or shifting of probabilities
y_true, _, probas_pred = make_prediction(binary=True)
roc_auc = auc_score(y_true, probas_pred)
roc_auc_scaled = auc_score(y_true, 100 * probas_pred)
roc_auc_shifted = auc_score(y_true, probas_pred - 10)
assert_equal(roc_auc, roc_auc_scaled)
assert_equal(roc_auc, roc_auc_shifted)
pr_auc = average_precision_score(y_true, probas_pred)
pr_auc_scaled = average_precision_score(y_true, 100 * probas_pred)
pr_auc_shifted = average_precision_score(y_true, probas_pred - 10)
assert_equal(pr_auc, pr_auc_scaled)
assert_equal(pr_auc, pr_auc_shifted)


def test_losses():
"""Test loss functions"""
y_true, y_pred, _ = make_prediction(binary=True)
Expand Down

0 comments on commit 90c0079

Please sign in to comment.