-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SPARK-3568 [mllib] add ranking metrics #2667
Closed
Closed
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
3a5a6ff
add ranking metrics
coderxiang e443fee
change style and use alternative builtin methods
coderxiang b7851cc
Use generic type to represent IDs
coderxiang 5f87bce
add API to calculate precision and ndcg at each ranking position
coderxiang 77c9e5d
set precAtK and ndcgAtK as private member. Improve documentation
coderxiang b794cb2
move private members precAtK, ndcgAtK into public methods. style change
coderxiang f66612d
add additional test of precisionAt
coderxiang 62047c4
style change and add documentation
coderxiang dfae292
handle empty labSet for map. add test
coderxiang d64c120
add logWarning for empty ground truth set
coderxiang be6645e
set the precision of empty labset to 0.0
coderxiang f626896
improve doc and remove unnecessary computation while labSet is empty
coderxiang f113ee1
modify doc for displaying superscript and subscript
coderxiang d7fb93f
style change and remove ignored files
coderxiang 14d9cd9
remove unexpected files
coderxiang d881097
update doc
coderxiang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
152 changes: 152 additions & 0 deletions
152
mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.evaluation | ||
|
||
import scala.reflect.ClassTag | ||
|
||
import org.apache.spark.Logging | ||
import org.apache.spark.SparkContext._ | ||
import org.apache.spark.annotation.Experimental | ||
import org.apache.spark.rdd.RDD | ||
|
||
/** | ||
* ::Experimental:: | ||
* Evaluator for ranking algorithms. | ||
* | ||
* @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs. | ||
*/ | ||
@Experimental | ||
class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]) | ||
extends Logging with Serializable { | ||
|
||
/** | ||
* Compute the average precision of all the queries, truncated at ranking position k. | ||
* | ||
* If for a query, the ranking algorithm returns n (n < k) results, the precision value will be | ||
* computed as #(relevant items retrieved) / k. This formula also applies when the size of the | ||
* ground truth set is less than k. | ||
* | ||
* If a query has an empty ground truth set, zero will be used as precision together with | ||
* a log warning. | ||
* | ||
* See the following paper for detail: | ||
* | ||
* IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen | ||
* | ||
* @param k the position to compute the truncated precision, must be positive | ||
* @return the average precision at the first k ranking positions | ||
*/ | ||
def precisionAt(k: Int): Double = { | ||
require(k > 0, "ranking position k should be positive") | ||
predictionAndLabels.map { case (pred, lab) => | ||
val labSet = lab.toSet | ||
|
||
if (labSet.nonEmpty) { | ||
val n = math.min(pred.length, k) | ||
var i = 0 | ||
var cnt = 0 | ||
while (i < n) { | ||
if (labSet.contains(pred(i))) { | ||
cnt += 1 | ||
} | ||
i += 1 | ||
} | ||
cnt.toDouble / k | ||
} else { | ||
logWarning("Empty ground truth set, check input data") | ||
0.0 | ||
} | ||
}.mean | ||
} | ||
|
||
/** | ||
* Returns the mean average precision (MAP) of all the queries. | ||
* If a query has an empty ground truth set, the average precision will be zero and a log | ||
* warining is generated. | ||
*/ | ||
lazy val meanAveragePrecision: Double = { | ||
predictionAndLabels.map { case (pred, lab) => | ||
val labSet = lab.toSet | ||
|
||
if (labSet.nonEmpty) { | ||
var i = 0 | ||
var cnt = 0 | ||
var precSum = 0.0 | ||
val n = pred.length | ||
while (i < n) { | ||
if (labSet.contains(pred(i))) { | ||
cnt += 1 | ||
precSum += cnt.toDouble / (i + 1) | ||
} | ||
i += 1 | ||
} | ||
precSum / labSet.size | ||
} else { | ||
logWarning("Empty ground truth set, check input data") | ||
0.0 | ||
} | ||
}.mean | ||
} | ||
|
||
/** | ||
* Compute the average NDCG value of all the queries, truncated at ranking position k. | ||
* The discounted cumulative gain at position k is computed as: | ||
* sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1), | ||
* and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current | ||
* implementation, the relevance value is binary. | ||
|
||
* If a query has an empty ground truth set, zero will be used as ndcg together with | ||
* a log warning. | ||
* | ||
* See the following paper for detail: | ||
* | ||
* IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen | ||
* | ||
* @param k the position to compute the truncated ndcg, must be positive | ||
* @return the average ndcg at the first k ranking positions | ||
*/ | ||
def ndcgAt(k: Int): Double = { | ||
require(k > 0, "ranking position k should be positive") | ||
predictionAndLabels.map { case (pred, lab) => | ||
val labSet = lab.toSet | ||
|
||
if (labSet.nonEmpty) { | ||
val labSetSize = labSet.size | ||
val n = math.min(math.max(pred.length, labSetSize), k) | ||
var maxDcg = 0.0 | ||
var dcg = 0.0 | ||
var i = 0 | ||
while (i < n) { | ||
val gain = 1.0 / math.log(i + 2) | ||
if (labSet.contains(pred(i))) { | ||
dcg += gain | ||
} | ||
if (i < labSetSize) { | ||
maxDcg += gain | ||
} | ||
i += 1 | ||
} | ||
dcg / maxDcg | ||
} else { | ||
logWarning("Empty ground truth set, check input data") | ||
0.0 | ||
} | ||
}.mean | ||
} | ||
|
||
} |
54 changes: 54 additions & 0 deletions
54
mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.evaluation | ||
|
||
import org.scalatest.FunSuite | ||
|
||
import org.apache.spark.mllib.util.TestingUtils._ | ||
import org.apache.spark.mllib.util.LocalSparkContext | ||
|
||
class RankingMetricsSuite extends FunSuite with LocalSparkContext { | ||
test("Ranking metrics: map, ndcg") { | ||
val predictionAndLabels = sc.parallelize( | ||
Seq( | ||
(Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)), | ||
(Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)), | ||
(Array[Int](1, 2, 3, 4, 5), Array[Int]()) | ||
), 2) | ||
val eps: Double = 1E-5 | ||
|
||
val metrics = new RankingMetrics(predictionAndLabels) | ||
val map = metrics.meanAveragePrecision | ||
|
||
assert(metrics.precisionAt(1) ~== 1.0/3 absTol eps) | ||
assert(metrics.precisionAt(2) ~== 1.0/3 absTol eps) | ||
assert(metrics.precisionAt(3) ~== 1.0/3 absTol eps) | ||
assert(metrics.precisionAt(4) ~== 0.75/3 absTol eps) | ||
assert(metrics.precisionAt(5) ~== 0.8/3 absTol eps) | ||
assert(metrics.precisionAt(10) ~== 0.8/3 absTol eps) | ||
assert(metrics.precisionAt(15) ~== 8.0/45 absTol eps) | ||
|
||
assert(map ~== 0.355026 absTol eps) | ||
|
||
assert(metrics.ndcgAt(3) ~== 1.0/3 absTol eps) | ||
assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps) | ||
assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps) | ||
assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps) | ||
|
||
} | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The inputs are really ranks, right? would this not be more natural as
Int
then?I might have expected that the inputs were instead predicted and ground truth "scores" instead, in which case
Double
makes sense. But then the methods need to convert to rankings.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@srowen Thanks for all the comments! The current implementation only considers binary relevance, meaning the input is just labels instead of scores. It is true that
Int
is enough. IMHO, usingDouble
is compatible with the current mllib setting and could be extended to deal with the score setting.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, are these IDs of some kind then? the example makes them look like rankings, but maybe that's coincidence. They shouldn't be labels, right? Because the resulting set would almost always be {0.0,1.0}.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@srowen Take the first test case for example,
(Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1 to 5))
, this means for this single query, the ideal ranking algorithm should return Document 1 to Document 5. However, the ranking algorithm returns 10 documents, with IDs (1, 6, .....). This setting right now only works for binary relevance.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I get it. But these could also just as easily be Strings maybe? like document IDs? anything you can put into a set and look for. Could this even be generified to accept any type? At least,
Double
seemed like the least likely type for an ID. It's not even what MLlib overloads to mean "ID"; for example ALS assumesInt
IDs.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on @srowen 's suggestion. We can use a generic type here with a ClassTag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... and if the code is just going to turn the arguments into
Set
then it need not be anRDD
ofArray
, but something more generic?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@srowen We can optimize the implementation later, if it is worth doing, and using Array may help. Usually, both predictions and labels are small. Scanning it sequentially may be faster than
toSeq
andcontains
. Array is also consistent across Java, Scala, and Python, and storage efficient for primitive types.