Skip to content

Commit

Permalink
metrics.py: COSMIT more commets on precision_recall_curve
Browse files Browse the repository at this point in the history
  • Loading branch information
conradlee authored and amueller committed Nov 14, 2012
1 parent 090bed6 commit e31376a
Showing 1 changed file with 26 additions and 23 deletions.
49 changes: 26 additions & 23 deletions sklearn/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# Olivier Grisel <[email protected]>
# License: BSD Style.

import itertools
from itertools import izip
import numpy as np
from scipy.sparse import coo_matrix

Expand Down Expand Up @@ -855,14 +855,17 @@ def precision_recall_curve(y_true, probas_pred):
raise ValueError("y_true contains non binary labels: %r" % labels)

# Sort pred_probas (and corresponding true labels) by pred_proba value
sort_idxs = np.argsort(probas_pred, kind="mergesort")[::-1]
probas_pred = probas_pred[sort_idxs]
y_true = y_true[sort_idxs]

# Get indices where values of probas_pred decreases
thresh_idxs = np.r_[0,
np.where(np.diff(probas_pred))[0] + 1,
len(probas_pred)]
decreasing_probas_indices = np.argsort(probas_pred, kind="mergesort")[::-1]
probas_pred = probas_pred[decreasing_probas_indices]
y_true = y_true[decreasing_probas_indices]

# Probas_pred typically has many tied values. Here we extract
# the indices associated with the distinct values. We also
# concatenate values onto the ends of the curve.
distinct_value_indices = np.where(np.diff(probas_pred))[0] + 1
threshold_idxs = np.r_[0,
distinct_value_indices,
len(probas_pred)]

# Initialize true and false positive counts, precision and recall
total_positive = float(y_true.sum())
Expand All @@ -871,20 +874,20 @@ def precision_recall_curve(y_true, probas_pred):
recall = [0.]
thresholds = []

# Iterate over indices which indicate distinct values of probas_pred --
# each of these distinct values will be represented in the curve with a
# coordinate in precision-recall space. To calculate the precision and
# recall associated with each point, we use these indices to select all
# labels associated with the predictions. By incrementally keeping track
# of the number of positive and negative labels seen so far, we can
# calculate precision and recall.
for l_idx, r_idx in itertools.izip(thresh_idxs[:-1], thresh_idxs[1:]):
thresh_labels = y_true[l_idx:r_idx]
n_thresh = r_idx - l_idx
n_pos_thresh = thresh_labels.sum()
n_neg_thresh = n_thresh - n_pos_thresh
tp_count += n_pos_thresh
fp_count += n_neg_thresh
# Iterate over indices which indicate distinct values (thresholds) of
# probas_pred. Each of these threshold values will be represented in the
# curve with a coordinate in precision-recall space. To calculate the
# precision and recall associated with each point, we use these indices to
# select all labels associated with the predictions. By incrementally
# keeping track of the number of positive and negative labels seen so far,
# we can calculate precision and recall.
for l_idx, r_idx in izip(threshold_idxs[:-1], threshold_idxs[1:]):
threshold_labels = y_true[l_idx:r_idx]
n_at_threshold = r_idx - l_idx
n_pos_at_threshold = threshold_labels.sum()
n_neg_at_threshold = n_at_threshold - n_pos_at_threshold
tp_count += n_pos_at_threshold
fp_count += n_neg_at_threshold
fn_count = total_positive - tp_count
precision.append(tp_count / (tp_count + fp_count))
recall.append(tp_count / (tp_count + fn_count))
Expand Down

0 comments on commit e31376a

Please sign in to comment.