forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
metrics.py: COSMIT more commets on precision_recall_curve
- Loading branch information
Showing
1 changed file
with
26 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ | |
# Olivier Grisel <[email protected]> | ||
# License: BSD Style. | ||
|
||
import itertools | ||
from itertools import izip | ||
import numpy as np | ||
from scipy.sparse import coo_matrix | ||
|
||
|
@@ -855,14 +855,17 @@ def precision_recall_curve(y_true, probas_pred): | |
raise ValueError("y_true contains non binary labels: %r" % labels) | ||
|
||
# Sort pred_probas (and corresponding true labels) by pred_proba value | ||
sort_idxs = np.argsort(probas_pred, kind="mergesort")[::-1] | ||
probas_pred = probas_pred[sort_idxs] | ||
y_true = y_true[sort_idxs] | ||
|
||
# Get indices where values of probas_pred decreases | ||
thresh_idxs = np.r_[0, | ||
np.where(np.diff(probas_pred))[0] + 1, | ||
len(probas_pred)] | ||
decreasing_probas_indices = np.argsort(probas_pred, kind="mergesort")[::-1] | ||
probas_pred = probas_pred[decreasing_probas_indices] | ||
y_true = y_true[decreasing_probas_indices] | ||
|
||
# Probas_pred typically has many tied values. Here we extract | ||
# the indices associated with the distinct values. We also | ||
# concatenate values onto the ends of the curve. | ||
distinct_value_indices = np.where(np.diff(probas_pred))[0] + 1 | ||
threshold_idxs = np.r_[0, | ||
distinct_value_indices, | ||
len(probas_pred)] | ||
|
||
# Initialize true and false positive counts, precision and recall | ||
total_positive = float(y_true.sum()) | ||
|
@@ -871,20 +874,20 @@ def precision_recall_curve(y_true, probas_pred): | |
recall = [0.] | ||
thresholds = [] | ||
|
||
# Iterate over indices which indicate distinct values of probas_pred -- | ||
# each of these distinct values will be represented in the curve with a | ||
# coordinate in precision-recall space. To calculate the precision and | ||
# recall associated with each point, we use these indices to select all | ||
# labels associated with the predictions. By incrementally keeping track | ||
# of the number of positive and negative labels seen so far, we can | ||
# calculate precision and recall. | ||
for l_idx, r_idx in itertools.izip(thresh_idxs[:-1], thresh_idxs[1:]): | ||
thresh_labels = y_true[l_idx:r_idx] | ||
n_thresh = r_idx - l_idx | ||
n_pos_thresh = thresh_labels.sum() | ||
n_neg_thresh = n_thresh - n_pos_thresh | ||
tp_count += n_pos_thresh | ||
fp_count += n_neg_thresh | ||
# Iterate over indices which indicate distinct values (thresholds) of | ||
# probas_pred. Each of these threshold values will be represented in the | ||
# curve with a coordinate in precision-recall space. To calculate the | ||
# precision and recall associated with each point, we use these indices to | ||
# select all labels associated with the predictions. By incrementally | ||
# keeping track of the number of positive and negative labels seen so far, | ||
# we can calculate precision and recall. | ||
for l_idx, r_idx in izip(threshold_idxs[:-1], threshold_idxs[1:]): | ||
threshold_labels = y_true[l_idx:r_idx] | ||
n_at_threshold = r_idx - l_idx | ||
n_pos_at_threshold = threshold_labels.sum() | ||
n_neg_at_threshold = n_at_threshold - n_pos_at_threshold | ||
tp_count += n_pos_at_threshold | ||
fp_count += n_neg_at_threshold | ||
fn_count = total_positive - tp_count | ||
precision.append(tp_count / (tp_count + fp_count)) | ||
recall.append(tp_count / (tp_count + fn_count)) | ||
|