Skip to content

Commit

Permalink
Still working on UDTs
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Nov 2, 2014
1 parent 105c5a3 commit 0eaeb81
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,19 @@ object ScalaReflection {
case class Schema(dataType: DataType, nullable: Boolean)

/** Converts Scala objects to catalyst rows / types */
def convertToCatalyst(a: Any): Any = a match {
case o: Option[_] => o.map(convertToCatalyst).orNull
case s: Seq[_] => s.map(convertToCatalyst)
case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
case d: BigDecimal => Decimal(d)
case other => other
def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match {
case (o: Option[_], oType: _) => convertToCatalyst(o.orNull, oType)
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType))
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
}
case (p: Product, structType: StructType) => new GenericRow(
p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) =>
convertToCatalyst(elem, field.dataType)
}.toArray)
case (udt: _, udtType: UDTType) => udtType.
case (d: BigDecimal, _) => Decimal(d)
case (other, _) => other
}

/** Converts Catalyst types used internally in rows to standard Scala types */
Expand All @@ -56,27 +62,37 @@ object ScalaReflection {
def convertRowToScala(r: Row): Row = new GenericRow(r.toArray.map(convertToScala))

/** Returns a Sequence of attributes for the given case class type. */
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
case Schema(s: StructType, _) =>
s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
def attributesFor[T: TypeTag](
udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Seq[Attribute] = {
schemaFor[T](udtRegistry) match {
case Schema(s: StructType, _) =>
s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
}
}

/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T])
def schemaFor[T: TypeTag](udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = {
schemaFor(typeOf[T], udtRegistry)
}

/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`, udtRegistry: Map[TypeTag[_], UserDefinedType[_]]): Schema = tpe match {
/**
* Returns a catalyst DataType and its nullability for the given Scala Type using reflection.
* TODO: ADD DOC
*/
def schemaFor(
tpe: `Type`,
udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = tpe match {
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
Schema(schemaFor(optType, udtRegistry).dataType, nullable = true)
case t if t <:< typeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
Schema(StructType(
params.head.map { p =>
val Schema(dataType, nullable) =
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs), udtRegistry)
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
// Need to decide if we actually need a special type here.
Expand All @@ -85,12 +101,12 @@ object ScalaReflection {
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
val Schema(dataType, nullable) = schemaFor(elementType, udtRegistry)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_,_]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
val Schema(valueDataType, valueNullable) = schemaFor(valueType, udtRegistry)
Schema(MapType(schemaFor(keyType, udtRegistry).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
Expand All @@ -111,6 +127,9 @@ object ScalaReflection {
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
case t if udtRegistry.contains(tpe) =>
val udtStructType: StructType = udtRegistry(tpe).dataType
Schema(udtStructType, nullable = true)
}

def typeOfObject: PartialFunction[Any, DataType] = {
Expand Down Expand Up @@ -142,7 +161,9 @@ object ScalaReflection {
* for the the data in the sequence.
*/
def asRelation: LocalRelation = {
val output = attributesFor[A]
// Pass empty map to attributesFor since this method is only used for debugging Catalyst,
// not used with SparkSQL.
val output = attributesFor[A](Map.empty)
LocalRelation(output, data)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.types.DataType
import org.apache.spark.util.ClosureCleaner

/**
* User-defined function.
* @param dataType Return type of function.
*/
case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression])
extends Expression {

Expand Down Expand Up @@ -347,6 +351,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
}
// scalastyle:on

ScalaReflection.convertToCatalyst(result)
ScalaReflection.convertToCatalyst(result, dataType)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ object FractionalType {
case _ => false
}
}

abstract class FractionalType extends NumericType {
private[sql] val fractional: Fractional[JvmType]
private[sql] val asIntegral: Integral[JvmType]
Expand Down Expand Up @@ -583,22 +584,12 @@ object UDTType {
}

/**
* The data type for Maps. Keys in a map are not allowed to have `null` values.
* @param keyType The data type of map keys.
* @param valueType The data type of map values.
* @param valueContainsNull Indicates if map values have `null` values.
* The data type for UserDefinedType.
*/
case class UDTType(
keyType: DataType,
valueType: DataType,
valueContainsNull: Boolean) extends DataType {
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
builder.append(s"${prefix}-- key: ${keyType.simpleString}\n")
builder.append(s"${prefix}-- value: ${valueType.simpleString} " +
s"(valueContainsNull = ${valueContainsNull})\n")
DataType.buildFormattedString(keyType, s"$prefix |", builder)
DataType.buildFormattedString(valueType, s"$prefix |", builder)
}
case class UDTType(dataType: StructType, ) extends DataType {
// Used only in regex parser above.
//private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { }

def simpleString: String = "map"
// TODO
def simpleString: String = "udt"
}
18 changes: 9 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.types.UserDefinedType

import scala.collection.mutable
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
Expand Down Expand Up @@ -101,8 +103,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = {
SparkPlan.currentContext.set(self)
new SchemaRDD(this,
LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self))
new SchemaRDD(this, LogicalRDD(ScalaReflection.attributesFor[A](udtRegistry),
RDDConversions.productToRowRdd(rdd, ScalaReflection.schemaFor[A](udtRegistry).dataType))(self))
}

/**
Expand Down Expand Up @@ -253,7 +255,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
new SchemaRDD(
this,
ParquetRelation.createEmpty(
path, ScalaReflection.attributesFor[A], allowExisting, conf, this))
path, ScalaReflection.attributesFor[A](udtRegistry), allowExisting, conf, this))
}

/**
Expand Down Expand Up @@ -290,16 +292,14 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Fails if this type has been registered already.
*/
def registerUserType[UserType, UDT <: UserDefinedType[UserType]](
//userType: Class[UserType],
udt: UDT): Unit = {
val userType: TypeTag[UserType] = typeTag[UserType]
require(!registeredUserTypes.contains(userType),
udt: UDT)(implicit userType: TypeTag[UserType]): Unit = {
require(!udtRegistry.contains(userType),
"registerUserType called on type which was already registered.")
registeredUserTypes(userType) = udt
udtRegistry(userType) = udt
}

/** Map: UserType --> UserDefinedType */
protected[sql] val registeredUserTypes = new mutable.HashMap[TypeTag[_], UserDefinedType[_]]()
protected[sql] val udtRegistry = new mutable.HashMap[Any, UserDefinedType[_]]()

protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext: SparkContext = self.sparkContext
Expand Down
46 changes: 23 additions & 23 deletions sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ private[sql] trait UDFRegistration {
s"""
def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = {
def builder(e: Seq[Expression]) =
ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}
"""
Expand All @@ -87,112 +87,112 @@ private[sql] trait UDFRegistration {

// scalastyle:off
def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function2[_, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function3[_, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function4[_, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function5[_, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function6[_, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function7[_, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function8[_, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function9[_, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function10[_, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}

def registerFunction[T: TypeTag](name: String, func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e)
functionRegistry.registerFunction(name, builder)
}
// scalastyle:on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@ package org.apache.spark.sql.execution

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataType, Row, SQLContext}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.ScalaReflection.Schema
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.types.UserDefinedType

/**
* :: DeveloperApi ::
*/
@DeveloperApi
object RDDConversions {
def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
def productToRowRdd[A <: Product](data: RDD[A], dataType: DataType): RDD[Row] = {
data.mapPartitions { iterator =>
if (iterator.isEmpty) {
Iterator.empty
Expand All @@ -41,7 +43,7 @@ object RDDConversions {
bufferedIterator.map { r =>
var i = 0
while (i < mutableRow.length) {
mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i))
mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i), dataType)
i += 1
}

Expand Down
Loading

0 comments on commit 0eaeb81

Please sign in to comment.