Skip to content

Commit

Permalink
move private members precAtK, ndcgAtK into public methods. style change
Browse files Browse the repository at this point in the history
  • Loading branch information
coderxiang committed Oct 15, 2014
1 parent 77c9e5d commit b794cb2
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,6 @@ import org.apache.spark.rdd.RDD
@Experimental
class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]) {

/**
* Returns the precsion@k for each query
*/
private lazy val precAtK: RDD[Array[Double]] = predictionAndLabels.map{ case (pred, lab)=>
val labSet = lab.toSet
val n = pred.length
val topKPrec = new Array[Double](n)
var i = 0
var cnt = 0

while (i < n) {
if (labSet.contains(pred(i))) {
cnt += 1
}
topKPrec(i) = cnt.toDouble / (i + 1)
i += 1
}
topKPrec
}

/**
* Compute the average precision of all the queries, truncated at ranking position k.
* If for a query, the ranking algorithm returns n (n < k) results,
Expand All @@ -64,19 +44,25 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
* @param k the position to compute the truncated precision
* @return the average precision at the first k ranking positions
*/
def precision(k: Int): Double = precAtK.map {topKPrec =>
val n = topKPrec.length
if (k <= n) {
topKPrec(k - 1)
} else {
topKPrec(n - 1) * n / k
def precisionAt(k: Int): Double = predictionAndLabels.map { case (pred, lab) =>
val labSet = lab.toSet
val n = math.min(pred.length, k)
var i = 0
var cnt = 0

while (i < n) {
if (labSet.contains(pred(i))) {
cnt += 1
}
i += 1
}
cnt.toDouble / k
}.mean

/**
* Returns the average precision for each query
* Returns the mean average precision (MAP) of all the queries
*/
private lazy val avePrec: RDD[Double] = predictionAndLabels.map {case (pred, lab) =>
lazy val meanAveragePrecision: Double = predictionAndLabels.map { case (pred, lab) =>
val labSet = lab.toSet
var i = 0
var cnt = 0
Expand All @@ -91,54 +77,39 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
i += 1
}
precSum / labSet.size
}

/**
* Returns the mean average precision (MAP) of all the queries
*/
lazy val meanAveragePrecision: Double = avePrec.mean
}.mean

/**
* Returns the normalized discounted cumulative gain for each query
* Compute the average NDCG value of all the queries, truncated at ranking position k.
* If for a query, the ranking algorithm returns n (n < k) results, the NDCG value at
* at position n will be used. See the following paper for detail:
*
* IR evaluation methods for retrieving highly relevant documents.
* K. Jarvelin and J. Kekalainen
*
* @param k the position to compute the truncated ndcg
* @return the average ndcg at the first k ranking positions
*/
private lazy val ndcgAtK: RDD[Array[Double]] = predictionAndLabels.map {case (pred, lab) =>
def ndcgAt(k: Int): Double = predictionAndLabels.map { case (pred, lab) =>
val labSet = lab.toSet
val labSetSize = labSet.size
val n = math.max(pred.length, labSetSize)
val topKNdcg = new Array[Double](n)
val n = math.min(math.max(pred.length, labSetSize), k)
var maxDcg = 0.0
var dcg = 0.0
var i = 0

while (i < n) {
/** Calculate 1/log2(i + 2) */
// Calculate 1/log2(i + 2)
val gain = math.log(2) / math.log(i + 2)
if (labSet.contains(pred(i))) {
dcg += gain
}
if (i < labSetSize) {
maxDcg += gain
}
topKNdcg(i) = dcg / maxDcg
i += 1
}
topKNdcg
}

/**
* Compute the average NDCG value of all the queries, truncated at ranking position k.
* If for a query, the ranking algorithm returns n (n < k) results, the NDCG value at
* at position n will be used. See the following paper for detail:
*
* IR evaluation methods for retrieving highly relevant documents.
* K. Jarvelin and J. Kekalainen
*
* @param k the position to compute the truncated ndcg
* @return the average ndcg at the first k ranking positions
*/
def ndcg(k: Int): Double = ndcgAtK.map {topKNdcg =>
val pos = math.min(k, topKNdcg.length) - 1
topKNdcg(pos)
dcg / maxDcg
}.mean

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.mllib.evaluation

import org.scalatest.FunSuite

import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.util.LocalSparkContext

Expand All @@ -33,18 +34,18 @@ class RankingMetricsSuite extends FunSuite with LocalSparkContext {
val metrics = new RankingMetrics(predictionAndLabels)
val map = metrics.meanAveragePrecision

assert(metrics.precision(1) ~== 0.5 absTol eps)
assert(metrics.precision(2) ~== 0.5 absTol eps)
assert(metrics.precision(3) ~== 0.5 absTol eps)
assert(metrics.precision(4) ~== 0.375 absTol eps)
assert(metrics.precision(10) ~== 0.4 absTol eps)
assert(metrics.precisionAt(1) ~== 0.5 absTol eps)
assert(metrics.precisionAt(2) ~== 0.5 absTol eps)
assert(metrics.precisionAt(3) ~== 0.5 absTol eps)
assert(metrics.precisionAt(4) ~== 0.375 absTol eps)
assert(metrics.precisionAt(10) ~== 0.4 absTol eps)

assert(map ~== 0.532539 absTol eps)

assert(metrics.ndcg(3) ~== 0.5 absTol eps)
assert(metrics.ndcg(5) ~== 0.493182 absTol eps)
assert(metrics.ndcg(10) ~== 0.731869 absTol eps)
assert(metrics.ndcg(15) ~== metrics.ndcg(10) absTol eps)
assert(metrics.ndcgAt(3) ~== 0.5 absTol eps)
assert(metrics.ndcgAt(5) ~== 0.493182 absTol eps)
assert(metrics.ndcgAt(10) ~== 0.731869 absTol eps)
assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps)

}
}

0 comments on commit b794cb2

Please sign in to comment.