Skip to content

Commit

Permalink
[SPARK-6094] [MLLIB] Add MultilabelMetrics in PySpark/MLlib
Browse files Browse the repository at this point in the history
Add MultilabelMetrics in PySpark/MLlib

Author: Yanbo Liang <[email protected]>

Closes #6276 from yanboliang/spark-6094 and squashes the following commits:

b8e3343 [Yanbo Liang] Add MultilabelMetrics in PySpark/MLlib

(cherry picked from commit 98a46f9)
Signed-off-by: Xiangrui Meng <[email protected]>
  • Loading branch information
yanboliang authored and mengxr committed May 20, 2015
1 parent 996e2d4 commit 606ae3e
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.mllib.evaluation

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.sql.DataFrame

/**
* Evaluator for multilabel classification.
Expand All @@ -27,6 +28,13 @@ import org.apache.spark.SparkContext._
*/
class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) {

/**
* An auxiliary constructor taking a DataFrame.
* @param predictionAndLabels a DataFrame with two double array columns: prediction and label
*/
private[mllib] def this(predictionAndLabels: DataFrame) =
this(predictionAndLabels.map(r => (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray)))

private lazy val numDocs: Long = predictionAndLabels.count()

private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) =>
Expand Down
117 changes: 117 additions & 0 deletions python/pyspark/mllib/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,123 @@ def ndcgAt(self, k):
return self.call("ndcgAt", int(k))


class MultilabelMetrics(JavaModelWrapper):
"""
Evaluator for multilabel classification.
>>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]),
... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]),
... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])])
>>> metrics = MultilabelMetrics(predictionAndLabels)
>>> metrics.precision(0.0)
1.0
>>> metrics.recall(1.0)
0.66...
>>> metrics.f1Measure(2.0)
0.5
>>> metrics.precision()
0.66...
>>> metrics.recall()
0.64...
>>> metrics.f1Measure()
0.63...
>>> metrics.microPrecision
0.72...
>>> metrics.microRecall
0.66...
>>> metrics.microF1Measure
0.69...
>>> metrics.hammingLoss
0.33...
>>> metrics.subsetAccuracy
0.28...
>>> metrics.accuracy
0.54...
"""

def __init__(self, predictionAndLabels):
sc = predictionAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndLabels,
schema=sql_ctx._inferSchema(predictionAndLabels))
java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics
java_model = java_class(df._jdf)
super(MultilabelMetrics, self).__init__(java_model)

def precision(self, label=None):
"""
Returns precision or precision for a given label (category) if specified.
"""
if label is None:
return self.call("precision")
else:
return self.call("precision", float(label))

def recall(self, label=None):
"""
Returns recall or recall for a given label (category) if specified.
"""
if label is None:
return self.call("recall")
else:
return self.call("recall", float(label))

def f1Measure(self, label=None):
"""
Returns f1Measure or f1Measure for a given label (category) if specified.
"""
if label is None:
return self.call("f1Measure")
else:
return self.call("f1Measure", float(label))

@property
def microPrecision(self):
"""
Returns micro-averaged label-based precision.
(equals to micro-averaged document-based precision)
"""
return self.call("microPrecision")

@property
def microRecall(self):
"""
Returns micro-averaged label-based recall.
(equals to micro-averaged document-based recall)
"""
return self.call("microRecall")

@property
def microF1Measure(self):
"""
Returns micro-averaged label-based f1-measure.
(equals to micro-averaged document-based f1-measure)
"""
return self.call("microF1Measure")

@property
def hammingLoss(self):
"""
Returns Hamming-loss.
"""
return self.call("hammingLoss")

@property
def subsetAccuracy(self):
"""
Returns subset accuracy.
(for equal sets of labels)
"""
return self.call("subsetAccuracy")

@property
def accuracy(self):
"""
Returns accuracy.
"""
return self.call("accuracy")


def _test():
import doctest
from pyspark import SparkContext
Expand Down

0 comments on commit 606ae3e

Please sign in to comment.