-
Notifications
You must be signed in to change notification settings - Fork 28.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SPARK-3568 [mllib] add ranking metrics
Add common metrics for ranking algorithms (http://www-nlp.stanford.edu/IR-book/), including: - Mean Average Precision - Precisionn: top-n precision - Discounted cumulative gain (DCG) and NDCG The following methods and the corresponding tests are implemented: ``` class RankingMetrics[T](predictionAndLabels: RDD[(Array[T], Array[T])]) { /* Returns the precsionk for each query */ lazy val precAtK: RDD[Array[Double]] /** * param k the position to compute the truncated precision * return the average precision at the first k ranking positions */ def precision(k: Int): Double /* Returns the average precision for each query */ lazy val avePrec: RDD[Double] /*Returns the mean average precision (MAP) of all the queries*/ lazy val meanAvePrec: Double /*Returns the normalized discounted cumulative gain for each query */ lazy val ndcgAtK: RDD[Array[Double]] /** * param k the position to compute the truncated ndcg * return the average ndcg at the first k ranking positions */ def ndcg(k: Int): Double } ``` Author: coderxiang <[email protected]> Closes #2667 from coderxiang/rankingmetrics and squashes the following commits: d881097 [coderxiang] update doc 14d9cd9 [coderxiang] remove unexpected files d7fb93f [coderxiang] style change and remove ignored files f113ee1 [coderxiang] modify doc for displaying superscript and subscript f626896 [coderxiang] improve doc and remove unnecessary computation while labSet is empty be6645e [coderxiang] set the precision of empty labset to 0.0 d64c120 [coderxiang] add logWarning for empty ground truth set dfae292 [coderxiang] handle empty labSet for map. add test 62047c4 [coderxiang] style change and add documentation f66612d [coderxiang] add additional test of precisionAt b794cb2 [coderxiang] move private members precAtK, ndcgAtK into public methods. style change 77c9e5d [coderxiang] set precAtK and ndcgAtK as private member. Improve documentation 5f87bce [coderxiang] add API to calculate precision and ndcg at each ranking position b7851cc [coderxiang] Use generic type to represent IDs e443fee [coderxiang] change style and use alternative builtin methods 3a5a6ff [coderxiang] add ranking metrics
- Loading branch information
1 parent
5fdaf52
commit 814a9cd
Showing
2 changed files
with
206 additions
and
0 deletions.
There are no files selected for viewing
152 changes: 152 additions & 0 deletions
152
mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.evaluation | ||
|
||
import scala.reflect.ClassTag | ||
|
||
import org.apache.spark.Logging | ||
import org.apache.spark.SparkContext._ | ||
import org.apache.spark.annotation.Experimental | ||
import org.apache.spark.rdd.RDD | ||
|
||
/** | ||
* ::Experimental:: | ||
* Evaluator for ranking algorithms. | ||
* | ||
* @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs. | ||
*/ | ||
@Experimental | ||
class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]) | ||
extends Logging with Serializable { | ||
|
||
/** | ||
* Compute the average precision of all the queries, truncated at ranking position k. | ||
* | ||
* If for a query, the ranking algorithm returns n (n < k) results, the precision value will be | ||
* computed as #(relevant items retrieved) / k. This formula also applies when the size of the | ||
* ground truth set is less than k. | ||
* | ||
* If a query has an empty ground truth set, zero will be used as precision together with | ||
* a log warning. | ||
* | ||
* See the following paper for detail: | ||
* | ||
* IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen | ||
* | ||
* @param k the position to compute the truncated precision, must be positive | ||
* @return the average precision at the first k ranking positions | ||
*/ | ||
def precisionAt(k: Int): Double = { | ||
require(k > 0, "ranking position k should be positive") | ||
predictionAndLabels.map { case (pred, lab) => | ||
val labSet = lab.toSet | ||
|
||
if (labSet.nonEmpty) { | ||
val n = math.min(pred.length, k) | ||
var i = 0 | ||
var cnt = 0 | ||
while (i < n) { | ||
if (labSet.contains(pred(i))) { | ||
cnt += 1 | ||
} | ||
i += 1 | ||
} | ||
cnt.toDouble / k | ||
} else { | ||
logWarning("Empty ground truth set, check input data") | ||
0.0 | ||
} | ||
}.mean | ||
} | ||
|
||
/** | ||
* Returns the mean average precision (MAP) of all the queries. | ||
* If a query has an empty ground truth set, the average precision will be zero and a log | ||
* warining is generated. | ||
*/ | ||
lazy val meanAveragePrecision: Double = { | ||
predictionAndLabels.map { case (pred, lab) => | ||
val labSet = lab.toSet | ||
|
||
if (labSet.nonEmpty) { | ||
var i = 0 | ||
var cnt = 0 | ||
var precSum = 0.0 | ||
val n = pred.length | ||
while (i < n) { | ||
if (labSet.contains(pred(i))) { | ||
cnt += 1 | ||
precSum += cnt.toDouble / (i + 1) | ||
} | ||
i += 1 | ||
} | ||
precSum / labSet.size | ||
} else { | ||
logWarning("Empty ground truth set, check input data") | ||
0.0 | ||
} | ||
}.mean | ||
} | ||
|
||
/** | ||
* Compute the average NDCG value of all the queries, truncated at ranking position k. | ||
* The discounted cumulative gain at position k is computed as: | ||
* sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1), | ||
* and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current | ||
* implementation, the relevance value is binary. | ||
* If a query has an empty ground truth set, zero will be used as ndcg together with | ||
* a log warning. | ||
* | ||
* See the following paper for detail: | ||
* | ||
* IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen | ||
* | ||
* @param k the position to compute the truncated ndcg, must be positive | ||
* @return the average ndcg at the first k ranking positions | ||
*/ | ||
def ndcgAt(k: Int): Double = { | ||
require(k > 0, "ranking position k should be positive") | ||
predictionAndLabels.map { case (pred, lab) => | ||
val labSet = lab.toSet | ||
|
||
if (labSet.nonEmpty) { | ||
val labSetSize = labSet.size | ||
val n = math.min(math.max(pred.length, labSetSize), k) | ||
var maxDcg = 0.0 | ||
var dcg = 0.0 | ||
var i = 0 | ||
while (i < n) { | ||
val gain = 1.0 / math.log(i + 2) | ||
if (labSet.contains(pred(i))) { | ||
dcg += gain | ||
} | ||
if (i < labSetSize) { | ||
maxDcg += gain | ||
} | ||
i += 1 | ||
} | ||
dcg / maxDcg | ||
} else { | ||
logWarning("Empty ground truth set, check input data") | ||
0.0 | ||
} | ||
}.mean | ||
} | ||
|
||
} |
54 changes: 54 additions & 0 deletions
54
mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.evaluation | ||
|
||
import org.scalatest.FunSuite | ||
|
||
import org.apache.spark.mllib.util.TestingUtils._ | ||
import org.apache.spark.mllib.util.LocalSparkContext | ||
|
||
class RankingMetricsSuite extends FunSuite with LocalSparkContext { | ||
test("Ranking metrics: map, ndcg") { | ||
val predictionAndLabels = sc.parallelize( | ||
Seq( | ||
(Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)), | ||
(Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)), | ||
(Array[Int](1, 2, 3, 4, 5), Array[Int]()) | ||
), 2) | ||
val eps: Double = 1E-5 | ||
|
||
val metrics = new RankingMetrics(predictionAndLabels) | ||
val map = metrics.meanAveragePrecision | ||
|
||
assert(metrics.precisionAt(1) ~== 1.0/3 absTol eps) | ||
assert(metrics.precisionAt(2) ~== 1.0/3 absTol eps) | ||
assert(metrics.precisionAt(3) ~== 1.0/3 absTol eps) | ||
assert(metrics.precisionAt(4) ~== 0.75/3 absTol eps) | ||
assert(metrics.precisionAt(5) ~== 0.8/3 absTol eps) | ||
assert(metrics.precisionAt(10) ~== 0.8/3 absTol eps) | ||
assert(metrics.precisionAt(15) ~== 8.0/45 absTol eps) | ||
|
||
assert(map ~== 0.355026 absTol eps) | ||
|
||
assert(metrics.ndcgAt(3) ~== 1.0/3 absTol eps) | ||
assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps) | ||
assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps) | ||
assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps) | ||
|
||
} | ||
} |