From 0eaeb8187342048287b4ada1400a750ca6742f80 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 6 Oct 2014 09:54:51 -0700 Subject: [PATCH] Still working on UDTs --- .../spark/sql/catalyst/ScalaReflection.scala | 59 +++++++++++++------ .../sql/catalyst/expressions/ScalaUdf.scala | 6 +- .../spark/sql/catalyst/types/dataTypes.scala | 23 +++----- .../org/apache/spark/sql/SQLContext.scala | 18 +++--- .../apache/spark/sql/UdfRegistration.scala | 46 +++++++-------- .../spark/sql/execution/ExistingRDD.scala | 8 ++- .../spark/sql/execution/SparkStrategies.scala | 3 +- .../org/apache/spark/sql/UserTypeSuite.scala | 7 ++- .../apache/spark/sql/hive/HiveContext.scala | 3 +- 9 files changed, 99 insertions(+), 74 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 07039a02fd85f..92c8bdf376e3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -36,13 +36,19 @@ object ScalaReflection { case class Schema(dataType: DataType, nullable: Boolean) /** Converts Scala objects to catalyst rows / types */ - def convertToCatalyst(a: Any): Any = a match { - case o: Option[_] => o.map(convertToCatalyst).orNull - case s: Seq[_] => s.map(convertToCatalyst) - case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } - case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) - case d: BigDecimal => Decimal(d) - case other => other + def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + case (o: Option[_], oType: _) => convertToCatalyst(o.orNull, oType) + case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) + case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => + convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) + } + case (p: Product, structType: StructType) => new GenericRow( + p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) => + convertToCatalyst(elem, field.dataType) + }.toArray) + case (udt: _, udtType: UDTType) => udtType. + case (d: BigDecimal, _) => Decimal(d) + case (other, _) => other } /** Converts Catalyst types used internally in rows to standard Scala types */ @@ -56,19 +62,29 @@ object ScalaReflection { def convertRowToScala(r: Row): Row = new GenericRow(r.toArray.map(convertToScala)) /** Returns a Sequence of attributes for the given case class type. */ - def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { - case Schema(s: StructType, _) => - s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + def attributesFor[T: TypeTag]( + udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Seq[Attribute] = { + schemaFor[T](udtRegistry) match { + case Schema(s: StructType, _) => + s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + } } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T]) + def schemaFor[T: TypeTag](udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = { + schemaFor(typeOf[T], udtRegistry) + } - /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor(tpe: `Type`, udtRegistry: Map[TypeTag[_], UserDefinedType[_]]): Schema = tpe match { + /** + * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. + * TODO: ADD DOC + */ + def schemaFor( + tpe: `Type`, + udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = tpe match { case t if t <:< typeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - Schema(schemaFor(optType).dataType, nullable = true) + Schema(schemaFor(optType, udtRegistry).dataType, nullable = true) case t if t <:< typeOf[Product] => val formalTypeArgs = t.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = t @@ -76,7 +92,7 @@ object ScalaReflection { Schema(StructType( params.head.map { p => val Schema(dataType, nullable) = - schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) + schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs), udtRegistry) StructField(p.name.toString, dataType, nullable) }), nullable = true) // Need to decide if we actually need a special type here. @@ -85,12 +101,12 @@ object ScalaReflection { sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") case t if t <:< typeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) + val Schema(dataType, nullable) = schemaFor(elementType, udtRegistry) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< typeOf[Map[_,_]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - Schema(MapType(schemaFor(keyType).dataType, + val Schema(valueDataType, valueNullable) = schemaFor(valueType, udtRegistry) + Schema(MapType(schemaFor(keyType, udtRegistry).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) case t if t <:< typeOf[String] => Schema(StringType, nullable = true) case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) @@ -111,6 +127,9 @@ object ScalaReflection { case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + case t if udtRegistry.contains(tpe) => + val udtStructType: StructType = udtRegistry(tpe).dataType + Schema(udtStructType, nullable = true) } def typeOfObject: PartialFunction[Any, DataType] = { @@ -142,7 +161,9 @@ object ScalaReflection { * for the the data in the sequence. */ def asRelation: LocalRelation = { - val output = attributesFor[A] + // Pass empty map to attributesFor since this method is only used for debugging Catalyst, + // not used with SparkSQL. + val output = attributesFor[A](Map.empty) LocalRelation(output, data) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 1b687a443ef8b..fa1786e74bb3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -21,6 +21,10 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.util.ClosureCleaner +/** + * User-defined function. + * @param dataType Return type of function. + */ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression]) extends Expression { @@ -347,6 +351,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi } // scalastyle:on - ScalaReflection.convertToCatalyst(result) + ScalaReflection.convertToCatalyst(result, dataType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 1c2c578992c3a..afe4cf101c88e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -342,6 +342,7 @@ object FractionalType { case _ => false } } + abstract class FractionalType extends NumericType { private[sql] val fractional: Fractional[JvmType] private[sql] val asIntegral: Integral[JvmType] @@ -583,22 +584,12 @@ object UDTType { } /** - * The data type for Maps. Keys in a map are not allowed to have `null` values. - * @param keyType The data type of map keys. - * @param valueType The data type of map values. - * @param valueContainsNull Indicates if map values have `null` values. + * The data type for UserDefinedType. */ -case class UDTType( - keyType: DataType, - valueType: DataType, - valueContainsNull: Boolean) extends DataType { - private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"${prefix}-- key: ${keyType.simpleString}\n") - builder.append(s"${prefix}-- value: ${valueType.simpleString} " + - s"(valueContainsNull = ${valueContainsNull})\n") - DataType.buildFormattedString(keyType, s"$prefix |", builder) - DataType.buildFormattedString(valueType, s"$prefix |", builder) - } +case class UDTType(dataType: StructType, ) extends DataType { + // Used only in regex parser above. + //private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { } - def simpleString: String = "map" + // TODO + def simpleString: String = "udt" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f06e9033382eb..d1811b7512715 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.types.UserDefinedType + import scala.collection.mutable import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} @@ -101,8 +103,8 @@ class SQLContext(@transient val sparkContext: SparkContext) */ implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = { SparkPlan.currentContext.set(self) - new SchemaRDD(this, - LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self)) + new SchemaRDD(this, LogicalRDD(ScalaReflection.attributesFor[A](udtRegistry), + RDDConversions.productToRowRdd(rdd, ScalaReflection.schemaFor[A](udtRegistry).dataType))(self)) } /** @@ -253,7 +255,7 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD( this, ParquetRelation.createEmpty( - path, ScalaReflection.attributesFor[A], allowExisting, conf, this)) + path, ScalaReflection.attributesFor[A](udtRegistry), allowExisting, conf, this)) } /** @@ -290,16 +292,14 @@ class SQLContext(@transient val sparkContext: SparkContext) * Fails if this type has been registered already. */ def registerUserType[UserType, UDT <: UserDefinedType[UserType]]( - //userType: Class[UserType], - udt: UDT): Unit = { - val userType: TypeTag[UserType] = typeTag[UserType] - require(!registeredUserTypes.contains(userType), + udt: UDT)(implicit userType: TypeTag[UserType]): Unit = { + require(!udtRegistry.contains(userType), "registerUserType called on type which was already registered.") - registeredUserTypes(userType) = udt + udtRegistry(userType) = udt } /** Map: UserType --> UserDefinedType */ - protected[sql] val registeredUserTypes = new mutable.HashMap[TypeTag[_], UserDefinedType[_]]() + protected[sql] val udtRegistry = new mutable.HashMap[Any, UserDefinedType[_]]() protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index 595b4aa36eae3..9946c8aa4d1bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -78,7 +78,7 @@ private[sql] trait UDFRegistration { s""" def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = { def builder(e: Seq[Expression]) = - ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } """ @@ -87,112 +87,112 @@ private[sql] trait UDFRegistration { // scalastyle:off def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function2[_, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function3[_, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function4[_, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function5[_, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function6[_, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function7[_, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function8[_, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function9[_, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function10[_, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 04c51a1ee4b97..3d2cbb3057e0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,18 +19,20 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataType, Row, SQLContext} import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.ScalaReflection.Schema import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.types.UserDefinedType /** * :: DeveloperApi :: */ @DeveloperApi object RDDConversions { - def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { + def productToRowRdd[A <: Product](data: RDD[A], dataType: DataType): RDD[Row] = { data.mapPartitions { iterator => if (iterator.isEmpty) { Iterator.empty @@ -41,7 +43,7 @@ object RDDConversions { bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { - mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i)) + mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i), dataType) i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 79e4ddb8c4f5d..d3c7033187048 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -280,7 +280,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val nPartitions = if (data.isEmpty) 1 else numPartitions PhysicalRDD( output, - RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions))) :: Nil + RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions), + StructType.fromAttributes(output))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala index 0da7aad4ac6f9..9daa8ea093d78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala @@ -47,13 +47,18 @@ class UserTypeSuite extends QueryTest { case class LabeledPoint(label: Double, features: DenseVector) test("register user type: LabeledPoint") { - TestSQLContext.registerUserType(classOf[DenseVector], classOf[VectorRowSerializer]) + TestSQLContext.registerUserType(new VectorRowSerializer()) val points = Seq( LabeledPoint(1.0, new DenseVector(Array(1.0, 0.0))), LabeledPoint(0.0, new DenseVector(Array(1.0, -1.0)))) val pointsRDD: RDD[LabeledPoint] = sparkContext.parallelize(points) + println("Converting to SchemaRDD") + val tmpSchemaRDD: SchemaRDD = TestSQLContext.createSchemaRDD(pointsRDD) + println(s"SchemaRDD count: ${tmpSchemaRDD.count()}") + println("Done converting to SchemaRDD") + val features: RDD[DenseVector] = pointsRDD.select('features).map { case Row(v: DenseVector) => v } val featuresArrays: Array[DenseVector] = features.collect() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2e27817d60221..1b751ab1b7c05 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -121,7 +121,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * @tparam A A case class that is used to describe the schema of the table to be created. */ def createTable[A <: Product : TypeTag](tableName: String, allowExisting: Boolean = true) { - catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], allowExisting) + catalog.createTable("default", tableName, ScalaReflection.attributesFor[A](udtRegistry), + allowExisting) } /**