Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-3568 [mllib] add ranking metrics #2667

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.evaluation

import scala.reflect.ClassTag

import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD

/**
* ::Experimental::
* Evaluator for ranking algorithms.
*
* @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inputs are really ranks, right? would this not be more natural as Int then?
I might have expected that the inputs were instead predicted and ground truth "scores" instead, in which case Double makes sense. But then the methods need to convert to rankings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srowen Thanks for all the comments! The current implementation only considers binary relevance, meaning the input is just labels instead of scores. It is true that Int is enough. IMHO, using Double is compatible with the current mllib setting and could be extended to deal with the score setting.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, are these IDs of some kind then? the example makes them look like rankings, but maybe that's coincidence. They shouldn't be labels, right? Because the resulting set would almost always be {0.0,1.0}.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srowen Take the first test case for example, (Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1 to 5)), this means for this single query, the ideal ranking algorithm should return Document 1 to Document 5. However, the ranking algorithm returns 10 documents, with IDs (1, 6, .....). This setting right now only works for binary relevance.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I get it. But these could also just as easily be Strings maybe? like document IDs? anything you can put into a set and look for. Could this even be generified to accept any type? At least, Double seemed like the least likely type for an ID. It's not even what MLlib overloads to mean "ID"; for example ALS assumes Int IDs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on @srowen 's suggestion. We can use a generic type here with a ClassTag.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... and if the code is just going to turn the arguments into Set then it need not be an RDD of Array, but something more generic?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srowen We can optimize the implementation later, if it is worth doing, and using Array may help. Usually, both predictions and labels are small. Scanning it sequentially may be faster than toSeq and contains. Array is also consistent across Java, Scala, and Python, and storage efficient for primitive types.

*/
@Experimental
class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])])
extends Logging with Serializable {

/**
* Compute the average precision of all the queries, truncated at ranking position k.
*
* If for a query, the ranking algorithm returns n (n < k) results, the precision value will be
* computed as #(relevant items retrieved) / k. This formula also applies when the size of the
* ground truth set is less than k.
*
* If a query has an empty ground truth set, zero will be used as precision together with
* a log warning.
*
* See the following paper for detail:
*
* IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
*
* @param k the position to compute the truncated precision, must be positive
* @return the average precision at the first k ranking positions
*/
def precisionAt(k: Int): Double = {
require(k > 0, "ranking position k should be positive")
predictionAndLabels.map { case (pred, lab) =>
val labSet = lab.toSet

if (labSet.nonEmpty) {
val n = math.min(pred.length, k)
var i = 0
var cnt = 0
while (i < n) {
if (labSet.contains(pred(i))) {
cnt += 1
}
i += 1
}
cnt.toDouble / k
} else {
logWarning("Empty ground truth set, check input data")
0.0
}
}.mean
}

/**
* Returns the mean average precision (MAP) of all the queries.
* If a query has an empty ground truth set, the average precision will be zero and a log
* warining is generated.
*/
lazy val meanAveragePrecision: Double = {
predictionAndLabels.map { case (pred, lab) =>
val labSet = lab.toSet

if (labSet.nonEmpty) {
var i = 0
var cnt = 0
var precSum = 0.0
val n = pred.length
while (i < n) {
if (labSet.contains(pred(i))) {
cnt += 1
precSum += cnt.toDouble / (i + 1)
}
i += 1
}
precSum / labSet.size
} else {
logWarning("Empty ground truth set, check input data")
0.0
}
}.mean
}

/**
* Compute the average NDCG value of all the queries, truncated at ranking position k.
* The discounted cumulative gain at position k is computed as:
* sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
* and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current
* implementation, the relevance value is binary.

* If a query has an empty ground truth set, zero will be used as ndcg together with
* a log warning.
*
* See the following paper for detail:
*
* IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
*
* @param k the position to compute the truncated ndcg, must be positive
* @return the average ndcg at the first k ranking positions
*/
def ndcgAt(k: Int): Double = {
require(k > 0, "ranking position k should be positive")
predictionAndLabels.map { case (pred, lab) =>
val labSet = lab.toSet

if (labSet.nonEmpty) {
val labSetSize = labSet.size
val n = math.min(math.max(pred.length, labSetSize), k)
var maxDcg = 0.0
var dcg = 0.0
var i = 0
while (i < n) {
val gain = 1.0 / math.log(i + 2)
if (labSet.contains(pred(i))) {
dcg += gain
}
if (i < labSetSize) {
maxDcg += gain
}
i += 1
}
dcg / maxDcg
} else {
logWarning("Empty ground truth set, check input data")
0.0
}
}.mean
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.evaluation

import org.scalatest.FunSuite

import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.util.LocalSparkContext

class RankingMetricsSuite extends FunSuite with LocalSparkContext {
test("Ranking metrics: map, ndcg") {
val predictionAndLabels = sc.parallelize(
Seq(
(Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)),
(Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)),
(Array[Int](1, 2, 3, 4, 5), Array[Int]())
), 2)
val eps: Double = 1E-5

val metrics = new RankingMetrics(predictionAndLabels)
val map = metrics.meanAveragePrecision

assert(metrics.precisionAt(1) ~== 1.0/3 absTol eps)
assert(metrics.precisionAt(2) ~== 1.0/3 absTol eps)
assert(metrics.precisionAt(3) ~== 1.0/3 absTol eps)
assert(metrics.precisionAt(4) ~== 0.75/3 absTol eps)
assert(metrics.precisionAt(5) ~== 0.8/3 absTol eps)
assert(metrics.precisionAt(10) ~== 0.8/3 absTol eps)
assert(metrics.precisionAt(15) ~== 8.0/45 absTol eps)

assert(map ~== 0.355026 absTol eps)

assert(metrics.ndcgAt(3) ~== 1.0/3 absTol eps)
assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps)
assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps)
assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps)

}
}