Skip to content

Commit

Permalink
still working on UDTs
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Nov 2, 2014
1 parent 19b2f60 commit 982c035
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,26 @@ object ScalaReflection {
val udtStructType: StructType = udtRegistry(typeTag[T]).dataType
Schema(udtStructType, nullable = true)
} else {
schemaFor(typeOf[T], udtRegistry)
schemaFor(typeOf[T])
}
}

/**
* Returns a catalyst DataType and its nullability for the given Scala Type using reflection.
* TODO: ADD DOC
*/
def schemaFor(
tpe: `Type`,
udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = tpe match {
def schemaFor(tpe: `Type`): Schema = tpe match {
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType, udtRegistry).dataType, nullable = true)
Schema(schemaFor(optType).dataType, nullable = true)
case t if t <:< typeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
Schema(StructType(
params.head.map { p =>
val Schema(dataType, nullable) =
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs), udtRegistry)
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
// Need to decide if we actually need a special type here.
Expand All @@ -108,12 +106,12 @@ object ScalaReflection {
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType, udtRegistry)
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_,_]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType, udtRegistry)
Schema(MapType(schemaFor(keyType, udtRegistry).dataType,
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
Expand All @@ -134,9 +132,6 @@ object ScalaReflection {
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
/* case t if udtRegistry.contains(typeTag[t]) =>
val udtStructType: StructType = udtRegistry(tpe).dataType
Schema(udtStructType, nullable = true)*/
}

def typeOfObject: PartialFunction[Any, DataType] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -574,24 +574,23 @@ trait UserDefinedType[T] {
def deserialize(row: Row): T
}

object UDTType {
object UserDefinedType {
/**
* Construct a [[UDTType]] object with the given key type and value type.
* Construct a [[UserDefinedType]] object with the given key type and value type.
* The `valueContainsNull` is true.
*/
def apply(keyType: DataType, valueType: DataType): MapType =
MapType(keyType: DataType, valueType: DataType, true)
//def apply(keyType: DataType, valueType: DataType): MapType =
// MapType(keyType: DataType, valueType: DataType, true)
}

/**
* The data type for UserDefinedType.
* The data type for User Defined Types.
*/
case class UDTType(dataType: StructType, ) extends DataType {
abstract class UserDefinedType[UserType](val dataType: StructType) extends DataType with Serializable {

// Used only in regex parser above.
//private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { }

def dataType: StructType

def serialize(obj: Any): Row

def deserialize(row: Row): UserType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
val ord = new RowOrdering(sortOrder, child.output)

// TODO: Is this copying for no reason?
override def executeCollect() =
child.execute().map(_.copy()).takeOrdered(limit)(ord).map(ScalaReflection.convertRowToScala)
override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord)
.map(ScalaReflection.convertRowToScala)

// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types.UserDefinedType
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._

class UserDefinedTypeSuite extends QueryTest {

case class LabeledPoint(label: Double, feature: Double) extends Serializable

object LabeledPointUDT {

def dataType: StructType =
StructType(Seq(
StructField("label", DoubleType, nullable = false),
StructField("feature", DoubleType, nullable = false)))

}

case class LabeledPointUDT() extends UserDefinedType[LabeledPoint](LabeledPointUDT.dataType) with Serializable {

override def serialize(obj: Any): Row = obj match {
case lp: LabeledPoint =>
val row: GenericMutableRow = new GenericMutableRow(2)
row.setDouble(0, lp.label)
row.setDouble(1, lp.feature)
row
}

override def deserialize(row: Row): LabeledPoint = {
assert(row.length == 2)
val label = row.getDouble(0)
val feature = row.getDouble(1)
LabeledPoint(label, feature)
}
}

test("register user type: LabeledPoint") {
try {
TestSQLContext.registerUserType(new LabeledPointUDT())
println("udtRegistry:")
TestSQLContext.udtRegistry.foreach { case (t, s) => println(s"$t -> $s")}

println(s"test: ${scala.reflect.runtime.universe.typeTag[LabeledPoint]}")
assert(TestSQLContext.udtRegistry.contains(scala.reflect.runtime.universe.typeTag[LabeledPoint]))

val points = Seq(
LabeledPoint(1.0, 2.0),
LabeledPoint(0.0, 3.0))
val pointsRDD: RDD[LabeledPoint] = sparkContext.parallelize(points)

println("Converting to SchemaRDD")
val tmpSchemaRDD: SchemaRDD = TestSQLContext.createSchemaRDD(pointsRDD)
println("blah")
println(s"SchemaRDD count: ${tmpSchemaRDD.count()}")
println("Done converting to SchemaRDD")

/*
val features: RDD[DenseVector] =
pointsRDD.select('features).map { case Row(v: DenseVector) => v}
val featuresArrays: Array[DenseVector] = features.collect()
assert(featuresArrays.size === 2)
assert(featuresArrays.contains(new DenseVector(Array(1.0, 0.0))))
assert(featuresArrays.contains(new DenseVector(Array(1.0, -1.0))))
val labels: RDD[Double] = pointsRDD.select('labels).map { case Row(v: Double) => v}
val labelsArrays: Array[Double] = labels.collect()
assert(labelsArrays.size === 2)
assert(labelsArrays.contains(1.0))
assert(labelsArrays.contains(0.0))
*/
} catch {
case e: Exception =>
e.printStackTrace()
}
}

}
108 changes: 68 additions & 40 deletions sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,64 +18,92 @@
package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types.UserDefinedType
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._

class UserTypeSuite extends QueryTest {

class DenseVector(val data: Array[Double])
class DenseVector(val data: Array[Double]) extends Serializable

case class LabeledPoint(label: Double, features: DenseVector)
case class LabeledPoint(label: Double, features: DenseVector) extends Serializable

class LabeledPointUDT extends UserDefinedType[LabeledPoint] {
object LabeledPointUDT {

override def dataType: StructType =
StructType(Seq(StructField("features", ArrayType(DoubleType), nullable = false)))
def dataType: StructType =
StructType(Seq(
StructField("label", DoubleType, nullable = false),
StructField("features", ArrayType(DoubleType), nullable = false)))

override def serialize(obj: Any): Row = Row(obj.asInstanceOf[DenseVector].data)
}

case class LabeledPointUDT() extends UserDefinedType[LabeledPoint](LabeledPointUDT.dataType) with Serializable {

override def serialize(obj: Any): Row = obj match {
case lp: LabeledPoint =>
val row: GenericMutableRow = new GenericMutableRow(1 + lp.features.data.size)
row.setDouble(0, lp.label)
var i = 0
while (i < lp.features.data.size) {
row.setDouble(1 + i, lp.features.data(i))
i += 1
}
row
// Array.concat(Array(lp.label), lp.features.data))
}

override def deserialize(row: Row): DenseVector = {
val arr = new Array[Double](row.length)
override def deserialize(row: Row): LabeledPoint = {
assert(row.length >= 1)
val label = row.getDouble(0)
val arr = new Array[Double](row.length - 1)
var i = 0
while (i < row.length) {
arr(i) = row.getDouble(i)
while (i < row.length - 1) {
arr(i) = row.getDouble(i + 1)
i += 1
}
new DenseVector(arr)
LabeledPoint(label, new DenseVector(arr))
}
}

test("register user type: LabeledPoint") {
TestSQLContext.registerUserType(new VectorRowSerializer())
println("udtRegistry:")
TestSQLContext.udtRegistry.foreach { case (t,s) => println(s"$t -> $s") }

println(s"test: ${scala.reflect.runtime.universe.typeTag[DenseVector]}")
assert(TestSQLContext.udtRegistry.contains(scala.reflect.runtime.universe.typeTag[DenseVector]))

val points = Seq(
LabeledPoint(1.0, new DenseVector(Array(1.0, 0.0))),
LabeledPoint(0.0, new DenseVector(Array(1.0, -1.0))))
val pointsRDD: RDD[LabeledPoint] = sparkContext.parallelize(points)

println("Converting to SchemaRDD")
val tmpSchemaRDD: SchemaRDD = TestSQLContext.createSchemaRDD(pointsRDD)
println(s"SchemaRDD count: ${tmpSchemaRDD.count()}")
println("Done converting to SchemaRDD")

val features: RDD[DenseVector] =
pointsRDD.select('features).map { case Row(v: DenseVector) => v }
val featuresArrays: Array[DenseVector] = features.collect()
assert(featuresArrays.size === 2)
assert(featuresArrays.contains(new DenseVector(Array(1.0, 0.0))))
assert(featuresArrays.contains(new DenseVector(Array(1.0, -1.0))))

val labels: RDD[Double] = pointsRDD.select('labels).map { case Row(v: Double) => v }
val labelsArrays: Array[Double] = labels.collect()
assert(labelsArrays.size === 2)
assert(labelsArrays.contains(1.0))
assert(labelsArrays.contains(0.0))
try {
TestSQLContext.registerUserType(new LabeledPointUDT())
println("udtRegistry:")
TestSQLContext.udtRegistry.foreach { case (t, s) => println(s"$t -> $s")}

println(s"test: ${scala.reflect.runtime.universe.typeTag[LabeledPoint]}")
assert(TestSQLContext.udtRegistry.contains(scala.reflect.runtime.universe.typeTag[LabeledPoint]))

val points = Seq(
LabeledPoint(1.0, new DenseVector(Array(1.0, 0.0))),
LabeledPoint(0.0, new DenseVector(Array(1.0, -1.0))))
val pointsRDD: RDD[LabeledPoint] = sparkContext.parallelize(points)

println("Converting to SchemaRDD")
val tmpSchemaRDD: SchemaRDD = TestSQLContext.createSchemaRDD(pointsRDD)
println("blah")
println(s"SchemaRDD count: ${tmpSchemaRDD.count()}")
println("Done converting to SchemaRDD")

/*
val features: RDD[DenseVector] =
pointsRDD.select('features).map { case Row(v: DenseVector) => v}
val featuresArrays: Array[DenseVector] = features.collect()
assert(featuresArrays.size === 2)
assert(featuresArrays.contains(new DenseVector(Array(1.0, 0.0))))
assert(featuresArrays.contains(new DenseVector(Array(1.0, -1.0))))
val labels: RDD[Double] = pointsRDD.select('labels).map { case Row(v: Double) => v}
val labelsArrays: Array[Double] = labels.collect()
assert(labelsArrays.size === 2)
assert(labelsArrays.contains(1.0))
assert(labelsArrays.contains(0.0))
*/
} catch {
case e: Exception =>
e.printStackTrace()
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
/** Extends QueryExecution with hive specific features. */
protected[sql] abstract class QueryExecution extends super.QueryExecution {

override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy())
override lazy val toRdd: RDD[Row] = {
//val dataType = StructType.fromAttributes(logical.output)
executedPlan.execute().map(ScalaReflection.convertRowToScala(_))
}

protected val primitiveTypes =
Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType,
Expand Down

0 comments on commit 982c035

Please sign in to comment.