-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SPARK-3568 [mllib] add ranking metrics #2667
Changes from 1 commit
3a5a6ff
e443fee
b7851cc
5f87bce
77c9e5d
b794cb2
f66612d
62047c4
dfae292
d64c120
be6645e
f626896
f113ee1
d7fb93f
14d9cd9
d881097
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.evaluation | ||
|
||
|
||
import org.apache.spark.SparkContext._ | ||
import org.apache.spark.annotation.Experimental | ||
import org.apache.spark.rdd.RDD | ||
|
||
|
||
/** | ||
* ::Experimental:: | ||
* Evaluator for ranking algorithms. | ||
* | ||
* @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs. | ||
*/ | ||
@Experimental | ||
class RankingMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might check that arguments are not empty and of equal length? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I should check non-empty. It is not necessary that these two have same length, as the first array contains all the items we believe to be relevant, it could be more or less than the ground truth set. |
||
|
||
/** | ||
* Returns the precsion@k for each query | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might actually use |
||
*/ | ||
lazy val precAtK: RDD[Array[Double]] = predictionAndLabels.map {case (pred, lab)=> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Returning an RDD may not be useful in evaluation. Usually people look for scalar metrics. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should it be private? Think about what users call it for. Even we make it public, users may still need some aggregate metrics out of it. For the first version, I think it is safe to provide only the following:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw, I think it is better to use |
||
val labSet : Set[Double] = lab.toSet | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given my previous comment, maybe I'm missing something, but isn't one of the two arguments always going to be 1 to n? either you are ranking the predicted top n versus real rankings, or evaluating the predicted ranking of the known top n... ? I think I would have expected the input to be the predicted top n items by ID or something, and the IDs of the real top n, and then making a set and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current setting is exactly the latter case you mentioned. |
||
val n = pred.length | ||
val topkPrec = Array.fill[Double](n)(.0) | ||
var (i, cnt) = (0, 0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
while (i < n) { | ||
if (labSet.contains(pred(i))) { | ||
cnt += 1 | ||
} | ||
topkPrec(i) = cnt.toDouble / (i + 1) | ||
i += 1 | ||
} | ||
topkPrec | ||
} | ||
|
||
/** | ||
* Returns the average precision for each query | ||
*/ | ||
lazy val avePrec: RDD[Double] = predictionAndLabels.map {case (pred, lab) => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as |
||
val labSet: Set[Double] = lab.toSet | ||
var (i, cnt, precSum) = (0, 0, .0) | ||
val n = pred.length | ||
|
||
while (i < n) { | ||
if (labSet.contains(pred(i))) { | ||
cnt += 1 | ||
precSum += cnt.toDouble / (i + 1) | ||
} | ||
i += 1 | ||
} | ||
precSum / labSet.size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} | ||
|
||
/** | ||
* Returns the mean average precision (MAP) of all the queries | ||
*/ | ||
lazy val meanAvePrec: Double = computeMean(avePrec) | ||
|
||
/** | ||
* Returns the normalized discounted cumulative gain for each query | ||
*/ | ||
lazy val ndcg: RDD[Double] = predictionAndLabels.map {case (pred, lab) => | ||
val labSet = lab.toSet | ||
val n = math.min(pred.length, labSet.size) | ||
var (maxDcg, dcg, i) = (.0, .0, 0) | ||
while (i < n) { | ||
/* Calculate 1/log2(i + 2) */ | ||
val gain = 1.0 / (math.log(i + 2) / math.log(2)) | ||
if (labSet.contains(pred(i))) { | ||
dcg += gain | ||
} | ||
maxDcg += gain | ||
i += 1 | ||
} | ||
dcg / maxDcg | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} | ||
|
||
/** | ||
* Returns the mean NDCG of all the queries | ||
*/ | ||
lazy val meanNdcg: Double = computeMean(ndcg) | ||
|
||
private def computeMean(data: RDD[Double]): Double = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
val stat = data.aggregate((.0, 0))( | ||
seqOp = (c, v) => (c, v) match {case ((sum, cnt), a) => (sum + a, cnt + 1)}, | ||
combOp = (c1, c2) => (c1, c2) match {case (x, y) => (x._1 + y._1, x._2 + y._2)} | ||
) | ||
stat._1 / stat._2 | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.evaluation | ||
|
||
import org.scalatest.FunSuite | ||
|
||
import org.apache.spark.mllib.util.LocalSparkContext | ||
|
||
class RankingMetricsSuite extends FunSuite with LocalSparkContext { | ||
test("Ranking metrics: map, ndcg") { | ||
val predictionAndLabels = sc.parallelize( | ||
Seq( | ||
(Array[Double](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Double](1, 2, 3, 4, 5)), | ||
(Array[Double](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Double](1, 2, 3)) | ||
), 2) | ||
val eps: Double = 1e-5 | ||
|
||
val metrics = new RankingMetrics(predictionAndLabels) | ||
val precAtK = metrics.precAtK.collect() | ||
val avePrec = metrics.avePrec.collect() | ||
val map = metrics.meanAvePrec | ||
val ndcg = metrics.ndcg.collect() | ||
val aveNdcg = metrics.meanNdcg | ||
|
||
assert(math.abs(precAtK(0)(4) - 0.4) < eps) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check out the |
||
assert(math.abs(precAtK(1)(6) - 3.0/7) < eps) | ||
assert(math.abs(avePrec(0) - 0.622222) < eps) | ||
assert(math.abs(avePrec(1) - 0.442857) < eps) | ||
assert(math.abs(map - 0.532539) < eps) | ||
assert(math.abs(ndcg(0) - 0.508740) < eps) | ||
assert(math.abs(ndcg(1) - 0.296082) < eps) | ||
assert(math.abs(aveNdcg - 0.402411) < eps) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The inputs are really ranks, right? would this not be more natural as
Int
then?I might have expected that the inputs were instead predicted and ground truth "scores" instead, in which case
Double
makes sense. But then the methods need to convert to rankings.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@srowen Thanks for all the comments! The current implementation only considers binary relevance, meaning the input is just labels instead of scores. It is true that
Int
is enough. IMHO, usingDouble
is compatible with the current mllib setting and could be extended to deal with the score setting.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, are these IDs of some kind then? the example makes them look like rankings, but maybe that's coincidence. They shouldn't be labels, right? Because the resulting set would almost always be {0.0,1.0}.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@srowen Take the first test case for example,
(Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1 to 5))
, this means for this single query, the ideal ranking algorithm should return Document 1 to Document 5. However, the ranking algorithm returns 10 documents, with IDs (1, 6, .....). This setting right now only works for binary relevance.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I get it. But these could also just as easily be Strings maybe? like document IDs? anything you can put into a set and look for. Could this even be generified to accept any type? At least,
Double
seemed like the least likely type for an ID. It's not even what MLlib overloads to mean "ID"; for example ALS assumesInt
IDs.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on @srowen 's suggestion. We can use a generic type here with a ClassTag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... and if the code is just going to turn the arguments into
Set
then it need not be anRDD
ofArray
, but something more generic?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@srowen We can optimize the implementation later, if it is worth doing, and using Array may help. Usually, both predictions and labels are small. Scanning it sequentially may be faster than
toSeq
andcontains
. Array is also consistent across Java, Scala, and Python, and storage efficient for primitive types.