Skip to content

Commit

Permalink
Added UDTs for Vectors in MLlib, plus DatasetExample using the UDTs
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Nov 2, 2014
1 parent cd60cb4 commit dff99d6
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.mllib

import scopt.OptionParser

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext, SchemaRDD}

/**
* An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with
* {{{
* ./bin/run-example org.apache.spark.examples.mllib.Dataset [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
object DatasetExample {

case class Params(
input: String = "data/mllib/sample_libsvm_data.txt",
dataFormat: String = "libsvm") extends AbstractParams[Params]

def main(args: Array[String]) {
val defaultParams = Params()

val parser = new OptionParser[Params]("Dataset") {
head("Dataset: an example app using SchemaRDD as a Dataset for ML.")
opt[String]("input")
.text(s"input path to dataset")
.action((x, c) => c.copy(input = x))
opt[String]("dataFormat")
.text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
.action((x, c) => c.copy(input = x))
checkConfig { params =>
success
}
}

parser.parse(args, defaultParams).map { params =>
run(params)
}.getOrElse {
sys.exit(1)
}
}

def run(params: Params) {

val conf = new SparkConf().setAppName(s"Dataset with $params")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext._ // for implicit conversions

// Load input data
val origData: RDD[LabeledPoint] = params.dataFormat match {
case "dense" => MLUtils.loadLabeledPoints(sc, params.input)
case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input)
}
println(s"Loaded ${origData.count()} instances from file: ${params.input}")

// Convert input data to SchemaRDD explicitly.
val schemaRDD: SchemaRDD = origData
println(s"Converted to SchemaRDD with ${schemaRDD.count()} records")

// Select columns, using implicit conversion to SchemaRDD.
val labelsSchemaRDD: SchemaRDD = origData.select('label)
val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v }
val numLabels = labels.count()
val meanLabel = labels.fold(0.0)(_ + _) / numLabels
println(s"Selected label column with average value $meanLabel")

val featuresSchemaRDD: SchemaRDD = origData.select('features)
val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v }
val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
(summary, feat) => summary.add(feat),
(sum1, sum2) => sum1.merge(sum2))
println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")

sc.stop()
}

}
131 changes: 130 additions & 1 deletion mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,13 @@ import scala.collection.JavaConverters._

import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}

import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.SparkException
import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.sql.catalyst.UDTRegistry
import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.Row

/**
* Represents a numeric vector, whose index type is Int and value type is Double.
Expand Down Expand Up @@ -81,6 +86,8 @@ sealed trait Vector extends Serializable {
*/
object Vectors {

UDTRegistry.registerType(scala.reflect.runtime.universe.typeOf[Vector], new VectorUDT())

/**
* Creates a dense vector from its values.
*/
Expand Down Expand Up @@ -191,6 +198,7 @@ object Vectors {
/**
* A dense vector represented by a value array.
*/
@SQLUserDefinedType(udt = classOf[DenseVectorUDT])
class DenseVector(val values: Array[Double]) extends Vector {

override def size: Int = values.length
Expand Down Expand Up @@ -242,3 +250,124 @@ class SparseVector(

private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)
}

/**
* User-defined type for [[Vector]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.SchemaRDD]].
*/
private[spark] class VectorUDT extends UserDefinedType[Vector] {

/**
* vectorType: 0 = dense, 1 = sparse.
* dense, sparse: One element holds the vector, and the other is null.
*/
override def sqlType: StructType = StructType(Seq(
StructField("vectorType", ByteType, nullable = false),
StructField("dense", new DenseVectorUDT(), nullable = true),
StructField("sparse", new SparseVectorUDT(), nullable = true)))

override def serialize(obj: Any): Row = {
val row = new GenericMutableRow(3)
obj match {
case v: DenseVector =>
row.setByte(0, 0)
row.update(1, new DenseVectorUDT().serialize(obj))
row.setNullAt(2)
case v: SparseVector =>
row.setByte(0, 1)
row.setNullAt(1)
row.update(2, new SparseVectorUDT().serialize(obj))
}
row
}

override def deserialize(row: Row): Vector = {
require(row.length == 3,
s"VectorUDT.deserialize given row with length ${row.length} but requires length == 3")
val vectorType = row.getByte(0)
vectorType match {
case 0 =>
new DenseVectorUDT().deserialize(row.getAs[Row](1))
case 1 =>
new SparseVectorUDT().deserialize(row.getAs[Row](2))
}
}
}

/**
* User-defined type for [[DenseVector]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.SchemaRDD]].
*/
private[spark] class DenseVectorUDT extends UserDefinedType[DenseVector] {

override def sqlType: ArrayType = ArrayType(DoubleType, containsNull = false)

override def serialize(obj: Any): Row = obj match {
case v: DenseVector =>
val row: GenericMutableRow = new GenericMutableRow(v.size)
var i = 0
while (i < v.size) {
row.setDouble(i, v(i))
i += 1
}
row
}

override def deserialize(row: Row): DenseVector = {
val values = new Array[Double](row.length)
var i = 0
while (i < row.length) {
values(i) = row.getDouble(i)
i += 1
}
new DenseVector(values)
}
}

/**
* User-defined type for [[SparseVector]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.SchemaRDD]].
*/
private[spark] class SparseVectorUDT extends UserDefinedType[SparseVector] {

override def sqlType: StructType = StructType(Seq(
StructField("size", IntegerType, nullable = false),
StructField("indices", ArrayType(DoubleType, containsNull = false), nullable = false),
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = false)))

override def serialize(obj: Any): Row = obj match {
case v: SparseVector =>
val nnz = v.indices.size
val row: GenericMutableRow = new GenericMutableRow(1 + 2 * nnz)
row.setInt(0, v.size)
var i = 0
while (i < nnz) {
row.setInt(1 + i, v.indices(i))
i += 1
}
i = 0
while (i < nnz) {
row.setDouble(1 + nnz + i, v.values(i))
i += 1
}
row
}

override def deserialize(row: Row): SparseVector = {
require(row.length >= 1,
s"SparseVectorUDT.deserialize given row with length ${row.length} but requires length >= 1")
val vSize = row.getInt(0)
val nnz: Int = (row.length - 1) / 2
require(nnz * 2 + 1 == row.length,
s"SparseVectorUDT.deserialize given row with non-matching indices, values lengths")
val indices = new Array[Int](nnz)
val values = new Array[Double](nnz)
var i = 0
while (i < nnz) {
indices(i) = row.getInt(1 + i)
values(i) = row.getDouble(1 + nnz + i)
i += 1
}
new SparseVector(vSize, indices, values)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ case class LabeledPoint(label: Double, features: Vector) {
}
}



/**
* Parser for [[org.apache.spark.mllib.regression.LabeledPoint]].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,18 @@ object ScalaReflection {
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = {
val className: String = tpe.erasure.typeSymbol.asClass.fullName
println(s"schemaFor: className = $className")
tpe match {
case t if Utils.classIsLoadable(className) &&
Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
// Note: We check for classIsLoadable above since Utils.classForName uses Java reflection,
// whereas className is from Scala reflection. This can make it hard to find classes
// in some cases, such as when a class is enclosed in an object (in which case
// Java appends a '$' to the object name but Scala does not).
UDTRegistry.registerType(t)
Schema(UDTRegistry.udtRegistry(t), nullable = true)
val udt = Utils.classForName(className)
.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
UDTRegistry.registerType(t, udt)
Schema(udt, nullable = true)
case t if UDTRegistry.udtRegistry.contains(t) =>
Schema(UDTRegistry.udtRegistry(t), nullable = true)
case t if t <:< typeOf[Option[_]] =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.types.UserDefinedType
/**
* Global registry for user-defined types (UDTs).
*/
private[sql] object UDTRegistry {
object UDTRegistry {
/** Map: UserType --> UserDefinedType */
val udtRegistry = new mutable.HashMap[Any, UserDefinedType[_]]()

Expand All @@ -35,14 +35,9 @@ private[sql] object UDTRegistry {
* RDDs of user types and SchemaRDDs.
* If this type has already been registered, this does nothing.
*/
def registerType(userType: Type): Unit = {
def registerType(userType: Type, udt: UserDefinedType[_]): Unit = {
// TODO: Check to see if type is built-in. Throw exception?
if (!UDTRegistry.udtRegistry.contains(userType)) {
val udt =
getClass.getClassLoader.loadClass(userType.typeSymbol.asClass.fullName)
.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
UDTRegistry.udtRegistry(userType) = udt
}
UDTRegistry.udtRegistry(userType) = udt
// TODO: Else: Should we check (assert) that udt is the same as what is in the registry?
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.LogicalRDD
* Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java)
*/
private[sql] trait SchemaRDDLike {
@transient val sqlContext: SQLContext
@transient def sqlContext: SQLContext
@transient val baseLogicalPlan: LogicalPlan

private[sql] def baseSchemaRDD: SchemaRDD
Expand Down

0 comments on commit dff99d6

Please sign in to comment.