Skip to content

Commit

Permalink
SPARK-3568 [mllib] add ranking metrics
Browse files Browse the repository at this point in the history
Add common metrics for ranking algorithms (http://www-nlp.stanford.edu/IR-book/), including:
 - Mean Average Precision
 - Precisionn: top-n precision
 - Discounted cumulative gain (DCG) and NDCG

The following methods and the corresponding tests are implemented:

```
class RankingMetrics[T](predictionAndLabels: RDD[(Array[T], Array[T])]) {
  /* Returns the precsionk for each query */
  lazy val precAtK: RDD[Array[Double]]

  /**
   * param k the position to compute the truncated precision
   * return the average precision at the first k ranking positions
   */
  def precision(k: Int): Double

  /* Returns the average precision for each query */
  lazy val avePrec: RDD[Double]

  /*Returns the mean average precision (MAP) of all the queries*/
  lazy val meanAvePrec: Double

  /*Returns the normalized discounted cumulative gain for each query */
  lazy val ndcgAtK: RDD[Array[Double]]

  /**
   * param k the position to compute the truncated ndcg
   * return the average ndcg at the first k ranking positions
   */
  def ndcg(k: Int): Double
}
```

Author: coderxiang <[email protected]>

Closes #2667 from coderxiang/rankingmetrics and squashes the following commits:

d881097 [coderxiang] update doc
14d9cd9 [coderxiang] remove unexpected files
d7fb93f [coderxiang] style change and remove ignored files
f113ee1 [coderxiang] modify doc for displaying superscript and subscript
f626896 [coderxiang] improve doc and remove unnecessary computation while labSet is empty
be6645e [coderxiang] set the precision of empty labset to 0.0
d64c120 [coderxiang] add logWarning for empty ground truth set
dfae292 [coderxiang] handle empty labSet for map. add test
62047c4 [coderxiang] style change and add documentation
f66612d [coderxiang] add additional test of precisionAt
b794cb2 [coderxiang] move private members precAtK, ndcgAtK into public methods. style change
77c9e5d [coderxiang] set precAtK and ndcgAtK as private member. Improve documentation
5f87bce [coderxiang] add API to calculate precision and ndcg at each ranking position
b7851cc [coderxiang] Use generic type to represent IDs
e443fee [coderxiang] change style and use alternative builtin methods
3a5a6ff [coderxiang] add ranking metrics
  • Loading branch information
coderxiang authored and mengxr committed Oct 21, 2014
1 parent 5fdaf52 commit 814a9cd
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.evaluation

import scala.reflect.ClassTag

import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD

/**
* ::Experimental::
* Evaluator for ranking algorithms.
*
* @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs.
*/
@Experimental
class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])])
extends Logging with Serializable {

/**
* Compute the average precision of all the queries, truncated at ranking position k.
*
* If for a query, the ranking algorithm returns n (n < k) results, the precision value will be
* computed as #(relevant items retrieved) / k. This formula also applies when the size of the
* ground truth set is less than k.
*
* If a query has an empty ground truth set, zero will be used as precision together with
* a log warning.
*
* See the following paper for detail:
*
* IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
*
* @param k the position to compute the truncated precision, must be positive
* @return the average precision at the first k ranking positions
*/
def precisionAt(k: Int): Double = {
require(k > 0, "ranking position k should be positive")
predictionAndLabels.map { case (pred, lab) =>
val labSet = lab.toSet

if (labSet.nonEmpty) {
val n = math.min(pred.length, k)
var i = 0
var cnt = 0
while (i < n) {
if (labSet.contains(pred(i))) {
cnt += 1
}
i += 1
}
cnt.toDouble / k
} else {
logWarning("Empty ground truth set, check input data")
0.0
}
}.mean
}

/**
* Returns the mean average precision (MAP) of all the queries.
* If a query has an empty ground truth set, the average precision will be zero and a log
* warining is generated.
*/
lazy val meanAveragePrecision: Double = {
predictionAndLabels.map { case (pred, lab) =>
val labSet = lab.toSet

if (labSet.nonEmpty) {
var i = 0
var cnt = 0
var precSum = 0.0
val n = pred.length
while (i < n) {
if (labSet.contains(pred(i))) {
cnt += 1
precSum += cnt.toDouble / (i + 1)
}
i += 1
}
precSum / labSet.size
} else {
logWarning("Empty ground truth set, check input data")
0.0
}
}.mean
}

/**
* Compute the average NDCG value of all the queries, truncated at ranking position k.
* The discounted cumulative gain at position k is computed as:
* sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
* and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current
* implementation, the relevance value is binary.
* If a query has an empty ground truth set, zero will be used as ndcg together with
* a log warning.
*
* See the following paper for detail:
*
* IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
*
* @param k the position to compute the truncated ndcg, must be positive
* @return the average ndcg at the first k ranking positions
*/
def ndcgAt(k: Int): Double = {
require(k > 0, "ranking position k should be positive")
predictionAndLabels.map { case (pred, lab) =>
val labSet = lab.toSet

if (labSet.nonEmpty) {
val labSetSize = labSet.size
val n = math.min(math.max(pred.length, labSetSize), k)
var maxDcg = 0.0
var dcg = 0.0
var i = 0
while (i < n) {
val gain = 1.0 / math.log(i + 2)
if (labSet.contains(pred(i))) {
dcg += gain
}
if (i < labSetSize) {
maxDcg += gain
}
i += 1
}
dcg / maxDcg
} else {
logWarning("Empty ground truth set, check input data")
0.0
}
}.mean
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.evaluation

import org.scalatest.FunSuite

import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.util.LocalSparkContext

class RankingMetricsSuite extends FunSuite with LocalSparkContext {
test("Ranking metrics: map, ndcg") {
val predictionAndLabels = sc.parallelize(
Seq(
(Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)),
(Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)),
(Array[Int](1, 2, 3, 4, 5), Array[Int]())
), 2)
val eps: Double = 1E-5

val metrics = new RankingMetrics(predictionAndLabels)
val map = metrics.meanAveragePrecision

assert(metrics.precisionAt(1) ~== 1.0/3 absTol eps)
assert(metrics.precisionAt(2) ~== 1.0/3 absTol eps)
assert(metrics.precisionAt(3) ~== 1.0/3 absTol eps)
assert(metrics.precisionAt(4) ~== 0.75/3 absTol eps)
assert(metrics.precisionAt(5) ~== 0.8/3 absTol eps)
assert(metrics.precisionAt(10) ~== 0.8/3 absTol eps)
assert(metrics.precisionAt(15) ~== 8.0/45 absTol eps)

assert(map ~== 0.355026 absTol eps)

assert(metrics.ndcgAt(3) ~== 1.0/3 absTol eps)
assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps)
assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps)
assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps)

}
}

0 comments on commit 814a9cd

Please sign in to comment.