From 105c5a366501a8ef6957cba43968f477b56f9a45 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 2 Oct 2014 19:06:49 -0700 Subject: [PATCH 01/46] Adding UserDefinedType to SQL, not done yet. --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../spark/sql/catalyst/types/dataTypes.scala | 41 ++++++++++- .../org/apache/spark/sql/SQLContext.scala | 20 +++++- .../org/apache/spark/sql/UserTypeSuite.scala | 71 +++++++++++++++++++ 4 files changed, 130 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8fbdf664b71e4..07039a02fd85f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -65,7 +65,7 @@ object ScalaReflection { def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T]) /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor(tpe: `Type`): Schema = tpe match { + def schemaFor(tpe: `Type`, udtRegistry: Map[TypeTag[_], UserDefinedType[_]]): Schema = tpe match { case t if t <:< typeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t Schema(schemaFor(optType).dataType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 8dda0b182805c..1c2c578992c3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -30,10 +30,10 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.sql.catalyst.ScalaReflectionLock -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Row} +import org.apache.spark.sql.catalyst.types.decimal._ import org.apache.spark.sql.catalyst.util.Metadata import org.apache.spark.util.Utils -import org.apache.spark.sql.catalyst.types.decimal._ object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) @@ -565,3 +565,40 @@ case class MapType( ("valueType" -> valueType.jsonValue) ~ ("valueContainsNull" -> valueContainsNull) } + +// TODO: Where should this go? +trait UserDefinedType[T] { + def dataType: StructType + def serialize(obj: T): Row + def deserialize(row: Row): T +} + +object UDTType { + /** + * Construct a [[UDTType]] object with the given key type and value type. + * The `valueContainsNull` is true. + */ + def apply(keyType: DataType, valueType: DataType): MapType = + MapType(keyType: DataType, valueType: DataType, true) +} + +/** + * The data type for Maps. Keys in a map are not allowed to have `null` values. + * @param keyType The data type of map keys. + * @param valueType The data type of map values. + * @param valueContainsNull Indicates if map values have `null` values. + */ +case class UDTType( + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean) extends DataType { + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"${prefix}-- key: ${keyType.simpleString}\n") + builder.append(s"${prefix}-- value: ${valueType.simpleString} " + + s"(valueContainsNull = ${valueContainsNull})\n") + DataType.buildFormattedString(keyType, s"$prefix |", builder) + DataType.buildFormattedString(valueType, s"$prefix |", builder) + } + + def simpleString: String = "map" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4953f8399a96b..f06e9033382eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql +import scala.collection.mutable import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag +import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.hadoop.conf.Configuration @@ -283,6 +284,23 @@ class SQLContext(@transient val sparkContext: SparkContext) def table(tableName: String): SchemaRDD = new SchemaRDD(this, catalog.lookupRelation(None, tableName)) + /** + * Register a user-defined type and its serializer, to allow automatic conversion between + * RDDs of user types and SchemaRDDs. + * Fails if this type has been registered already. + */ + def registerUserType[UserType, UDT <: UserDefinedType[UserType]]( + //userType: Class[UserType], + udt: UDT): Unit = { + val userType: TypeTag[UserType] = typeTag[UserType] + require(!registeredUserTypes.contains(userType), + "registerUserType called on type which was already registered.") + registeredUserTypes(userType) = udt + } + + /** Map: UserType --> UserDefinedType */ + protected[sql] val registeredUserTypes = new mutable.HashMap[TypeTag[_], UserDefinedType[_]]() + protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala new file mode 100644 index 0000000000000..0da7aad4ac6f9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.types.UserDefinedType +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ + +class UserTypeSuite extends QueryTest { + + class DenseVector(val data: Array[Double]) + + class VectorRowSerializer extends UserDefinedType[DenseVector] { + + override def dataType: StructType = + StructType(Seq(StructField("features", ArrayType(DoubleType), nullable = false))) + + override def serialize(obj: DenseVector): Row = Row(obj.data) + + override def deserialize(row: Row): DenseVector = { + val arr = new Array[Double](row.length) + var i = 0 + while (i < row.length) { + arr(i) = row.getDouble(i) + i += 1 + } + new DenseVector(arr) + } + } + + case class LabeledPoint(label: Double, features: DenseVector) + + test("register user type: LabeledPoint") { + TestSQLContext.registerUserType(classOf[DenseVector], classOf[VectorRowSerializer]) + + val points = Seq( + LabeledPoint(1.0, new DenseVector(Array(1.0, 0.0))), + LabeledPoint(0.0, new DenseVector(Array(1.0, -1.0)))) + val pointsRDD: RDD[LabeledPoint] = sparkContext.parallelize(points) + + val features: RDD[DenseVector] = + pointsRDD.select('features).map { case Row(v: DenseVector) => v } + val featuresArrays: Array[DenseVector] = features.collect() + assert(featuresArrays.size === 2) + assert(featuresArrays.contains(new DenseVector(Array(1.0, 0.0)))) + assert(featuresArrays.contains(new DenseVector(Array(1.0, -1.0)))) + + val labels: RDD[Double] = pointsRDD.select('labels).map { case Row(v: Double) => v } + val labelsArrays: Array[Double] = labels.collect() + assert(labelsArrays.size === 2) + assert(labelsArrays.contains(1.0)) + assert(labelsArrays.contains(0.0)) + } + +} From 0eaeb8187342048287b4ada1400a750ca6742f80 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 6 Oct 2014 09:54:51 -0700 Subject: [PATCH 02/46] Still working on UDTs --- .../spark/sql/catalyst/ScalaReflection.scala | 59 +++++++++++++------ .../sql/catalyst/expressions/ScalaUdf.scala | 6 +- .../spark/sql/catalyst/types/dataTypes.scala | 23 +++----- .../org/apache/spark/sql/SQLContext.scala | 18 +++--- .../apache/spark/sql/UdfRegistration.scala | 46 +++++++-------- .../spark/sql/execution/ExistingRDD.scala | 8 ++- .../spark/sql/execution/SparkStrategies.scala | 3 +- .../org/apache/spark/sql/UserTypeSuite.scala | 7 ++- .../apache/spark/sql/hive/HiveContext.scala | 3 +- 9 files changed, 99 insertions(+), 74 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 07039a02fd85f..92c8bdf376e3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -36,13 +36,19 @@ object ScalaReflection { case class Schema(dataType: DataType, nullable: Boolean) /** Converts Scala objects to catalyst rows / types */ - def convertToCatalyst(a: Any): Any = a match { - case o: Option[_] => o.map(convertToCatalyst).orNull - case s: Seq[_] => s.map(convertToCatalyst) - case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } - case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) - case d: BigDecimal => Decimal(d) - case other => other + def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + case (o: Option[_], oType: _) => convertToCatalyst(o.orNull, oType) + case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) + case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => + convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) + } + case (p: Product, structType: StructType) => new GenericRow( + p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) => + convertToCatalyst(elem, field.dataType) + }.toArray) + case (udt: _, udtType: UDTType) => udtType. + case (d: BigDecimal, _) => Decimal(d) + case (other, _) => other } /** Converts Catalyst types used internally in rows to standard Scala types */ @@ -56,19 +62,29 @@ object ScalaReflection { def convertRowToScala(r: Row): Row = new GenericRow(r.toArray.map(convertToScala)) /** Returns a Sequence of attributes for the given case class type. */ - def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { - case Schema(s: StructType, _) => - s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + def attributesFor[T: TypeTag]( + udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Seq[Attribute] = { + schemaFor[T](udtRegistry) match { + case Schema(s: StructType, _) => + s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + } } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T]) + def schemaFor[T: TypeTag](udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = { + schemaFor(typeOf[T], udtRegistry) + } - /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor(tpe: `Type`, udtRegistry: Map[TypeTag[_], UserDefinedType[_]]): Schema = tpe match { + /** + * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. + * TODO: ADD DOC + */ + def schemaFor( + tpe: `Type`, + udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = tpe match { case t if t <:< typeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - Schema(schemaFor(optType).dataType, nullable = true) + Schema(schemaFor(optType, udtRegistry).dataType, nullable = true) case t if t <:< typeOf[Product] => val formalTypeArgs = t.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = t @@ -76,7 +92,7 @@ object ScalaReflection { Schema(StructType( params.head.map { p => val Schema(dataType, nullable) = - schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) + schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs), udtRegistry) StructField(p.name.toString, dataType, nullable) }), nullable = true) // Need to decide if we actually need a special type here. @@ -85,12 +101,12 @@ object ScalaReflection { sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") case t if t <:< typeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) + val Schema(dataType, nullable) = schemaFor(elementType, udtRegistry) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< typeOf[Map[_,_]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - Schema(MapType(schemaFor(keyType).dataType, + val Schema(valueDataType, valueNullable) = schemaFor(valueType, udtRegistry) + Schema(MapType(schemaFor(keyType, udtRegistry).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) case t if t <:< typeOf[String] => Schema(StringType, nullable = true) case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) @@ -111,6 +127,9 @@ object ScalaReflection { case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + case t if udtRegistry.contains(tpe) => + val udtStructType: StructType = udtRegistry(tpe).dataType + Schema(udtStructType, nullable = true) } def typeOfObject: PartialFunction[Any, DataType] = { @@ -142,7 +161,9 @@ object ScalaReflection { * for the the data in the sequence. */ def asRelation: LocalRelation = { - val output = attributesFor[A] + // Pass empty map to attributesFor since this method is only used for debugging Catalyst, + // not used with SparkSQL. + val output = attributesFor[A](Map.empty) LocalRelation(output, data) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 1b687a443ef8b..fa1786e74bb3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -21,6 +21,10 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.util.ClosureCleaner +/** + * User-defined function. + * @param dataType Return type of function. + */ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression]) extends Expression { @@ -347,6 +351,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi } // scalastyle:on - ScalaReflection.convertToCatalyst(result) + ScalaReflection.convertToCatalyst(result, dataType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 1c2c578992c3a..afe4cf101c88e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -342,6 +342,7 @@ object FractionalType { case _ => false } } + abstract class FractionalType extends NumericType { private[sql] val fractional: Fractional[JvmType] private[sql] val asIntegral: Integral[JvmType] @@ -583,22 +584,12 @@ object UDTType { } /** - * The data type for Maps. Keys in a map are not allowed to have `null` values. - * @param keyType The data type of map keys. - * @param valueType The data type of map values. - * @param valueContainsNull Indicates if map values have `null` values. + * The data type for UserDefinedType. */ -case class UDTType( - keyType: DataType, - valueType: DataType, - valueContainsNull: Boolean) extends DataType { - private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"${prefix}-- key: ${keyType.simpleString}\n") - builder.append(s"${prefix}-- value: ${valueType.simpleString} " + - s"(valueContainsNull = ${valueContainsNull})\n") - DataType.buildFormattedString(keyType, s"$prefix |", builder) - DataType.buildFormattedString(valueType, s"$prefix |", builder) - } +case class UDTType(dataType: StructType, ) extends DataType { + // Used only in regex parser above. + //private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { } - def simpleString: String = "map" + // TODO + def simpleString: String = "udt" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f06e9033382eb..d1811b7512715 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.types.UserDefinedType + import scala.collection.mutable import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} @@ -101,8 +103,8 @@ class SQLContext(@transient val sparkContext: SparkContext) */ implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = { SparkPlan.currentContext.set(self) - new SchemaRDD(this, - LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self)) + new SchemaRDD(this, LogicalRDD(ScalaReflection.attributesFor[A](udtRegistry), + RDDConversions.productToRowRdd(rdd, ScalaReflection.schemaFor[A](udtRegistry).dataType))(self)) } /** @@ -253,7 +255,7 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD( this, ParquetRelation.createEmpty( - path, ScalaReflection.attributesFor[A], allowExisting, conf, this)) + path, ScalaReflection.attributesFor[A](udtRegistry), allowExisting, conf, this)) } /** @@ -290,16 +292,14 @@ class SQLContext(@transient val sparkContext: SparkContext) * Fails if this type has been registered already. */ def registerUserType[UserType, UDT <: UserDefinedType[UserType]]( - //userType: Class[UserType], - udt: UDT): Unit = { - val userType: TypeTag[UserType] = typeTag[UserType] - require(!registeredUserTypes.contains(userType), + udt: UDT)(implicit userType: TypeTag[UserType]): Unit = { + require(!udtRegistry.contains(userType), "registerUserType called on type which was already registered.") - registeredUserTypes(userType) = udt + udtRegistry(userType) = udt } /** Map: UserType --> UserDefinedType */ - protected[sql] val registeredUserTypes = new mutable.HashMap[TypeTag[_], UserDefinedType[_]]() + protected[sql] val udtRegistry = new mutable.HashMap[Any, UserDefinedType[_]]() protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index 595b4aa36eae3..9946c8aa4d1bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -78,7 +78,7 @@ private[sql] trait UDFRegistration { s""" def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = { def builder(e: Seq[Expression]) = - ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } """ @@ -87,112 +87,112 @@ private[sql] trait UDFRegistration { // scalastyle:off def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function2[_, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function3[_, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function4[_, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function5[_, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function6[_, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function7[_, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function8[_, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function9[_, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function10[_, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) functionRegistry.registerFunction(name, builder) } // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 04c51a1ee4b97..3d2cbb3057e0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,18 +19,20 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataType, Row, SQLContext} import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.ScalaReflection.Schema import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.types.UserDefinedType /** * :: DeveloperApi :: */ @DeveloperApi object RDDConversions { - def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { + def productToRowRdd[A <: Product](data: RDD[A], dataType: DataType): RDD[Row] = { data.mapPartitions { iterator => if (iterator.isEmpty) { Iterator.empty @@ -41,7 +43,7 @@ object RDDConversions { bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { - mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i)) + mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i), dataType) i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 79e4ddb8c4f5d..d3c7033187048 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -280,7 +280,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val nPartitions = if (data.isEmpty) 1 else numPartitions PhysicalRDD( output, - RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions))) :: Nil + RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions), + StructType.fromAttributes(output))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala index 0da7aad4ac6f9..9daa8ea093d78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala @@ -47,13 +47,18 @@ class UserTypeSuite extends QueryTest { case class LabeledPoint(label: Double, features: DenseVector) test("register user type: LabeledPoint") { - TestSQLContext.registerUserType(classOf[DenseVector], classOf[VectorRowSerializer]) + TestSQLContext.registerUserType(new VectorRowSerializer()) val points = Seq( LabeledPoint(1.0, new DenseVector(Array(1.0, 0.0))), LabeledPoint(0.0, new DenseVector(Array(1.0, -1.0)))) val pointsRDD: RDD[LabeledPoint] = sparkContext.parallelize(points) + println("Converting to SchemaRDD") + val tmpSchemaRDD: SchemaRDD = TestSQLContext.createSchemaRDD(pointsRDD) + println(s"SchemaRDD count: ${tmpSchemaRDD.count()}") + println("Done converting to SchemaRDD") + val features: RDD[DenseVector] = pointsRDD.select('features).map { case Row(v: DenseVector) => v } val featuresArrays: Array[DenseVector] = features.collect() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2e27817d60221..1b751ab1b7c05 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -121,7 +121,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * @tparam A A case class that is used to describe the schema of the table to be created. */ def createTable[A <: Product : TypeTag](tableName: String, allowExisting: Boolean = true) { - catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], allowExisting) + catalog.createTable("default", tableName, ScalaReflection.attributesFor[A](udtRegistry), + allowExisting) } /** From 19b2f60cf337c6e3e332ee8cddfeb818dad2b699 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 6 Oct 2014 13:18:26 -0700 Subject: [PATCH 03/46] still working on UDTs --- .../apache/spark/sql/catalyst/ScalaReflection.scala | 13 ++++++++++--- .../apache/spark/sql/catalyst/types/dataTypes.scala | 6 ++++++ .../spark/sql/catalyst/ScalaReflectionSuite.scala | 3 ++- .../scala/org/apache/spark/sql/SQLContext.scala | 4 ++-- .../scala/org/apache/spark/sql/UserTypeSuite.scala | 13 +++++++++---- 5 files changed, 29 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 92c8bdf376e3b..3e45b8ca133bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -56,6 +56,7 @@ object ScalaReflection { case s: Seq[_] => s.map(convertToScala) case m: Map[_, _] => m.map { case (k, v) => convertToScala(k) -> convertToScala(v) } case d: Decimal => d.toBigDecimal + case (udt: Any, udtType: UserDefinedType[_]) => udtType.serialize(udt) case other => other } @@ -72,7 +73,13 @@ object ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor[T: TypeTag](udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = { - schemaFor(typeOf[T], udtRegistry) + println(s"schemaFor: ${typeTag[T]}") + if (udtRegistry.contains(typeTag[T])) { + val udtStructType: StructType = udtRegistry(typeTag[T]).dataType + Schema(udtStructType, nullable = true) + } else { + schemaFor(typeOf[T], udtRegistry) + } } /** @@ -127,9 +134,9 @@ object ScalaReflection { case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) - case t if udtRegistry.contains(tpe) => +/* case t if udtRegistry.contains(typeTag[t]) => val udtStructType: StructType = udtRegistry(tpe).dataType - Schema(udtStructType, nullable = true) + Schema(udtStructType, nullable = true)*/ } def typeOfObject: PartialFunction[Any, DataType] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index afe4cf101c88e..988f7dd138072 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -590,6 +590,12 @@ case class UDTType(dataType: StructType, ) extends DataType { // Used only in regex parser above. //private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { } + def dataType: StructType + + def serialize(obj: Any): Row + + def deserialize(row: Row): UserType + // TODO def simpleString: String = "udt" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 21b2c8e20d4db..37978155f43ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -69,7 +69,7 @@ case class GenericData[A]( class ScalaReflectionSuite extends FunSuite { import ScalaReflection._ - +/* test("primitive data") { val schema = schemaFor[PrimitiveData] assert(schema === Schema( @@ -248,4 +248,5 @@ class ScalaReflectionSuite extends FunSuite { val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true, convertToCatalyst(primitiveData)) assert(convertToCatalyst(data) === convertedData) } + */ } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index d1811b7512715..5a2aa9a3dee01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -291,8 +291,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * RDDs of user types and SchemaRDDs. * Fails if this type has been registered already. */ - def registerUserType[UserType, UDT <: UserDefinedType[UserType]]( - udt: UDT)(implicit userType: TypeTag[UserType]): Unit = { + def registerUserType[UserType]( + udt: UserDefinedType[UserType])(implicit userType: TypeTag[UserType]): Unit = { require(!udtRegistry.contains(userType), "registerUserType called on type which was already registered.") udtRegistry(userType) = udt diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala index 9daa8ea093d78..613e90bc2eb2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala @@ -26,12 +26,14 @@ class UserTypeSuite extends QueryTest { class DenseVector(val data: Array[Double]) - class VectorRowSerializer extends UserDefinedType[DenseVector] { + case class LabeledPoint(label: Double, features: DenseVector) + + class LabeledPointUDT extends UserDefinedType[LabeledPoint] { override def dataType: StructType = StructType(Seq(StructField("features", ArrayType(DoubleType), nullable = false))) - override def serialize(obj: DenseVector): Row = Row(obj.data) + override def serialize(obj: Any): Row = Row(obj.asInstanceOf[DenseVector].data) override def deserialize(row: Row): DenseVector = { val arr = new Array[Double](row.length) @@ -44,10 +46,13 @@ class UserTypeSuite extends QueryTest { } } - case class LabeledPoint(label: Double, features: DenseVector) - test("register user type: LabeledPoint") { TestSQLContext.registerUserType(new VectorRowSerializer()) + println("udtRegistry:") + TestSQLContext.udtRegistry.foreach { case (t,s) => println(s"$t -> $s") } + + println(s"test: ${scala.reflect.runtime.universe.typeTag[DenseVector]}") + assert(TestSQLContext.udtRegistry.contains(scala.reflect.runtime.universe.typeTag[DenseVector])) val points = Seq( LabeledPoint(1.0, new DenseVector(Array(1.0, 0.0))), From 982c03561d8b1b41d14c061e704910b582f91703 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 6 Oct 2014 19:10:43 -0700 Subject: [PATCH 04/46] still working on UDTs --- .../spark/sql/catalyst/ScalaReflection.scala | 19 ++- .../spark/sql/catalyst/types/dataTypes.scala | 15 ++- .../spark/sql/execution/basicOperators.scala | 4 +- .../spark/sql/UserDefinedTypeSuite.scala | 97 ++++++++++++++++ .../org/apache/spark/sql/UserTypeSuite.scala | 108 +++++++++++------- .../apache/spark/sql/hive/HiveContext.scala | 5 +- 6 files changed, 185 insertions(+), 63 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 3e45b8ca133bc..cd26d863a4f24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -78,7 +78,7 @@ object ScalaReflection { val udtStructType: StructType = udtRegistry(typeTag[T]).dataType Schema(udtStructType, nullable = true) } else { - schemaFor(typeOf[T], udtRegistry) + schemaFor(typeOf[T]) } } @@ -86,12 +86,10 @@ object ScalaReflection { * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. * TODO: ADD DOC */ - def schemaFor( - tpe: `Type`, - udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = tpe match { + def schemaFor(tpe: `Type`): Schema = tpe match { case t if t <:< typeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - Schema(schemaFor(optType, udtRegistry).dataType, nullable = true) + Schema(schemaFor(optType).dataType, nullable = true) case t if t <:< typeOf[Product] => val formalTypeArgs = t.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = t @@ -99,7 +97,7 @@ object ScalaReflection { Schema(StructType( params.head.map { p => val Schema(dataType, nullable) = - schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs), udtRegistry) + schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) StructField(p.name.toString, dataType, nullable) }), nullable = true) // Need to decide if we actually need a special type here. @@ -108,12 +106,12 @@ object ScalaReflection { sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") case t if t <:< typeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType, udtRegistry) + val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< typeOf[Map[_,_]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(valueDataType, valueNullable) = schemaFor(valueType, udtRegistry) - Schema(MapType(schemaFor(keyType, udtRegistry).dataType, + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + Schema(MapType(schemaFor(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) case t if t <:< typeOf[String] => Schema(StringType, nullable = true) case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) @@ -134,9 +132,6 @@ object ScalaReflection { case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) -/* case t if udtRegistry.contains(typeTag[t]) => - val udtStructType: StructType = udtRegistry(tpe).dataType - Schema(udtStructType, nullable = true)*/ } def typeOfObject: PartialFunction[Any, DataType] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 988f7dd138072..80bd1c4d951cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -574,24 +574,23 @@ trait UserDefinedType[T] { def deserialize(row: Row): T } -object UDTType { +object UserDefinedType { /** - * Construct a [[UDTType]] object with the given key type and value type. + * Construct a [[UserDefinedType]] object with the given key type and value type. * The `valueContainsNull` is true. */ - def apply(keyType: DataType, valueType: DataType): MapType = - MapType(keyType: DataType, valueType: DataType, true) + //def apply(keyType: DataType, valueType: DataType): MapType = + // MapType(keyType: DataType, valueType: DataType, true) } /** - * The data type for UserDefinedType. + * The data type for User Defined Types. */ -case class UDTType(dataType: StructType, ) extends DataType { +abstract class UserDefinedType[UserType](val dataType: StructType) extends DataType with Serializable { + // Used only in regex parser above. //private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { } - def dataType: StructType - def serialize(obj: Any): Row def deserialize(row: Row): UserType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e6cd1a9d04278..1efb793337f9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -179,8 +179,8 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) val ord = new RowOrdering(sortOrder, child.output) // TODO: Is this copying for no reason? - override def executeCollect() = - child.execute().map(_.copy()).takeOrdered(limit)(ord).map(ScalaReflection.convertRowToScala) + override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord) + .map(ScalaReflection.convertRowToScala) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala new file mode 100644 index 0000000000000..9996dfac93ebd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.types.UserDefinedType +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ + +class UserDefinedTypeSuite extends QueryTest { + + case class LabeledPoint(label: Double, feature: Double) extends Serializable + + object LabeledPointUDT { + + def dataType: StructType = + StructType(Seq( + StructField("label", DoubleType, nullable = false), + StructField("feature", DoubleType, nullable = false))) + + } + + case class LabeledPointUDT() extends UserDefinedType[LabeledPoint](LabeledPointUDT.dataType) with Serializable { + + override def serialize(obj: Any): Row = obj match { + case lp: LabeledPoint => + val row: GenericMutableRow = new GenericMutableRow(2) + row.setDouble(0, lp.label) + row.setDouble(1, lp.feature) + row + } + + override def deserialize(row: Row): LabeledPoint = { + assert(row.length == 2) + val label = row.getDouble(0) + val feature = row.getDouble(1) + LabeledPoint(label, feature) + } + } + + test("register user type: LabeledPoint") { + try { + TestSQLContext.registerUserType(new LabeledPointUDT()) + println("udtRegistry:") + TestSQLContext.udtRegistry.foreach { case (t, s) => println(s"$t -> $s")} + + println(s"test: ${scala.reflect.runtime.universe.typeTag[LabeledPoint]}") + assert(TestSQLContext.udtRegistry.contains(scala.reflect.runtime.universe.typeTag[LabeledPoint])) + + val points = Seq( + LabeledPoint(1.0, 2.0), + LabeledPoint(0.0, 3.0)) + val pointsRDD: RDD[LabeledPoint] = sparkContext.parallelize(points) + + println("Converting to SchemaRDD") + val tmpSchemaRDD: SchemaRDD = TestSQLContext.createSchemaRDD(pointsRDD) + println("blah") + println(s"SchemaRDD count: ${tmpSchemaRDD.count()}") + println("Done converting to SchemaRDD") + + /* + val features: RDD[DenseVector] = + pointsRDD.select('features).map { case Row(v: DenseVector) => v} + val featuresArrays: Array[DenseVector] = features.collect() + assert(featuresArrays.size === 2) + assert(featuresArrays.contains(new DenseVector(Array(1.0, 0.0)))) + assert(featuresArrays.contains(new DenseVector(Array(1.0, -1.0)))) + + val labels: RDD[Double] = pointsRDD.select('labels).map { case Row(v: Double) => v} + val labelsArrays: Array[Double] = labels.collect() + assert(labelsArrays.size === 2) + assert(labelsArrays.contains(1.0)) + assert(labelsArrays.contains(0.0)) + */ + } catch { + case e: Exception => + e.printStackTrace() + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala index 613e90bc2eb2c..eeb3793035147 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala @@ -18,64 +18,92 @@ package org.apache.spark.sql import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ class UserTypeSuite extends QueryTest { - class DenseVector(val data: Array[Double]) + class DenseVector(val data: Array[Double]) extends Serializable - case class LabeledPoint(label: Double, features: DenseVector) + case class LabeledPoint(label: Double, features: DenseVector) extends Serializable - class LabeledPointUDT extends UserDefinedType[LabeledPoint] { + object LabeledPointUDT { - override def dataType: StructType = - StructType(Seq(StructField("features", ArrayType(DoubleType), nullable = false))) + def dataType: StructType = + StructType(Seq( + StructField("label", DoubleType, nullable = false), + StructField("features", ArrayType(DoubleType), nullable = false))) - override def serialize(obj: Any): Row = Row(obj.asInstanceOf[DenseVector].data) + } + + case class LabeledPointUDT() extends UserDefinedType[LabeledPoint](LabeledPointUDT.dataType) with Serializable { + + override def serialize(obj: Any): Row = obj match { + case lp: LabeledPoint => + val row: GenericMutableRow = new GenericMutableRow(1 + lp.features.data.size) + row.setDouble(0, lp.label) + var i = 0 + while (i < lp.features.data.size) { + row.setDouble(1 + i, lp.features.data(i)) + i += 1 + } + row + // Array.concat(Array(lp.label), lp.features.data)) + } - override def deserialize(row: Row): DenseVector = { - val arr = new Array[Double](row.length) + override def deserialize(row: Row): LabeledPoint = { + assert(row.length >= 1) + val label = row.getDouble(0) + val arr = new Array[Double](row.length - 1) var i = 0 - while (i < row.length) { - arr(i) = row.getDouble(i) + while (i < row.length - 1) { + arr(i) = row.getDouble(i + 1) i += 1 } - new DenseVector(arr) + LabeledPoint(label, new DenseVector(arr)) } } test("register user type: LabeledPoint") { - TestSQLContext.registerUserType(new VectorRowSerializer()) - println("udtRegistry:") - TestSQLContext.udtRegistry.foreach { case (t,s) => println(s"$t -> $s") } - - println(s"test: ${scala.reflect.runtime.universe.typeTag[DenseVector]}") - assert(TestSQLContext.udtRegistry.contains(scala.reflect.runtime.universe.typeTag[DenseVector])) - - val points = Seq( - LabeledPoint(1.0, new DenseVector(Array(1.0, 0.0))), - LabeledPoint(0.0, new DenseVector(Array(1.0, -1.0)))) - val pointsRDD: RDD[LabeledPoint] = sparkContext.parallelize(points) - - println("Converting to SchemaRDD") - val tmpSchemaRDD: SchemaRDD = TestSQLContext.createSchemaRDD(pointsRDD) - println(s"SchemaRDD count: ${tmpSchemaRDD.count()}") - println("Done converting to SchemaRDD") - - val features: RDD[DenseVector] = - pointsRDD.select('features).map { case Row(v: DenseVector) => v } - val featuresArrays: Array[DenseVector] = features.collect() - assert(featuresArrays.size === 2) - assert(featuresArrays.contains(new DenseVector(Array(1.0, 0.0)))) - assert(featuresArrays.contains(new DenseVector(Array(1.0, -1.0)))) - - val labels: RDD[Double] = pointsRDD.select('labels).map { case Row(v: Double) => v } - val labelsArrays: Array[Double] = labels.collect() - assert(labelsArrays.size === 2) - assert(labelsArrays.contains(1.0)) - assert(labelsArrays.contains(0.0)) + try { + TestSQLContext.registerUserType(new LabeledPointUDT()) + println("udtRegistry:") + TestSQLContext.udtRegistry.foreach { case (t, s) => println(s"$t -> $s")} + + println(s"test: ${scala.reflect.runtime.universe.typeTag[LabeledPoint]}") + assert(TestSQLContext.udtRegistry.contains(scala.reflect.runtime.universe.typeTag[LabeledPoint])) + + val points = Seq( + LabeledPoint(1.0, new DenseVector(Array(1.0, 0.0))), + LabeledPoint(0.0, new DenseVector(Array(1.0, -1.0)))) + val pointsRDD: RDD[LabeledPoint] = sparkContext.parallelize(points) + + println("Converting to SchemaRDD") + val tmpSchemaRDD: SchemaRDD = TestSQLContext.createSchemaRDD(pointsRDD) + println("blah") + println(s"SchemaRDD count: ${tmpSchemaRDD.count()}") + println("Done converting to SchemaRDD") + + /* + val features: RDD[DenseVector] = + pointsRDD.select('features).map { case Row(v: DenseVector) => v} + val featuresArrays: Array[DenseVector] = features.collect() + assert(featuresArrays.size === 2) + assert(featuresArrays.contains(new DenseVector(Array(1.0, 0.0)))) + assert(featuresArrays.contains(new DenseVector(Array(1.0, -1.0)))) + + val labels: RDD[Double] = pointsRDD.select('labels).map { case Row(v: Double) => v} + val labelsArrays: Array[Double] = labels.collect() + assert(labelsArrays.size === 2) + assert(labelsArrays.contains(1.0)) + assert(labelsArrays.contains(0.0)) + */ + } catch { + case e: Exception => + e.printStackTrace() + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 1b751ab1b7c05..bed0faf74294b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -371,7 +371,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** Extends QueryExecution with hive specific features. */ protected[sql] abstract class QueryExecution extends super.QueryExecution { - override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy()) + override lazy val toRdd: RDD[Row] = { + //val dataType = StructType.fromAttributes(logical.output) + executedPlan.execute().map(ScalaReflection.convertRowToScala(_)) + } protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, From 53de70f3a38bde5a1efec9234ee72aa87055cfdc Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 7 Oct 2014 12:25:45 -0700 Subject: [PATCH 05/46] more udts... --- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 8 ++++++-- .../src/main/scala/org/apache/spark/sql/SchemaRDD.scala | 2 +- .../scala/org/apache/spark/sql/execution/SparkPlan.scala | 4 +++- .../org/apache/spark/sql/execution/basicOperators.scala | 4 ++-- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 4 ++-- 5 files changed, 14 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index cd26d863a4f24..e5185c963becb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -52,7 +52,9 @@ object ScalaReflection { } /** Converts Catalyst types used internally in rows to standard Scala types */ - def convertToScala(a: Any): Any = a match { + def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { + // TODO: USE DATATYPE + // TODO: What about Option and Product? case s: Seq[_] => s.map(convertToScala) case m: Map[_, _] => m.map { case (k, v) => convertToScala(k) -> convertToScala(v) } case d: Decimal => d.toBigDecimal @@ -60,7 +62,9 @@ object ScalaReflection { case other => other } - def convertRowToScala(r: Row): Row = new GenericRow(r.toArray.map(convertToScala)) + def convertRowToScala(r: Row, schema: StructType): Row = { + new GenericRow(r.toArray.map(convertToScala(_, schema))) + } /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 018a18c4ac214..f0cee63947721 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -114,7 +114,7 @@ class SchemaRDD( // ========================================================================================= override def compute(split: Partition, context: TaskContext): Iterator[Row] = - firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala) + firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala(_, this.schema)) override def getPartitions: Array[Partition] = firstParent[Row].partitions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index aafcce0572b25..21967b14617c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -82,7 +82,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Runs this query returning the result as an array. */ - def executeCollect(): Array[Row] = execute().map(ScalaReflection.convertRowToScala).collect() + def executeCollect(): Array[Row] = { + execute().map(ScalaReflection.convertRowToScala(_, schema)).collect() + } protected def newProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 1efb793337f9e..1b8ba3ace2a82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -143,7 +143,7 @@ case class Limit(limit: Int, child: SparkPlan) partsScanned += numPartsToTry } - buf.toArray.map(ScalaReflection.convertRowToScala) + buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema)) } override def execute() = { @@ -180,7 +180,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) // TODO: Is this copying for no reason? override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord) - .map(ScalaReflection.convertRowToScala) + .map(ScalaReflection.convertRowToScala(_, this.schema)) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index bed0faf74294b..6bb1f1b1d2a7a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -372,8 +372,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[sql] abstract class QueryExecution extends super.QueryExecution { override lazy val toRdd: RDD[Row] = { - //val dataType = StructType.fromAttributes(logical.output) - executedPlan.execute().map(ScalaReflection.convertRowToScala(_)) + val schema = StructType.fromAttributes(logical.output) + executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema)) } protected val primitiveTypes = From 8bebf24ad16f63034cb049c718e3d1b6070eea80 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 7 Oct 2014 15:51:07 -0700 Subject: [PATCH 06/46] commented out convertRowToScala for debugging --- .../spark/sql/catalyst/ScalaReflection.scala | 19 ++++++++++++------- .../spark/sql/catalyst/types/dataTypes.scala | 1 - .../org/apache/spark/sql/SQLContext.scala | 1 + .../org/apache/spark/sql/SchemaRDD.scala | 2 +- .../spark/sql/execution/SparkPlan.scala | 2 +- .../spark/sql/execution/basicOperators.scala | 4 ++-- 6 files changed, 17 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index e5185c963becb..4d755c25880df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -37,7 +37,8 @@ object ScalaReflection { /** Converts Scala objects to catalyst rows / types */ def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { - case (o: Option[_], oType: _) => convertToCatalyst(o.orNull, oType) + // TODO: Why does this not need to flatMap stuff? Does it not support nesting? + case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) @@ -46,25 +47,29 @@ object ScalaReflection { p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) => convertToCatalyst(elem, field.dataType) }.toArray) - case (udt: _, udtType: UDTType) => udtType. + case (udt: Any, udtType: UserDefinedType[_]) => udtType.serialize(udt) case (d: BigDecimal, _) => Decimal(d) case (other, _) => other } + /* /** Converts Catalyst types used internally in rows to standard Scala types */ def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { - // TODO: USE DATATYPE + // TODO: Why does this not need to flatMap stuff? Does it not support nesting? // TODO: What about Option and Product? - case s: Seq[_] => s.map(convertToScala) - case m: Map[_, _] => m.map { case (k, v) => convertToScala(k) -> convertToScala(v) } + case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType)) + case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => + convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) + } case d: Decimal => d.toBigDecimal - case (udt: Any, udtType: UserDefinedType[_]) => udtType.serialize(udt) - case other => other + case (udt: Row, udtType: UserDefinedType[_]) => udtType.deserialize(udt) + case (other, _) => other } def convertRowToScala(r: Row, schema: StructType): Row = { new GenericRow(r.toArray.map(convertToScala(_, schema))) } + */ /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 80bd1c4d951cd..ff25d0b136eb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -595,6 +595,5 @@ abstract class UserDefinedType[UserType](val dataType: StructType) extends DataT def deserialize(row: Row): UserType - // TODO def simpleString: String = "udt" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 5a2aa9a3dee01..ad70eee42d598 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -295,6 +295,7 @@ class SQLContext(@transient val sparkContext: SparkContext) udt: UserDefinedType[UserType])(implicit userType: TypeTag[UserType]): Unit = { require(!udtRegistry.contains(userType), "registerUserType called on type which was already registered.") + // TODO: Check to see if type is built-in. Throw exception? udtRegistry(userType) = udt } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index f0cee63947721..e455ab5d33aa6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -114,7 +114,7 @@ class SchemaRDD( // ========================================================================================= override def compute(split: Partition, context: TaskContext): Iterator[Row] = - firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala(_, this.schema)) + firstParent[Row].compute(split, context).map(_.copy) //(ScalaReflection.convertRowToScala(_, this.schema)) override def getPartitions: Array[Partition] = firstParent[Row].partitions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 21967b14617c0..40286eeec8274 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -83,7 +83,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Runs this query returning the result as an array. */ def executeCollect(): Array[Row] = { - execute().map(ScalaReflection.convertRowToScala(_, schema)).collect() + execute().map(_.copy).collect() //(ScalaReflection.convertRowToScala(_, schema)).collect() } protected def newProjection( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 1b8ba3ace2a82..35378f9ef92da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -143,7 +143,7 @@ case class Limit(limit: Int, child: SparkPlan) partsScanned += numPartsToTry } - buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema)) + buf.toArray//.map(ScalaReflection.convertRowToScala(_, this.schema)) } override def execute() = { @@ -180,7 +180,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) // TODO: Is this copying for no reason? override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord) - .map(ScalaReflection.convertRowToScala(_, this.schema)) + //.map(ScalaReflection.convertRowToScala(_, this.schema)) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. From 273ac9627b4acaced95521dc9ce2f1ef0eab7305 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 7 Oct 2014 19:22:10 -0700 Subject: [PATCH 07/46] basic UDT is working, but deserialization has yet to be done --- .../spark/sql/catalyst/ScalaReflection.scala | 65 ++++++++-- .../sql/catalyst/expressions/ScalaUdf.scala | 1 + .../spark/sql/catalyst/types/dataTypes.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 6 +- .../spark/sql/execution/ExistingRDD.scala | 11 +- .../spark/sql/execution/SparkStrategies.scala | 1 + .../apache/spark/sql/TmpSQLQuerySuite.scala | 63 ++++++++++ .../spark/sql/UserDefinedTypeSuite.scala | 114 ++++++++++-------- 8 files changed, 196 insertions(+), 67 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TmpSQLQuerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4d755c25880df..4c22d01758ec1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -36,20 +36,60 @@ object ScalaReflection { case class Schema(dataType: DataType, nullable: Boolean) /** Converts Scala objects to catalyst rows / types */ - def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + /* + def convertToCatalyst(a: Any, dataType: DataType): Any = a match { // TODO: Why does this not need to flatMap stuff? Does it not support nesting? - case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull - case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) - case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => - convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) + case o: Option[_] => + println(s"convertToCatalyst: option") + o.map(convertToCatalyst(_, dataType)).orNull + case s: Seq[_] => + println(s"convertToCatalyst: array") + s.map(convertToCatalyst(_, null)) + case m: Map[_, _] => + println(s"convertToCatalyst: map") + m.map { case (k, v) => + convertToCatalyst(k, null) -> convertToCatalyst(v, null) + } + case p: Product => + println(s"convertToCatalyst: struct") + new GenericRow(p.productIterator.map(convertToCatalyst(_, null)).toArray) + case other => + println(s"convertToCatalyst: other") + other + } + */ + + def convertToCatalyst(a: Any, dataType: DataType): Any = { + println(s"convertToCatalyst: a = $a, dataType = $dataType") + (a, dataType) match { + // TODO: Why does this not need to flatMap stuff? Does it not support nesting? + case (o: Option[_], _) => + println(s"convertToCatalyst: option") + o.map(convertToCatalyst(_, dataType)).orNull + case (s: Seq[_], arrayType: ArrayType) => + println(s"convertToCatalyst: array") + s.map(convertToCatalyst(_, arrayType.elementType)) + case (m: Map[_, _], mapType: MapType) => + println(s"convertToCatalyst: map") + m.map { case (k, v) => + convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) + } + case (p: Product, structType: StructType) => + println(s"convertToCatalyst: struct with") + println(s"\t p: $p") + println(s"\t structType: $structType") + new GenericRow( + p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) => + convertToCatalyst(elem, field.dataType) + }.toArray) + case (udt: Any, udtType: UserDefinedType[_]) => + println(s"convertToCatalyst: udt") + udtType.serialize(udt) + case (d: BigDecimal, _) => Decimal(d) + case (other, _) => + println(s"convertToCatalyst: other") + other } - case (p: Product, structType: StructType) => new GenericRow( - p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) => - convertToCatalyst(elem, field.dataType) - }.toArray) - case (udt: Any, udtType: UserDefinedType[_]) => udtType.serialize(udt) - case (d: BigDecimal, _) => Decimal(d) - case (other, _) => other } /* @@ -84,6 +124,7 @@ object ScalaReflection { def schemaFor[T: TypeTag](udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = { println(s"schemaFor: ${typeTag[T]}") if (udtRegistry.contains(typeTag[T])) { + println(s" schemaFor T matched udtRegistry") val udtStructType: StructType = udtRegistry(typeTag[T]).dataType Schema(udtStructType, nullable = true) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index fa1786e74bb3e..fb6dc7d3aad57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -51,6 +51,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi // scalastyle:off override def eval(input: Row): Any = { + println(s"ScalaUdf.eval called") val result = children.size match { case 0 => function.asInstanceOf[() => Any]() case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index ff25d0b136eb7..676910d237857 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -586,7 +586,7 @@ object UserDefinedType { /** * The data type for User Defined Types. */ -abstract class UserDefinedType[UserType](val dataType: StructType) extends DataType with Serializable { +abstract class UserDefinedType[UserType](val dataType: StructType) extends DataType { // Used only in regex parser above. //private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ad70eee42d598..f3736299b7b7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -103,8 +103,10 @@ class SQLContext(@transient val sparkContext: SparkContext) */ implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = { SparkPlan.currentContext.set(self) - new SchemaRDD(this, LogicalRDD(ScalaReflection.attributesFor[A](udtRegistry), - RDDConversions.productToRowRdd(rdd, ScalaReflection.schemaFor[A](udtRegistry).dataType))(self)) + println(s"createSchemaRDD called") + val attributeSeq = ScalaReflection.attributesFor[A](udtRegistry) + val schema = StructType.fromAttributes(attributeSeq) + new SchemaRDD(this, LogicalRDD(attributeSeq, RDDConversions.productToRowRdd(rdd, schema))(self)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 3d2cbb3057e0b..4fd768c6ab8dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataType, Row, SQLContext} +import org.apache.spark.sql.{DataType, StructType, Row, SQLContext} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.ScalaReflection.Schema import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -32,18 +32,21 @@ import org.apache.spark.sql.catalyst.types.UserDefinedType */ @DeveloperApi object RDDConversions { - def productToRowRdd[A <: Product](data: RDD[A], dataType: DataType): RDD[Row] = { + def productToRowRdd[A <: Product](data: RDD[A], schema: StructType): RDD[Row] = { + println(s"productToRowRdd called with datatype: $schema") data.mapPartitions { iterator => if (iterator.isEmpty) { Iterator.empty } else { val bufferedIterator = iterator.buffered val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity) - + assert(bufferedIterator.head.productArity == schema.fields.length) + val schemaFields = schema.fields.toArray bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { - mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i), dataType) + mutableRow(i) = + ScalaReflection.convertToCatalyst(r.productElement(i), schemaFields(i).dataType) i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d3c7033187048..f3d7c94466b39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -278,6 +278,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil case logical.LocalRelation(output, data) => val nPartitions = if (data.isEmpty) 1 else numPartitions + println(s"BasicOperators.apply: creating schema from attributes: $output") PhysicalRDD( output, RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TmpSQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TmpSQLQuerySuite.scala new file mode 100644 index 0000000000000..ccc9a9299bc8c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TmpSQLQuerySuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.TimeZone + +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.test._ +import org.scalatest.BeforeAndAfterAll + +/* Implicits */ +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.test.TestSQLContext._ + +class TmpSQLQuerySuite extends QueryTest with BeforeAndAfterAll { + // Make sure the tables are loaded. + TestData + + var origZone: TimeZone = _ + override protected def beforeAll() { + origZone = TimeZone.getDefault + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + } + + override protected def afterAll() { + TimeZone.setDefault(origZone) + } + + test("limit") { + /* + checkAnswer( + sql("SELECT * FROM testData LIMIT 10"), + testData.take(10).toSeq) +*/ + println("blah START") + checkAnswer( + sql("SELECT * FROM arrayData LIMIT 1"), + arrayData.collect().take(1).toSeq) + println("blah END") +/* + checkAnswer( + sql("SELECT * FROM mapData LIMIT 1"), + mapData.collect().take(1).toSeq) + */ + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 9996dfac93ebd..90c768bc9a6eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -23,75 +23,93 @@ import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ -class UserDefinedTypeSuite extends QueryTest { +class DenseVector(val data: Array[Double]) extends Serializable { + override def equals(other: Any): Boolean = other match { + case v: DenseVector => + java.util.Arrays.equals(this.data, v.data) + case _ => false + } +} - case class LabeledPoint(label: Double, feature: Double) extends Serializable +case class LabeledPoint(label: Double, features: DenseVector) - object LabeledPointUDT { +class UserDefinedTypeSuite extends QueryTest { + object LabeledPointUDT { def dataType: StructType = StructType(Seq( StructField("label", DoubleType, nullable = false), - StructField("feature", DoubleType, nullable = false))) - + StructField("features", ArrayType(DoubleType, containsNull = false), nullable = false))) } - case class LabeledPointUDT() extends UserDefinedType[LabeledPoint](LabeledPointUDT.dataType) with Serializable { + case class LabeledPointUDT() extends UserDefinedType[LabeledPoint](LabeledPointUDT.dataType) { override def serialize(obj: Any): Row = obj match { case lp: LabeledPoint => - val row: GenericMutableRow = new GenericMutableRow(2) + val row: GenericMutableRow = new GenericMutableRow(1 + lp.features.data.length) row.setDouble(0, lp.label) - row.setDouble(1, lp.feature) + var i = 0 + while (i < lp.features.data.length) { + row.setDouble(1 + i, lp.features.data(i)) + i += 1 + } row } override def deserialize(row: Row): LabeledPoint = { - assert(row.length == 2) + assert(row.length >= 1) val label = row.getDouble(0) - val feature = row.getDouble(1) - LabeledPoint(label, feature) + val numFeatures = row.length - 1 + val features = new DenseVector(new Array[Double](numFeatures)) + var i = 0 + while (i < numFeatures) { + features.data(i) = row.getDouble(1 + i) + i += 1 + } + LabeledPoint(label, features) } } test("register user type: LabeledPoint") { - try { - TestSQLContext.registerUserType(new LabeledPointUDT()) - println("udtRegistry:") - TestSQLContext.udtRegistry.foreach { case (t, s) => println(s"$t -> $s")} - - println(s"test: ${scala.reflect.runtime.universe.typeTag[LabeledPoint]}") - assert(TestSQLContext.udtRegistry.contains(scala.reflect.runtime.universe.typeTag[LabeledPoint])) - - val points = Seq( - LabeledPoint(1.0, 2.0), - LabeledPoint(0.0, 3.0)) - val pointsRDD: RDD[LabeledPoint] = sparkContext.parallelize(points) - - println("Converting to SchemaRDD") - val tmpSchemaRDD: SchemaRDD = TestSQLContext.createSchemaRDD(pointsRDD) - println("blah") - println(s"SchemaRDD count: ${tmpSchemaRDD.count()}") - println("Done converting to SchemaRDD") - - /* - val features: RDD[DenseVector] = - pointsRDD.select('features).map { case Row(v: DenseVector) => v} - val featuresArrays: Array[DenseVector] = features.collect() - assert(featuresArrays.size === 2) - assert(featuresArrays.contains(new DenseVector(Array(1.0, 0.0)))) - assert(featuresArrays.contains(new DenseVector(Array(1.0, -1.0)))) - - val labels: RDD[Double] = pointsRDD.select('labels).map { case Row(v: Double) => v} - val labelsArrays: Array[Double] = labels.collect() - assert(labelsArrays.size === 2) - assert(labelsArrays.contains(1.0)) - assert(labelsArrays.contains(0.0)) - */ - } catch { - case e: Exception => - e.printStackTrace() - } + TestSQLContext.registerUserType(new LabeledPointUDT()) + println("udtRegistry:") + TestSQLContext.udtRegistry.foreach { case (t, s) => println(s"$t -> $s")} + + println(s"test: ${scala.reflect.runtime.universe.typeTag[LabeledPoint]}") + assert(TestSQLContext.udtRegistry.contains(scala.reflect.runtime.universe.typeTag[LabeledPoint])) + + val points = Seq( + LabeledPoint(1.0, new DenseVector(Array(0.1, 1.0))), + LabeledPoint(0.0, new DenseVector(Array(0.2, 2.0)))) + val pointsRDD: RDD[LabeledPoint] = sparkContext.parallelize(points) + + println("Converting to SchemaRDD") + val tmpSchemaRDD: SchemaRDD = TestSQLContext.createSchemaRDD(pointsRDD) + println("blah") + println(s"SchemaRDD count: ${tmpSchemaRDD.count()}") + println("Done converting to SchemaRDD") + + // TODO: This test works even when the deserialize method is never used. How can I test deserialize? + val features: RDD[DenseVector] = + pointsRDD.select('features).map { case Row(v: DenseVector) => v} + val featuresArrays: Array[DenseVector] = features.collect() + assert(featuresArrays.size === 2) + assert(featuresArrays.contains(new DenseVector(Array(0.1, 1.0)))) + assert(featuresArrays.contains(new DenseVector(Array(0.2, 2.0)))) + + val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v} + val labelsArrays: Array[Double] = labels.collect() + assert(labelsArrays.size === 2) + assert(labelsArrays.contains(1.0)) + assert(labelsArrays.contains(0.0)) + } + + test("UDTs cannot be registered twice") { + // TODO + } + + test("UDTs cannot override built-in types") { + // TODO } } From 39f870732aedc70dc7b6f7509f3c18c2dc1964c6 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 7 Oct 2014 19:31:59 -0700 Subject: [PATCH 08/46] removed old udt suite --- .../org/apache/spark/sql/UserTypeSuite.scala | 109 ------------------ 1 file changed, 109 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala deleted file mode 100644 index eeb3793035147..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserTypeSuite.scala +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.types.UserDefinedType -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ - -class UserTypeSuite extends QueryTest { - - class DenseVector(val data: Array[Double]) extends Serializable - - case class LabeledPoint(label: Double, features: DenseVector) extends Serializable - - object LabeledPointUDT { - - def dataType: StructType = - StructType(Seq( - StructField("label", DoubleType, nullable = false), - StructField("features", ArrayType(DoubleType), nullable = false))) - - } - - case class LabeledPointUDT() extends UserDefinedType[LabeledPoint](LabeledPointUDT.dataType) with Serializable { - - override def serialize(obj: Any): Row = obj match { - case lp: LabeledPoint => - val row: GenericMutableRow = new GenericMutableRow(1 + lp.features.data.size) - row.setDouble(0, lp.label) - var i = 0 - while (i < lp.features.data.size) { - row.setDouble(1 + i, lp.features.data(i)) - i += 1 - } - row - // Array.concat(Array(lp.label), lp.features.data)) - } - - override def deserialize(row: Row): LabeledPoint = { - assert(row.length >= 1) - val label = row.getDouble(0) - val arr = new Array[Double](row.length - 1) - var i = 0 - while (i < row.length - 1) { - arr(i) = row.getDouble(i + 1) - i += 1 - } - LabeledPoint(label, new DenseVector(arr)) - } - } - - test("register user type: LabeledPoint") { - try { - TestSQLContext.registerUserType(new LabeledPointUDT()) - println("udtRegistry:") - TestSQLContext.udtRegistry.foreach { case (t, s) => println(s"$t -> $s")} - - println(s"test: ${scala.reflect.runtime.universe.typeTag[LabeledPoint]}") - assert(TestSQLContext.udtRegistry.contains(scala.reflect.runtime.universe.typeTag[LabeledPoint])) - - val points = Seq( - LabeledPoint(1.0, new DenseVector(Array(1.0, 0.0))), - LabeledPoint(0.0, new DenseVector(Array(1.0, -1.0)))) - val pointsRDD: RDD[LabeledPoint] = sparkContext.parallelize(points) - - println("Converting to SchemaRDD") - val tmpSchemaRDD: SchemaRDD = TestSQLContext.createSchemaRDD(pointsRDD) - println("blah") - println(s"SchemaRDD count: ${tmpSchemaRDD.count()}") - println("Done converting to SchemaRDD") - - /* - val features: RDD[DenseVector] = - pointsRDD.select('features).map { case Row(v: DenseVector) => v} - val featuresArrays: Array[DenseVector] = features.collect() - assert(featuresArrays.size === 2) - assert(featuresArrays.contains(new DenseVector(Array(1.0, 0.0)))) - assert(featuresArrays.contains(new DenseVector(Array(1.0, -1.0)))) - - val labels: RDD[Double] = pointsRDD.select('labels).map { case Row(v: Double) => v} - val labelsArrays: Array[Double] = labels.collect() - assert(labelsArrays.size === 2) - assert(labelsArrays.contains(1.0)) - assert(labelsArrays.contains(0.0)) - */ - } catch { - case e: Exception => - e.printStackTrace() - } - } - -} From 04303c9b1c179b6bb08b7e0c5987ebffadc65c92 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 9 Oct 2014 12:39:44 -0700 Subject: [PATCH 09/46] udts --- .../spark/sql/catalyst/ScalaReflection.scala | 118 +++++++++--------- .../spark/sql/catalyst/types/dataTypes.scala | 4 +- .../org/apache/spark/sql/SQLContext.scala | 6 +- .../org/apache/spark/sql/SchemaRDD.scala | 2 +- .../spark/sql/execution/SparkPlan.scala | 2 +- .../spark/sql/execution/basicOperators.scala | 4 +- .../spark/sql/UserDefinedTypeSuite.scala | 75 +++++------ 7 files changed, 102 insertions(+), 109 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4c22d01758ec1..88fb21eb0a6eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -82,8 +82,8 @@ object ScalaReflection { p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) => convertToCatalyst(elem, field.dataType) }.toArray) - case (udt: Any, udtType: UserDefinedType[_]) => - println(s"convertToCatalyst: udt") + case (udt, udtType: UserDefinedType[_]) => + println(s"convertToCatalyst: udt with $udtType") udtType.serialize(udt) case (d: BigDecimal, _) => Decimal(d) case (other, _) => @@ -92,7 +92,6 @@ object ScalaReflection { } } - /* /** Converts Catalyst types used internally in rows to standard Scala types */ def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { // TODO: Why does this not need to flatMap stuff? Does it not support nesting? @@ -109,7 +108,6 @@ object ScalaReflection { def convertRowToScala(r: Row, schema: StructType): Row = { new GenericRow(r.toArray.map(convertToScala(_, schema))) } - */ /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]( @@ -121,67 +119,66 @@ object ScalaReflection { } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag](udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = { - println(s"schemaFor: ${typeTag[T]}") - if (udtRegistry.contains(typeTag[T])) { - println(s" schemaFor T matched udtRegistry") - val udtStructType: StructType = udtRegistry(typeTag[T]).dataType - Schema(udtStructType, nullable = true) - } else { - schemaFor(typeOf[T]) - } - } + def schemaFor[T: TypeTag](udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = + schemaFor(typeOf[T], udtRegistry) /** * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. * TODO: ADD DOC */ - def schemaFor(tpe: `Type`): Schema = tpe match { - case t if t <:< typeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - Schema(schemaFor(optType).dataType, nullable = true) - case t if t <:< typeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val params = t.member(nme.CONSTRUCTOR).asMethod.paramss - Schema(StructType( - params.head.map { p => - val Schema(dataType, nullable) = - schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) - StructField(p.name.toString, dataType, nullable) - }), nullable = true) - // Need to decide if we actually need a special type here. - case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) - case t if t <:< typeOf[Array[_]] => - sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") - case t if t <:< typeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) - Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< typeOf[Map[_,_]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - Schema(MapType(schemaFor(keyType).dataType, - valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if t <:< typeOf[String] => Schema(StringType, nullable = true) - case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) - case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) - case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) - case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) - case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) - case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) - case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) - case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) - case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) - case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) - case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) - case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) - case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) - case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) - case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) - case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) - case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) - case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + def schemaFor(tpe: `Type`, udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = { + println(s"schemaFor: $tpe") + tpe match { + case t if udtRegistry.contains(t) => + println(s" schemaFor T matched udtRegistry") + Schema(udtRegistry(t), nullable = true) + case t if t <:< typeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + Schema(schemaFor(optType, udtRegistry).dataType, nullable = true) + case t if t <:< typeOf[Product] => + println(s" --schemaFor matched on Product") + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val params = t.member(nme.CONSTRUCTOR).asMethod.paramss + Schema(StructType( + params.head.map { p => + val Schema(dataType, nullable) = + schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs), udtRegistry) + StructField(p.name.toString, dataType, nullable) + }), nullable = true) + // Need to decide if we actually need a special type here. + case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) + case t if t <:< typeOf[Array[_]] => + sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") + case t if t <:< typeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType, udtRegistry) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + case t if t <:< typeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val Schema(valueDataType, valueNullable) = schemaFor(valueType, udtRegistry) + Schema(MapType(schemaFor(keyType, udtRegistry).dataType, + valueDataType, valueContainsNull = valueNullable), nullable = true) + case t if t <:< typeOf[String] => Schema(StringType, nullable = true) + case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) + case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) + case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) + case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) + case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) + case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) + case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) + case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) + case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) + case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) + case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) + case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) + case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) + case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) + case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + } } def typeOfObject: PartialFunction[Any, DataType] = { @@ -219,4 +216,5 @@ object ScalaReflection { LocalRelation(output, data) } } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 676910d237857..2beb7fe1693cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -586,11 +586,13 @@ object UserDefinedType { /** * The data type for User Defined Types. */ -abstract class UserDefinedType[UserType](val dataType: StructType) extends DataType { +abstract class UserDefinedType[UserType] extends DataType { // Used only in regex parser above. //private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { } + def sqlType: DataType + def serialize(obj: Any): Row def deserialize(row: Row): UserType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f3736299b7b7b..f838eeae9e753 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -293,12 +293,12 @@ class SQLContext(@transient val sparkContext: SparkContext) * RDDs of user types and SchemaRDDs. * Fails if this type has been registered already. */ - def registerUserType[UserType]( + def registerType[UserType]( udt: UserDefinedType[UserType])(implicit userType: TypeTag[UserType]): Unit = { - require(!udtRegistry.contains(userType), + require(!udtRegistry.contains(userType.tpe), "registerUserType called on type which was already registered.") // TODO: Check to see if type is built-in. Throw exception? - udtRegistry(userType) = udt + udtRegistry(userType.tpe) = udt } /** Map: UserType --> UserDefinedType */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index e455ab5d33aa6..f0cee63947721 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -114,7 +114,7 @@ class SchemaRDD( // ========================================================================================= override def compute(split: Partition, context: TaskContext): Iterator[Row] = - firstParent[Row].compute(split, context).map(_.copy) //(ScalaReflection.convertRowToScala(_, this.schema)) + firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala(_, this.schema)) override def getPartitions: Array[Partition] = firstParent[Row].partitions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 40286eeec8274..21967b14617c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -83,7 +83,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Runs this query returning the result as an array. */ def executeCollect(): Array[Row] = { - execute().map(_.copy).collect() //(ScalaReflection.convertRowToScala(_, schema)).collect() + execute().map(ScalaReflection.convertRowToScala(_, schema)).collect() } protected def newProjection( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 35378f9ef92da..1b8ba3ace2a82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -143,7 +143,7 @@ case class Limit(limit: Int, child: SparkPlan) partsScanned += numPartsToTry } - buf.toArray//.map(ScalaReflection.convertRowToScala(_, this.schema)) + buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema)) } override def execute() = { @@ -180,7 +180,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) // TODO: Is this copying for no reason? override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord) - //.map(ScalaReflection.convertRowToScala(_, this.schema)) + .map(ScalaReflection.convertRowToScala(_, this.schema)) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 90c768bc9a6eb..38600cbaaf4a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -33,50 +33,42 @@ class DenseVector(val data: Array[Double]) extends Serializable { case class LabeledPoint(label: Double, features: DenseVector) -class UserDefinedTypeSuite extends QueryTest { - - object LabeledPointUDT { - def dataType: StructType = - StructType(Seq( - StructField("label", DoubleType, nullable = false), - StructField("features", ArrayType(DoubleType, containsNull = false), nullable = false))) - } +case object DenseVectorUDT extends UserDefinedType[DenseVector] { - case class LabeledPointUDT() extends UserDefinedType[LabeledPoint](LabeledPointUDT.dataType) { - - override def serialize(obj: Any): Row = obj match { - case lp: LabeledPoint => - val row: GenericMutableRow = new GenericMutableRow(1 + lp.features.data.length) - row.setDouble(0, lp.label) - var i = 0 - while (i < lp.features.data.length) { - row.setDouble(1 + i, lp.features.data(i)) - i += 1 - } - row - } + override def sqlType: ArrayType = ArrayType(DoubleType, containsNull = false) - override def deserialize(row: Row): LabeledPoint = { - assert(row.length >= 1) - val label = row.getDouble(0) - val numFeatures = row.length - 1 - val features = new DenseVector(new Array[Double](numFeatures)) + override def serialize(obj: Any): Row = obj match { + case features: DenseVector => + val row: GenericMutableRow = new GenericMutableRow(features.data.length) + // TODO: Is there a copyTo command I can use? var i = 0 - while (i < numFeatures) { - features.data(i) = row.getDouble(1 + i) + while (i < features.data.length) { + row.setDouble(i, features.data(i)) i += 1 } - LabeledPoint(label, features) + row + } + + override def deserialize(row: Row): DenseVector = { + val features = new DenseVector(new Array[Double](row.length)) + var i = 0 + while (i < row.length) { + features.data(i) = row.getDouble(i) + i += 1 } + features } +} + +class UserDefinedTypeSuite extends QueryTest { - test("register user type: LabeledPoint") { - TestSQLContext.registerUserType(new LabeledPointUDT()) + test("register user type: DenseVector for LabeledPoint") { + TestSQLContext.registerType(DenseVectorUDT) println("udtRegistry:") TestSQLContext.udtRegistry.foreach { case (t, s) => println(s"$t -> $s")} - println(s"test: ${scala.reflect.runtime.universe.typeTag[LabeledPoint]}") - assert(TestSQLContext.udtRegistry.contains(scala.reflect.runtime.universe.typeTag[LabeledPoint])) + println(s"test: ${scala.reflect.runtime.universe.typeOf[DenseVector]}") + assert(TestSQLContext.udtRegistry.contains(scala.reflect.runtime.universe.typeOf[DenseVector])) val points = Seq( LabeledPoint(1.0, new DenseVector(Array(0.1, 1.0))), @@ -89,27 +81,28 @@ class UserDefinedTypeSuite extends QueryTest { println(s"SchemaRDD count: ${tmpSchemaRDD.count()}") println("Done converting to SchemaRDD") - // TODO: This test works even when the deserialize method is never used. How can I test deserialize? + val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v} + val labelsArrays: Array[Double] = labels.collect() + assert(labelsArrays.size === 2) + assert(labelsArrays.contains(1.0)) + assert(labelsArrays.contains(0.0)) + val features: RDD[DenseVector] = pointsRDD.select('features).map { case Row(v: DenseVector) => v} val featuresArrays: Array[DenseVector] = features.collect() assert(featuresArrays.size === 2) assert(featuresArrays.contains(new DenseVector(Array(0.1, 1.0)))) assert(featuresArrays.contains(new DenseVector(Array(0.2, 2.0)))) - - val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v} - val labelsArrays: Array[Double] = labels.collect() - assert(labelsArrays.size === 2) - assert(labelsArrays.contains(1.0)) - assert(labelsArrays.contains(0.0)) } - test("UDTs cannot be registered twice") { + /* + test("UDTs can be registered twice, overriding previous registration") { // TODO } test("UDTs cannot override built-in types") { // TODO } + */ } From 50f97269654b859f1babdef526c9fdebb9fa78f2 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 9 Oct 2014 13:09:15 -0700 Subject: [PATCH 10/46] udts --- .../main/scala/org/apache/spark/sql/SchemaRDDLike.scala | 7 ++++++- .../scala/org/apache/spark/sql/UserDefinedTypeSuite.scala | 4 +++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 15516afb95504..455692fa5fd02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -29,6 +29,8 @@ private[sql] trait SchemaRDDLike { @transient val sqlContext: SQLContext @transient val baseLogicalPlan: LogicalPlan + assert(sqlContext != null) + private[sql] def baseSchemaRDD: SchemaRDD /** @@ -49,7 +51,10 @@ private[sql] trait SchemaRDDLike { */ @transient @DeveloperApi - lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan) + lazy val queryExecution = { + assert(sqlContext != null) + sqlContext.executePlan(baseLogicalPlan) + } @transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match { // For various commands (like DDL) and queries with side effects, we force query optimization to diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 38600cbaaf4a9..c1ae1dd71720e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -63,7 +63,7 @@ case object DenseVectorUDT extends UserDefinedType[DenseVector] { class UserDefinedTypeSuite extends QueryTest { test("register user type: DenseVector for LabeledPoint") { - TestSQLContext.registerType(DenseVectorUDT) + registerType(DenseVectorUDT) println("udtRegistry:") TestSQLContext.udtRegistry.foreach { case (t, s) => println(s"$t -> $s")} @@ -81,12 +81,14 @@ class UserDefinedTypeSuite extends QueryTest { println(s"SchemaRDD count: ${tmpSchemaRDD.count()}") println("Done converting to SchemaRDD") + println("testing labels") val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v} val labelsArrays: Array[Double] = labels.collect() assert(labelsArrays.size === 2) assert(labelsArrays.contains(1.0)) assert(labelsArrays.contains(0.0)) + println("testing features") val features: RDD[DenseVector] = pointsRDD.select('features).map { case Row(v: DenseVector) => v} val featuresArrays: Array[DenseVector] = features.collect() From 893ee4cacefcfd8d6516481bc15166a6f3aced60 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 9 Oct 2014 14:18:41 -0700 Subject: [PATCH 11/46] udt finallly working --- .../spark/sql/catalyst/ScalaReflection.scala | 29 +++++++++++++------ .../org/apache/spark/sql/SQLContext.scala | 4 ++- .../org/apache/spark/sql/SchemaRDD.scala | 3 +- .../org/apache/spark/sql/SchemaRDDLike.scala | 7 +---- .../spark/sql/UserDefinedTypeSuite.scala | 4 +-- 5 files changed, 28 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 88fb21eb0a6eb..eb0780a6ddea9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -96,17 +96,28 @@ object ScalaReflection { def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { // TODO: Why does this not need to flatMap stuff? Does it not support nesting? // TODO: What about Option and Product? - case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType)) - case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => + case (s: Seq[_], arrayType: ArrayType) => + println("convertToScala: Seq") + s.map(convertToScala(_, arrayType.elementType)) + case (m: Map[_, _], mapType: MapType) => + println("convertToScala: Map") + m.map { case (k, v) => convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) } - case d: Decimal => d.toBigDecimal - case (udt: Row, udtType: UserDefinedType[_]) => udtType.deserialize(udt) - case (other, _) => other + case (d: Decimal, DecimalType) => d.toBigDecimal + case (udt: Row, udtType: UserDefinedType[_]) => + println("convertToScala: udt") + udtType.deserialize(udt) + case (other, _) => + println("convertToScala: other") + other } def convertRowToScala(r: Row, schema: StructType): Row = { - new GenericRow(r.toArray.map(convertToScala(_, schema))) + println("convertRowToScala called with schema: $schema") + new GenericRow( + r.zip(schema.fields.map(_.dataType)) + .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray) } /** Returns a Sequence of attributes for the given case class type. */ @@ -129,9 +140,6 @@ object ScalaReflection { def schemaFor(tpe: `Type`, udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = { println(s"schemaFor: $tpe") tpe match { - case t if udtRegistry.contains(t) => - println(s" schemaFor T matched udtRegistry") - Schema(udtRegistry(t), nullable = true) case t if t <:< typeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t Schema(schemaFor(optType, udtRegistry).dataType, nullable = true) @@ -178,6 +186,9 @@ object ScalaReflection { case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + case t if udtRegistry.contains(t) => + println(s" schemaFor T matched udtRegistry") + Schema(udtRegistry(t), nullable = true) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f838eeae9e753..e66aa22144cdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -106,7 +106,9 @@ class SQLContext(@transient val sparkContext: SparkContext) println(s"createSchemaRDD called") val attributeSeq = ScalaReflection.attributesFor[A](udtRegistry) val schema = StructType.fromAttributes(attributeSeq) - new SchemaRDD(this, LogicalRDD(attributeSeq, RDDConversions.productToRowRdd(rdd, schema))(self)) + val rowRDD = RDDConversions.productToRowRdd(rdd, schema) + println("done with productToRowRdd") + new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index f0cee63947721..80bfa8771e22c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -125,7 +125,7 @@ class SchemaRDD( * * @group schema */ - def schema: StructType = queryExecution.analyzed.schema + val schema: StructType = queryExecution.analyzed.schema // ======================================================================= // Query DSL @@ -148,6 +148,7 @@ class SchemaRDD( case (ne: NamedExpression, _) => ne case (e, i) => Alias(e, s"c$i")() } + assert(sqlContext != null) new SchemaRDD(sqlContext, Project(aliases, logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 455692fa5fd02..15516afb95504 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -29,8 +29,6 @@ private[sql] trait SchemaRDDLike { @transient val sqlContext: SQLContext @transient val baseLogicalPlan: LogicalPlan - assert(sqlContext != null) - private[sql] def baseSchemaRDD: SchemaRDD /** @@ -51,10 +49,7 @@ private[sql] trait SchemaRDDLike { */ @transient @DeveloperApi - lazy val queryExecution = { - assert(sqlContext != null) - sqlContext.executePlan(baseLogicalPlan) - } + lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan) @transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match { // For various commands (like DDL) and queries with side effects, we force query optimization to diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index c1ae1dd71720e..0d6843ed112f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -82,7 +82,7 @@ class UserDefinedTypeSuite extends QueryTest { println("Done converting to SchemaRDD") println("testing labels") - val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v} + val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() assert(labelsArrays.size === 2) assert(labelsArrays.contains(1.0)) @@ -90,7 +90,7 @@ class UserDefinedTypeSuite extends QueryTest { println("testing features") val features: RDD[DenseVector] = - pointsRDD.select('features).map { case Row(v: DenseVector) => v} + pointsRDD.select('features).map { case Row(v: DenseVector) => v } val featuresArrays: Array[DenseVector] = features.collect() assert(featuresArrays.size === 2) assert(featuresArrays.contains(new DenseVector(Array(0.1, 1.0)))) From 964b32e532c2949976d764239298064dea9c5081 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 9 Oct 2014 14:48:33 -0700 Subject: [PATCH 12/46] some cleanups --- .../spark/sql/catalyst/ScalaReflection.scala | 194 ++++++------------ .../sql/catalyst/expressions/ScalaUdf.scala | 1 - .../spark/sql/catalyst/types/dataTypes.scala | 23 +-- 3 files changed, 72 insertions(+), 146 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index eb0780a6ddea9..537207adc19ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -36,85 +36,34 @@ object ScalaReflection { case class Schema(dataType: DataType, nullable: Boolean) /** Converts Scala objects to catalyst rows / types */ - /* - def convertToCatalyst(a: Any, dataType: DataType): Any = a match { - // TODO: Why does this not need to flatMap stuff? Does it not support nesting? - case o: Option[_] => - println(s"convertToCatalyst: option") - o.map(convertToCatalyst(_, dataType)).orNull - case s: Seq[_] => - println(s"convertToCatalyst: array") - s.map(convertToCatalyst(_, null)) - case m: Map[_, _] => - println(s"convertToCatalyst: map") - m.map { case (k, v) => - convertToCatalyst(k, null) -> convertToCatalyst(v, null) - } - case p: Product => - println(s"convertToCatalyst: struct") - new GenericRow(p.productIterator.map(convertToCatalyst(_, null)).toArray) - case other => - println(s"convertToCatalyst: other") - other - } - */ - - def convertToCatalyst(a: Any, dataType: DataType): Any = { - println(s"convertToCatalyst: a = $a, dataType = $dataType") - (a, dataType) match { - // TODO: Why does this not need to flatMap stuff? Does it not support nesting? - case (o: Option[_], _) => - println(s"convertToCatalyst: option") - o.map(convertToCatalyst(_, dataType)).orNull - case (s: Seq[_], arrayType: ArrayType) => - println(s"convertToCatalyst: array") - s.map(convertToCatalyst(_, arrayType.elementType)) - case (m: Map[_, _], mapType: MapType) => - println(s"convertToCatalyst: map") - m.map { case (k, v) => - convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) - } - case (p: Product, structType: StructType) => - println(s"convertToCatalyst: struct with") - println(s"\t p: $p") - println(s"\t structType: $structType") - new GenericRow( - p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) => - convertToCatalyst(elem, field.dataType) - }.toArray) - case (udt, udtType: UserDefinedType[_]) => - println(s"convertToCatalyst: udt with $udtType") - udtType.serialize(udt) - case (d: BigDecimal, _) => Decimal(d) - case (other, _) => - println(s"convertToCatalyst: other") - other + def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull + case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) + case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => + convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) } + case (p: Product, structType: StructType) => + new GenericRow( + p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) => + convertToCatalyst(elem, field.dataType) + }.toArray) + case (d: BigDecimal, _) => Decimal(d) + case (udt, udtType: UserDefinedType[_]) => udtType.serialize(udt) + case (other, _) => other } /** Converts Catalyst types used internally in rows to standard Scala types */ def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { - // TODO: Why does this not need to flatMap stuff? Does it not support nesting? - // TODO: What about Option and Product? - case (s: Seq[_], arrayType: ArrayType) => - println("convertToScala: Seq") - s.map(convertToScala(_, arrayType.elementType)) - case (m: Map[_, _], mapType: MapType) => - println("convertToScala: Map") - m.map { case (k, v) => + case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType)) + case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) } case (d: Decimal, DecimalType) => d.toBigDecimal - case (udt: Row, udtType: UserDefinedType[_]) => - println("convertToScala: udt") - udtType.deserialize(udt) - case (other, _) => - println("convertToScala: other") - other + case (udt: Row, udtType: UserDefinedType[_]) => udtType.deserialize(udt) + case (other, _) => other } def convertRowToScala(r: Row, schema: StructType): Row = { - println("convertRowToScala called with schema: $schema") new GenericRow( r.zip(schema.fields.map(_.dataType)) .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray) @@ -133,63 +82,57 @@ object ScalaReflection { def schemaFor[T: TypeTag](udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = schemaFor(typeOf[T], udtRegistry) - /** - * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. - * TODO: ADD DOC - */ - def schemaFor(tpe: `Type`, udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = { - println(s"schemaFor: $tpe") - tpe match { - case t if t <:< typeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - Schema(schemaFor(optType, udtRegistry).dataType, nullable = true) - case t if t <:< typeOf[Product] => - println(s" --schemaFor matched on Product") - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val params = t.member(nme.CONSTRUCTOR).asMethod.paramss - Schema(StructType( - params.head.map { p => - val Schema(dataType, nullable) = - schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs), udtRegistry) - StructField(p.name.toString, dataType, nullable) - }), nullable = true) - // Need to decide if we actually need a special type here. - case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) - case t if t <:< typeOf[Array[_]] => - sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") - case t if t <:< typeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType, udtRegistry) - Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< typeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(valueDataType, valueNullable) = schemaFor(valueType, udtRegistry) - Schema(MapType(schemaFor(keyType, udtRegistry).dataType, - valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if t <:< typeOf[String] => Schema(StringType, nullable = true) - case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) - case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) - case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) - case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) - case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) - case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) - case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) - case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) - case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) - case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) - case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) - case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) - case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) - case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) - case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) - case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) - case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) - case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) - case t if udtRegistry.contains(t) => - println(s" schemaFor T matched udtRegistry") - Schema(udtRegistry(t), nullable = true) - } + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ + def schemaFor( + tpe: `Type`, + udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = tpe match { + case t if t <:< typeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + Schema(schemaFor(optType, udtRegistry).dataType, nullable = true) + case t if t <:< typeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val params = t.member(nme.CONSTRUCTOR).asMethod.paramss + Schema(StructType( + params.head.map { p => + val Schema(dataType, nullable) = + schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs), udtRegistry) + StructField(p.name.toString, dataType, nullable) + }), nullable = true) + // Need to decide if we actually need a special type here. + case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) + case t if t <:< typeOf[Array[_]] => + sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") + case t if t <:< typeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType, udtRegistry) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + case t if t <:< typeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val Schema(valueDataType, valueNullable) = schemaFor(valueType, udtRegistry) + Schema(MapType(schemaFor(keyType, udtRegistry).dataType, + valueDataType, valueContainsNull = valueNullable), nullable = true) + case t if t <:< typeOf[String] => Schema(StringType, nullable = true) + case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) + case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) + case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) + case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) + case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) + case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) + case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) + case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) + case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) + case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) + case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) + case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) + case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) + case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) + case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + case t if udtRegistry.contains(t) => + Schema(udtRegistry(t), nullable = true) } def typeOfObject: PartialFunction[Any, DataType] = { @@ -227,5 +170,4 @@ object ScalaReflection { LocalRelation(output, data) } } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index fb6dc7d3aad57..fa1786e74bb3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -51,7 +51,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi // scalastyle:off override def eval(input: Row): Any = { - println(s"ScalaUdf.eval called") val result = children.size match { case 0 => function.asInstanceOf[() => Any]() case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 2beb7fe1693cf..ebed3f15ab326 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -567,34 +567,19 @@ case class MapType( ("valueContainsNull" -> valueContainsNull) } -// TODO: Where should this go? -trait UserDefinedType[T] { - def dataType: StructType - def serialize(obj: T): Row - def deserialize(row: Row): T -} - -object UserDefinedType { - /** - * Construct a [[UserDefinedType]] object with the given key type and value type. - * The `valueContainsNull` is true. - */ - //def apply(keyType: DataType, valueType: DataType): MapType = - // MapType(keyType: DataType, valueType: DataType, true) -} - /** * The data type for User Defined Types. */ abstract class UserDefinedType[UserType] extends DataType { - // Used only in regex parser above. - //private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { } - + /** Underlying storage type for this UDT used by SparkSQL */ def sqlType: DataType + /** Convert the user type to a Row object */ + // TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, where we need to convert Any to UserType. def serialize(obj: Any): Row + /** Convert a Row object to the user type */ def deserialize(row: Row): UserType def simpleString: String = "udt" From fea04af0cf855149c6bed75792942ed6081e1995 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 9 Oct 2014 14:56:09 -0700 Subject: [PATCH 13/46] more cleanups --- .../spark/sql/catalyst/types/dataTypes.scala | 3 +- .../org/apache/spark/sql/SchemaRDD.scala | 10 ++- .../spark/sql/execution/ExistingRDD.scala | 2 - .../spark/sql/execution/SparkPlan.scala | 3 +- .../apache/spark/sql/TmpSQLQuerySuite.scala | 63 ------------------- 5 files changed, 7 insertions(+), 74 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TmpSQLQuerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index ebed3f15ab326..84371db9f642c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -576,7 +576,8 @@ abstract class UserDefinedType[UserType] extends DataType { def sqlType: DataType /** Convert the user type to a Row object */ - // TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, where we need to convert Any to UserType. + // TODO: Can we make this take obj: UserType? + // The issue is in ScalaReflection.convertToCatalyst, where we need to convert Any to UserType. def serialize(obj: Any): Row /** Convert a Row object to the user type */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 80bfa8771e22c..3e59becb0d143 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -17,10 +17,7 @@ package org.apache.spark.sql -import java.util.{Map => JMap, List => JList} - -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.storage.StorageLevel +import java.util.{List => JList} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -29,6 +26,7 @@ import net.razorvine.pickle.Pickler import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext} import org.apache.spark.annotation.{AlphaComponent, Experimental} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.sql.api.java.JavaSchemaRDD import org.apache.spark.sql.catalyst.analysis._ @@ -36,7 +34,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} -import org.apache.spark.api.java.JavaRDD +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.storage.StorageLevel /** * :: AlphaComponent :: @@ -148,7 +147,6 @@ class SchemaRDD( case (ne: NamedExpression, _) => ne case (e, i) => Alias(e, s"c$i")() } - assert(sqlContext != null) new SchemaRDD(sqlContext, Project(aliases, logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 4fd768c6ab8dd..a330ade4d9aa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -33,14 +33,12 @@ import org.apache.spark.sql.catalyst.types.UserDefinedType @DeveloperApi object RDDConversions { def productToRowRdd[A <: Product](data: RDD[A], schema: StructType): RDD[Row] = { - println(s"productToRowRdd called with datatype: $schema") data.mapPartitions { iterator => if (iterator.isEmpty) { Iterator.empty } else { val bufferedIterator = iterator.buffered val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity) - assert(bufferedIterator.head.productArity == schema.fields.length) val schemaFields = schema.fields.toArray bufferedIterator.map { r => var i = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 21967b14617c0..df3b0d70b8fd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -82,9 +82,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Runs this query returning the result as an array. */ - def executeCollect(): Array[Row] = { + def executeCollect(): Array[Row] = execute().map(ScalaReflection.convertRowToScala(_, schema)).collect() - } protected def newProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TmpSQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TmpSQLQuerySuite.scala deleted file mode 100644 index ccc9a9299bc8c..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/TmpSQLQuerySuite.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.util.TimeZone - -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.test._ -import org.scalatest.BeforeAndAfterAll - -/* Implicits */ -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.TestSQLContext._ - -class TmpSQLQuerySuite extends QueryTest with BeforeAndAfterAll { - // Make sure the tables are loaded. - TestData - - var origZone: TimeZone = _ - override protected def beforeAll() { - origZone = TimeZone.getDefault - TimeZone.setDefault(TimeZone.getTimeZone("UTC")) - } - - override protected def afterAll() { - TimeZone.setDefault(origZone) - } - - test("limit") { - /* - checkAnswer( - sql("SELECT * FROM testData LIMIT 10"), - testData.take(10).toSeq) -*/ - println("blah START") - checkAnswer( - sql("SELECT * FROM arrayData LIMIT 1"), - arrayData.collect().take(1).toSeq) - println("blah END") -/* - checkAnswer( - sql("SELECT * FROM mapData LIMIT 1"), - mapData.collect().take(1).toSeq) - */ - } - -} From b226b9e56687bbd31b3ea2cdd0c4dc0e9609fb8e Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 10 Oct 2014 10:33:37 -0700 Subject: [PATCH 14/46] Changing UDT to annotation --- .../apache/spark/annotation/DeveloperApi.java | 4 +- .../spark/sql/catalyst/ScalaReflection.scala | 123 +++++++++--------- .../spark/sql/catalyst/UDTRegistry.scala | 59 +++++++++ .../catalyst/annotation/UserDefinedType.java | 32 +++++ .../spark/sql/catalyst/types/dataTypes.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 23 +--- .../apache/spark/sql/UdfRegistration.scala | 46 +++---- .../spark/sql/UserDefinedTypeSuite.scala | 14 +- 8 files changed, 188 insertions(+), 115 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/UserDefinedType.java diff --git a/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java b/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java index 0ecef6db0e039..fbcda23c7b7e2 100644 --- a/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java +++ b/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java @@ -17,7 +17,7 @@ package org.apache.spark.annotation; -import java.lang.annotation.*; + import java.lang.annotation.*; /** * A lower-level, unstable API intended for developers. @@ -31,5 +31,5 @@ */ @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, - ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) public @interface DeveloperApi {} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 537207adc19ee..ab4720ef313d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.annotation.UserDefinedType import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types._ @@ -35,7 +36,11 @@ object ScalaReflection { case class Schema(dataType: DataType, nullable: Boolean) - /** Converts Scala objects to catalyst rows / types */ + /** + * Converts Scala objects to catalyst rows / types. + * Note: This is always called after schemaFor has been called. + * This ordering is important for UDT registration. + */ def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) @@ -48,7 +53,7 @@ object ScalaReflection { convertToCatalyst(elem, field.dataType) }.toArray) case (d: BigDecimal, _) => Decimal(d) - case (udt, udtType: UserDefinedType[_]) => udtType.serialize(udt) + case (udt, udtType: UserDefinedTypeType[_]) => udtType.serialize(udt) case (other, _) => other } @@ -59,7 +64,7 @@ object ScalaReflection { convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) } case (d: Decimal, DecimalType) => d.toBigDecimal - case (udt: Row, udtType: UserDefinedType[_]) => udtType.deserialize(udt) + case (udt: Row, udtType: UserDefinedTypeType[_]) => udtType.deserialize(udt) case (other, _) => other } @@ -70,69 +75,69 @@ object ScalaReflection { } /** Returns a Sequence of attributes for the given case class type. */ - def attributesFor[T: TypeTag]( - udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Seq[Attribute] = { - schemaFor[T](udtRegistry) match { + def attributesFor[T: TypeTag]: Seq[Attribute] = { + schemaFor[T] match { case Schema(s: StructType, _) => s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) } } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag](udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = - schemaFor(typeOf[T], udtRegistry) + def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T]) /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor( - tpe: `Type`, - udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = tpe match { - case t if t <:< typeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - Schema(schemaFor(optType, udtRegistry).dataType, nullable = true) - case t if t <:< typeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val params = t.member(nme.CONSTRUCTOR).asMethod.paramss - Schema(StructType( - params.head.map { p => - val Schema(dataType, nullable) = - schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs), udtRegistry) - StructField(p.name.toString, dataType, nullable) - }), nullable = true) - // Need to decide if we actually need a special type here. - case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) - case t if t <:< typeOf[Array[_]] => - sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") - case t if t <:< typeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType, udtRegistry) - Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< typeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(valueDataType, valueNullable) = schemaFor(valueType, udtRegistry) - Schema(MapType(schemaFor(keyType, udtRegistry).dataType, - valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if t <:< typeOf[String] => Schema(StringType, nullable = true) - case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) - case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) - case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) - case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) - case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) - case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) - case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) - case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) - case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) - case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) - case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) - case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) - case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) - case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) - case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) - case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) - case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) - case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) - case t if udtRegistry.contains(t) => - Schema(udtRegistry(t), nullable = true) + def schemaFor(tpe: `Type`): Schema = { + println(s"schemaFor: tpe = $tpe, tpe.getClass = ${tpe.getClass}, classOf[UserDefinedType] = ${classOf[UserDefinedType]}") + tpe match { + case t if t <:< typeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + Schema(schemaFor(optType).dataType, nullable = true) + case t if t <:< typeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val params = t.member(nme.CONSTRUCTOR).asMethod.paramss + Schema(StructType( + params.head.map { p => + val Schema(dataType, nullable) = + schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) + StructField(p.name.toString, dataType, nullable) + }), nullable = true) + // Need to decide if we actually need a special type here. + case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) + case t if t <:< typeOf[Array[_]] => + sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") + case t if t <:< typeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + case t if t <:< typeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + Schema(MapType(schemaFor(keyType).dataType, + valueDataType, valueContainsNull = valueNullable), nullable = true) + case t if t <:< typeOf[String] => Schema(StringType, nullable = true) + case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) + case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) + case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) + case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) + case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) + case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) + case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) + case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) + case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) + case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) + case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) + case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) + case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) + case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) + case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + case t if t.getClass.isAnnotationPresent(classOf[UserDefinedType]) => + UDTRegistry.registerType(t) + Schema(UDTRegistry.udtRegistry(t), nullable = true) + } } def typeOfObject: PartialFunction[Any, DataType] = { @@ -166,7 +171,7 @@ object ScalaReflection { def asRelation: LocalRelation = { // Pass empty map to attributesFor since this method is only used for debugging Catalyst, // not used with SparkSQL. - val output = attributesFor[A](Map.empty) + val output = attributesFor[A] LocalRelation(output, data) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala new file mode 100644 index 0000000000000..09dde4a995f34 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.sql.catalyst.annotation.UserDefinedType + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.types.UserDefinedTypeType + +import scala.reflect.runtime.universe._ + +/** + * Global registry for user-defined types (UDTs). + */ +private[sql] object UDTRegistry { + /** Map: UserType --> UserDefinedType */ + val udtRegistry = new mutable.HashMap[Any, UserDefinedTypeType[_]]() + + /** + * Register a user-defined type and its serializer, to allow automatic conversion between + * RDDs of user types and SchemaRDDs. + * Fails if this type has been registered already. + */ + /* + def registerType[UserType](implicit userType: TypeTag[UserType]): Unit = { + // TODO: Check to see if type is built-in. Throw exception? + val udt: UserDefinedTypeType[_] = + userType.getClass.getAnnotation(classOf[UserDefinedType]).udt().newInstance() + UDTRegistry.udtRegistry(userType.tpe) = udt + }*/ + + def registerType[UserType](implicit userType: Type): Unit = { + // TODO: Check to see if type is built-in. Throw exception? + if (!UDTRegistry.udtRegistry.contains(userType)) { + val udt: UserDefinedTypeType[_] = + userType.getClass.getAnnotation(classOf[UserDefinedType]).udt().newInstance() + UDTRegistry.udtRegistry(userType) = udt + } + // TODO: Else: Should we check (assert) that udt is the same as what is in the registry? + } + + //def getUDT +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/UserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/UserDefinedType.java new file mode 100644 index 0000000000000..f697712725752 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/UserDefinedType.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.annotation; + +import org.apache.spark.sql.catalyst.types.UserDefinedTypeType; + +import java.lang.annotation.*; + +/** + * A user-defined type which can be automatically recognized by a SQLContext and registered. + */ +// TODO: Should I used @Documented ? +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface UserDefinedType { + Class > udt(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 84371db9f642c..9eabe36195684 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -570,7 +570,7 @@ case class MapType( /** * The data type for User Defined Types. */ -abstract class UserDefinedType[UserType] extends DataType { +abstract class UserDefinedTypeType[UserType] extends DataType { /** Underlying storage type for this UDT used by SparkSQL */ def sqlType: DataType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index e66aa22144cdd..9c725436df3df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.types.UserDefinedType - -import scala.collection.mutable import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} @@ -104,7 +101,7 @@ class SQLContext(@transient val sparkContext: SparkContext) implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = { SparkPlan.currentContext.set(self) println(s"createSchemaRDD called") - val attributeSeq = ScalaReflection.attributesFor[A](udtRegistry) + val attributeSeq = ScalaReflection.attributesFor[A] val schema = StructType.fromAttributes(attributeSeq) val rowRDD = RDDConversions.productToRowRdd(rdd, schema) println("done with productToRowRdd") @@ -259,7 +256,7 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD( this, ParquetRelation.createEmpty( - path, ScalaReflection.attributesFor[A](udtRegistry), allowExisting, conf, this)) + path, ScalaReflection.attributesFor[A], allowExisting, conf, this)) } /** @@ -290,22 +287,6 @@ class SQLContext(@transient val sparkContext: SparkContext) def table(tableName: String): SchemaRDD = new SchemaRDD(this, catalog.lookupRelation(None, tableName)) - /** - * Register a user-defined type and its serializer, to allow automatic conversion between - * RDDs of user types and SchemaRDDs. - * Fails if this type has been registered already. - */ - def registerType[UserType]( - udt: UserDefinedType[UserType])(implicit userType: TypeTag[UserType]): Unit = { - require(!udtRegistry.contains(userType.tpe), - "registerUserType called on type which was already registered.") - // TODO: Check to see if type is built-in. Throw exception? - udtRegistry(userType.tpe) = udt - } - - /** Map: UserType --> UserDefinedType */ - protected[sql] val udtRegistry = new mutable.HashMap[Any, UserDefinedType[_]]() - protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index 9946c8aa4d1bc..6d4c0d82ac7af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -78,7 +78,7 @@ private[sql] trait UDFRegistration { s""" def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = { def builder(e: Seq[Expression]) = - ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } """ @@ -87,112 +87,112 @@ private[sql] trait UDFRegistration { // scalastyle:off def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function2[_, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function3[_, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function4[_, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function5[_, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function6[_, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function7[_, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function8[_, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function9[_, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function10[_, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T](udtRegistry).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } // scalastyle:on diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 0d6843ed112f9..652980efec57f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -18,11 +18,13 @@ package org.apache.spark.sql import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.annotation.UserDefinedType import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.types.UserDefinedType +import org.apache.spark.sql.catalyst.types.UserDefinedTypeType import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ +@UserDefinedType(udt = classOf[DenseVectorUDT]) class DenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { case v: DenseVector => @@ -33,7 +35,7 @@ class DenseVector(val data: Array[Double]) extends Serializable { case class LabeledPoint(label: Double, features: DenseVector) -case object DenseVectorUDT extends UserDefinedType[DenseVector] { +class DenseVectorUDT extends UserDefinedTypeType[DenseVector] { override def sqlType: ArrayType = ArrayType(DoubleType, containsNull = false) @@ -63,13 +65,7 @@ case object DenseVectorUDT extends UserDefinedType[DenseVector] { class UserDefinedTypeSuite extends QueryTest { test("register user type: DenseVector for LabeledPoint") { - registerType(DenseVectorUDT) - println("udtRegistry:") - TestSQLContext.udtRegistry.foreach { case (t, s) => println(s"$t -> $s")} - - println(s"test: ${scala.reflect.runtime.universe.typeOf[DenseVector]}") - assert(TestSQLContext.udtRegistry.contains(scala.reflect.runtime.universe.typeOf[DenseVector])) - + //registerType(DenseVectorUDT) val points = Seq( LabeledPoint(1.0, new DenseVector(Array(0.1, 1.0))), LabeledPoint(0.0, new DenseVector(Array(0.2, 2.0)))) From 357903547e731bfc1a15ead6fa8903737c547316 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 10 Oct 2014 11:53:27 -0700 Subject: [PATCH 15/46] udt annotation now working --- .../spark/sql/catalyst/ScalaReflection.scala | 4 ++-- .../apache/spark/sql/catalyst/UDTRegistry.scala | 17 ++++------------- .../catalyst/annotation/UserDefinedType.java | 2 ++ .../spark/sql/catalyst/types/dataTypes.scala | 4 +++- .../scala/org/apache/spark/sql/SQLContext.scala | 2 -- .../apache/spark/sql/UserDefinedTypeSuite.scala | 10 ---------- 6 files changed, 11 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index ab4720ef313d0..fd2ab2e6f3b6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -87,7 +87,6 @@ object ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = { - println(s"schemaFor: tpe = $tpe, tpe.getClass = ${tpe.getClass}, classOf[UserDefinedType] = ${classOf[UserDefinedType]}") tpe match { case t if t <:< typeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t @@ -134,7 +133,8 @@ object ScalaReflection { case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) - case t if t.getClass.isAnnotationPresent(classOf[UserDefinedType]) => + case t if getClass.getClassLoader.loadClass(t.typeSymbol.asClass.fullName) + .isAnnotationPresent(classOf[UserDefinedType]) => UDTRegistry.registerType(t) Schema(UDTRegistry.udtRegistry(t), nullable = true) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala index 09dde4a995f34..c5ea5ca804e07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala @@ -35,25 +35,16 @@ private[sql] object UDTRegistry { /** * Register a user-defined type and its serializer, to allow automatic conversion between * RDDs of user types and SchemaRDDs. - * Fails if this type has been registered already. + * If this type has already been registered, this does nothing. */ - /* - def registerType[UserType](implicit userType: TypeTag[UserType]): Unit = { - // TODO: Check to see if type is built-in. Throw exception? - val udt: UserDefinedTypeType[_] = - userType.getClass.getAnnotation(classOf[UserDefinedType]).udt().newInstance() - UDTRegistry.udtRegistry(userType.tpe) = udt - }*/ - def registerType[UserType](implicit userType: Type): Unit = { // TODO: Check to see if type is built-in. Throw exception? if (!UDTRegistry.udtRegistry.contains(userType)) { - val udt: UserDefinedTypeType[_] = - userType.getClass.getAnnotation(classOf[UserDefinedType]).udt().newInstance() + val udt = + getClass.getClassLoader.loadClass(userType.typeSymbol.asClass.fullName) + .getAnnotation(classOf[UserDefinedType]).udt().newInstance() UDTRegistry.udtRegistry(userType) = udt } // TODO: Else: Should we check (assert) that udt is the same as what is in the registry? } - - //def getUDT } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/UserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/UserDefinedType.java index f697712725752..21d6809154822 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/UserDefinedType.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/UserDefinedType.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.annotation; +import org.apache.spark.annotation.DeveloperApi; import org.apache.spark.sql.catalyst.types.UserDefinedTypeType; import java.lang.annotation.*; @@ -25,6 +26,7 @@ * A user-defined type which can be automatically recognized by a SQLContext and registered. */ // TODO: Should I used @Documented ? +@DeveloperApi @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface UserDefinedType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 9eabe36195684..916b6a2ea7b53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -29,6 +29,7 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Row} import org.apache.spark.sql.catalyst.types.decimal._ @@ -570,7 +571,8 @@ case class MapType( /** * The data type for User Defined Types. */ -abstract class UserDefinedTypeType[UserType] extends DataType { +@DeveloperApi +abstract class UserDefinedTypeType[UserType] extends DataType with Serializable { /** Underlying storage type for this UDT used by SparkSQL */ def sqlType: DataType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9c725436df3df..4df38a88ebcd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -100,11 +100,9 @@ class SQLContext(@transient val sparkContext: SparkContext) */ implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = { SparkPlan.currentContext.set(self) - println(s"createSchemaRDD called") val attributeSeq = ScalaReflection.attributesFor[A] val schema = StructType.fromAttributes(attributeSeq) val rowRDD = RDDConversions.productToRowRdd(rdd, schema) - println("done with productToRowRdd") new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 652980efec57f..45d366f20b811 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -42,7 +42,6 @@ class DenseVectorUDT extends UserDefinedTypeType[DenseVector] { override def serialize(obj: Any): Row = obj match { case features: DenseVector => val row: GenericMutableRow = new GenericMutableRow(features.data.length) - // TODO: Is there a copyTo command I can use? var i = 0 while (i < features.data.length) { row.setDouble(i, features.data(i)) @@ -65,26 +64,17 @@ class DenseVectorUDT extends UserDefinedTypeType[DenseVector] { class UserDefinedTypeSuite extends QueryTest { test("register user type: DenseVector for LabeledPoint") { - //registerType(DenseVectorUDT) val points = Seq( LabeledPoint(1.0, new DenseVector(Array(0.1, 1.0))), LabeledPoint(0.0, new DenseVector(Array(0.2, 2.0)))) val pointsRDD: RDD[LabeledPoint] = sparkContext.parallelize(points) - println("Converting to SchemaRDD") - val tmpSchemaRDD: SchemaRDD = TestSQLContext.createSchemaRDD(pointsRDD) - println("blah") - println(s"SchemaRDD count: ${tmpSchemaRDD.count()}") - println("Done converting to SchemaRDD") - - println("testing labels") val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() assert(labelsArrays.size === 2) assert(labelsArrays.contains(1.0)) assert(labelsArrays.contains(0.0)) - println("testing features") val features: RDD[DenseVector] = pointsRDD.select('features).map { case Row(v: DenseVector) => v } val featuresArrays: Array[DenseVector] = features.collect() From 2f40c02a8891e3466fdf1cbb10120cc17b961b96 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 10 Oct 2014 13:13:35 -0700 Subject: [PATCH 16/46] renamed UDT types --- .../spark/sql/catalyst/ScalaReflection.scala | 108 +++++++++--------- .../spark/sql/catalyst/UDTRegistry.scala | 8 +- ...finedType.java => SQLUserDefinedType.java} | 6 +- .../spark/sql/catalyst/types/dataTypes.scala | 2 +- .../spark/sql/UserDefinedTypeSuite.scala | 19 +-- 5 files changed, 65 insertions(+), 78 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/{UserDefinedType.java => SQLUserDefinedType.java} (88%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index fd2ab2e6f3b6d..8989dafa2b2a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.annotation.UserDefinedType +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types._ @@ -53,7 +53,7 @@ object ScalaReflection { convertToCatalyst(elem, field.dataType) }.toArray) case (d: BigDecimal, _) => Decimal(d) - case (udt, udtType: UserDefinedTypeType[_]) => udtType.serialize(udt) + case (udt, udtType: UserDefinedType[_]) => udtType.serialize(udt) case (other, _) => other } @@ -64,7 +64,7 @@ object ScalaReflection { convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) } case (d: Decimal, DecimalType) => d.toBigDecimal - case (udt: Row, udtType: UserDefinedTypeType[_]) => udtType.deserialize(udt) + case (udt: Row, udtType: UserDefinedType[_]) => udtType.deserialize(udt) case (other, _) => other } @@ -86,58 +86,56 @@ object ScalaReflection { def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T]) /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor(tpe: `Type`): Schema = { - tpe match { - case t if t <:< typeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - Schema(schemaFor(optType).dataType, nullable = true) - case t if t <:< typeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val params = t.member(nme.CONSTRUCTOR).asMethod.paramss - Schema(StructType( - params.head.map { p => - val Schema(dataType, nullable) = - schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) - StructField(p.name.toString, dataType, nullable) - }), nullable = true) - // Need to decide if we actually need a special type here. - case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) - case t if t <:< typeOf[Array[_]] => - sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") - case t if t <:< typeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) - Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< typeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - Schema(MapType(schemaFor(keyType).dataType, - valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if t <:< typeOf[String] => Schema(StringType, nullable = true) - case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) - case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) - case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) - case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) - case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) - case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) - case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) - case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) - case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) - case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) - case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) - case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) - case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) - case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) - case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) - case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) - case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) - case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) - case t if getClass.getClassLoader.loadClass(t.typeSymbol.asClass.fullName) - .isAnnotationPresent(classOf[UserDefinedType]) => - UDTRegistry.registerType(t) - Schema(UDTRegistry.udtRegistry(t), nullable = true) - } + def schemaFor(tpe: `Type`): Schema = tpe match { + case t if getClass.getClassLoader.loadClass(t.typeSymbol.asClass.fullName) + .isAnnotationPresent(classOf[SQLUserDefinedType]) => + UDTRegistry.registerType(t) + Schema(UDTRegistry.udtRegistry(t), nullable = true) + case t if t <:< typeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + Schema(schemaFor(optType).dataType, nullable = true) + case t if t <:< typeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val params = t.member(nme.CONSTRUCTOR).asMethod.paramss + Schema(StructType( + params.head.map { p => + val Schema(dataType, nullable) = + schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) + StructField(p.name.toString, dataType, nullable) + }), nullable = true) + // Need to decide if we actually need a special type here. + case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) + case t if t <:< typeOf[Array[_]] => + sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") + case t if t <:< typeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + case t if t <:< typeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + Schema(MapType(schemaFor(keyType).dataType, + valueDataType, valueContainsNull = valueNullable), nullable = true) + case t if t <:< typeOf[String] => Schema(StringType, nullable = true) + case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) + case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) + case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) + case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) + case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) + case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) + case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) + case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) + case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) + case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) + case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) + case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) + case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) + case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) + case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) } def typeOfObject: PartialFunction[Any, DataType] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala index c5ea5ca804e07..60dd782c1c45c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.catalyst.annotation.UserDefinedType +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import scala.collection.mutable -import org.apache.spark.sql.catalyst.types.UserDefinedTypeType +import org.apache.spark.sql.catalyst.types.UserDefinedType import scala.reflect.runtime.universe._ @@ -30,7 +30,7 @@ import scala.reflect.runtime.universe._ */ private[sql] object UDTRegistry { /** Map: UserType --> UserDefinedType */ - val udtRegistry = new mutable.HashMap[Any, UserDefinedTypeType[_]]() + val udtRegistry = new mutable.HashMap[Any, UserDefinedType[_]]() /** * Register a user-defined type and its serializer, to allow automatic conversion between @@ -42,7 +42,7 @@ private[sql] object UDTRegistry { if (!UDTRegistry.udtRegistry.contains(userType)) { val udt = getClass.getClassLoader.loadClass(userType.typeSymbol.asClass.fullName) - .getAnnotation(classOf[UserDefinedType]).udt().newInstance() + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() UDTRegistry.udtRegistry(userType) = udt } // TODO: Else: Should we check (assert) that udt is the same as what is in the registry? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/UserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java similarity index 88% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/UserDefinedType.java rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java index 21d6809154822..1ecb0ac00bb09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/UserDefinedType.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.annotation; import org.apache.spark.annotation.DeveloperApi; -import org.apache.spark.sql.catalyst.types.UserDefinedTypeType; +import org.apache.spark.sql.catalyst.types.UserDefinedType; import java.lang.annotation.*; @@ -29,6 +29,6 @@ @DeveloperApi @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) -public @interface UserDefinedType { - Class > udt(); +public @interface SQLUserDefinedType { + Class > udt(); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 916b6a2ea7b53..0c9c73af24d21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -572,7 +572,7 @@ case class MapType( * The data type for User Defined Types. */ @DeveloperApi -abstract class UserDefinedTypeType[UserType] extends DataType with Serializable { +abstract class UserDefinedType[UserType] extends DataType with Serializable { /** Underlying storage type for this UDT used by SparkSQL */ def sqlType: DataType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 45d366f20b811..bcbae7cd50aec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.annotation.UserDefinedType +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.types.UserDefinedTypeType -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.test.TestSQLContext._ -@UserDefinedType(udt = classOf[DenseVectorUDT]) +@SQLUserDefinedType(udt = classOf[DenseVectorUDT]) class DenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { case v: DenseVector => @@ -35,7 +34,7 @@ class DenseVector(val data: Array[Double]) extends Serializable { case class LabeledPoint(label: Double, features: DenseVector) -class DenseVectorUDT extends UserDefinedTypeType[DenseVector] { +class DenseVectorUDT extends UserDefinedType[DenseVector] { override def sqlType: ArrayType = ArrayType(DoubleType, containsNull = false) @@ -83,14 +82,4 @@ class UserDefinedTypeSuite extends QueryTest { assert(featuresArrays.contains(new DenseVector(Array(0.2, 2.0)))) } - /* - test("UDTs can be registered twice, overriding previous registration") { - // TODO - } - - test("UDTs cannot override built-in types") { - // TODO - } - */ - } From e1f7b9cd053b53550a23909bd5a9ace5074a9066 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 10 Oct 2014 14:02:32 -0700 Subject: [PATCH 17/46] blah --- .../apache/spark/annotation/DeveloperApi.java | 4 ++-- .../sql/catalyst/ScalaReflectionSuite.scala | 16 ++++++++++------ .../org/apache/spark/sql/hive/HiveContext.scala | 2 +- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java b/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java index fbcda23c7b7e2..0ecef6db0e039 100644 --- a/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java +++ b/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java @@ -17,7 +17,7 @@ package org.apache.spark.annotation; - import java.lang.annotation.*; +import java.lang.annotation.*; /** * A lower-level, unstable API intended for developers. @@ -31,5 +31,5 @@ */ @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, - ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) public @interface DeveloperApi {} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 37978155f43ae..0fe3cdeeba53c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -69,7 +69,7 @@ case class GenericData[A]( class ScalaReflectionSuite extends FunSuite { import ScalaReflection._ -/* + test("primitive data") { val schema = schemaFor[PrimitiveData] assert(schema === Schema( @@ -239,14 +239,18 @@ class ScalaReflectionSuite extends FunSuite { test("convert PrimitiveData to catalyst") { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) - assert(convertToCatalyst(data) === convertedData) + val dataType = schemaFor[PrimitiveData].dataType + assert(convertToCatalyst(data, dataType) === convertedData) } test("convert Option[Product] to catalyst") { val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true) - val data = OptionalData(Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(true), Some(primitiveData)) - val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true, convertToCatalyst(primitiveData)) - assert(convertToCatalyst(data) === convertedData) + val data = OptionalData(Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(true), + Some(primitiveData)) + val dataType = schemaFor[PrimitiveData].dataType + val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true, + convertToCatalyst(primitiveData, dataType)) + assert(convertToCatalyst(data, dataType) === convertedData) } - */ + } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 6bb1f1b1d2a7a..15cc6067ecb05 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -121,7 +121,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * @tparam A A case class that is used to describe the schema of the table to be created. */ def createTable[A <: Product : TypeTag](tableName: String, allowExisting: Boolean = true) { - catalog.createTable("default", tableName, ScalaReflection.attributesFor[A](udtRegistry), + catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], allowExisting) } From 34a5831eb28b1422b1562f256e7be1f290a70c40 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 10 Oct 2014 15:14:29 -0700 Subject: [PATCH 18/46] Added MLlib dependency on SQL. --- mllib/pom.xml | 5 +++++ .../scala/org/apache/spark/sql/catalyst/UDTRegistry.scala | 6 ++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mllib/pom.xml b/mllib/pom.xml index fb7239e779aae..0335409c9123d 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -40,6 +40,11 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + org.apache.spark spark-streaming_${scala.binary.version} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala index 60dd782c1c45c..b82f0b5f3eb02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType - import scala.collection.mutable +import scala.reflect.runtime.universe._ +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.types.UserDefinedType -import scala.reflect.runtime.universe._ - /** * Global registry for user-defined types (UDTs). */ From cd60cb48d36142a152cd02a263212f5c041e6c23 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 21 Oct 2014 11:57:14 -0700 Subject: [PATCH 19/46] Trying to get other SQL tests to run --- .../spark/sql/catalyst/ScalaReflection.scala | 110 ++++++++++-------- .../spark/sql/catalyst/UDTRegistry.scala | 2 +- .../annotation/SQLUserDefinedType.java | 5 + .../org/apache/spark/sql/SQLQuerySuite.scala | 40 ++++--- 4 files changed, 88 insertions(+), 69 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8989dafa2b2a5..75eb3b4872475 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} +import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -86,56 +87,65 @@ object ScalaReflection { def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T]) /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor(tpe: `Type`): Schema = tpe match { - case t if getClass.getClassLoader.loadClass(t.typeSymbol.asClass.fullName) - .isAnnotationPresent(classOf[SQLUserDefinedType]) => - UDTRegistry.registerType(t) - Schema(UDTRegistry.udtRegistry(t), nullable = true) - case t if t <:< typeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - Schema(schemaFor(optType).dataType, nullable = true) - case t if t <:< typeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val params = t.member(nme.CONSTRUCTOR).asMethod.paramss - Schema(StructType( - params.head.map { p => - val Schema(dataType, nullable) = - schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) - StructField(p.name.toString, dataType, nullable) - }), nullable = true) - // Need to decide if we actually need a special type here. - case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) - case t if t <:< typeOf[Array[_]] => - sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") - case t if t <:< typeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) - Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< typeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - Schema(MapType(schemaFor(keyType).dataType, - valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if t <:< typeOf[String] => Schema(StringType, nullable = true) - case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) - case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) - case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) - case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) - case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) - case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) - case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) - case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) - case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) - case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) - case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) - case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) - case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) - case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) - case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) - case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) - case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) - case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + def schemaFor(tpe: `Type`): Schema = { + val className: String = tpe.erasure.typeSymbol.asClass.fullName + tpe match { + case t if Utils.classIsLoadable(className) && + Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + // Note: We check for classIsLoadable above since Utils.classForName uses Java reflection, + // whereas className is from Scala reflection. This can make it hard to find classes + // in some cases, such as when a class is enclosed in an object (in which case + // Java appends a '$' to the object name but Scala does not). + UDTRegistry.registerType(t) + Schema(UDTRegistry.udtRegistry(t), nullable = true) + case t if UDTRegistry.udtRegistry.contains(t) => + Schema(UDTRegistry.udtRegistry(t), nullable = true) + case t if t <:< typeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + Schema(schemaFor(optType).dataType, nullable = true) + case t if t <:< typeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val params = t.member(nme.CONSTRUCTOR).asMethod.paramss + Schema(StructType( + params.head.map { p => + val Schema(dataType, nullable) = + schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) + StructField(p.name.toString, dataType, nullable) + }), nullable = true) + // Need to decide if we actually need a special type here. + case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) + case t if t <:< typeOf[Array[_]] => + sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") + case t if t <:< typeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + case t if t <:< typeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + Schema(MapType(schemaFor(keyType).dataType, + valueDataType, valueContainsNull = valueNullable), nullable = true) + case t if t <:< typeOf[String] => Schema(StringType, nullable = true) + case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) + case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) + case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) + case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) + case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) + case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) + case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) + case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) + case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) + case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) + case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) + case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) + case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) + case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) + case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + } } def typeOfObject: PartialFunction[Any, DataType] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala index b82f0b5f3eb02..a9be187ded96e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala @@ -35,7 +35,7 @@ private[sql] object UDTRegistry { * RDDs of user types and SchemaRDDs. * If this type has already been registered, this does nothing. */ - def registerType[UserType](implicit userType: Type): Unit = { + def registerType(userType: Type): Unit = { // TODO: Check to see if type is built-in. Throw exception? if (!UDTRegistry.udtRegistry.contains(userType)) { val udt = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java index 1ecb0ac00bb09..fa909a9eb1b3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java @@ -24,6 +24,11 @@ /** * A user-defined type which can be automatically recognized by a SQLContext and registered. + * + * WARNING: This annotation will only work if both Java and Scala reflection return the same class + * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class + * is enclosed in an object (a singleton). In these cases, the UDT must be registered + * manually. */ // TODO: Should I used @Documented ? @DeveloperApi diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6befe1b755cc6..73dac52452f23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -737,28 +737,32 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("throw errors for non-aggregate attributes with aggregation") { - def checkAggregation(query: String, isInvalidQuery: Boolean = true) { - val logicalPlan = sql(query).queryExecution.logical - - if (isInvalidQuery) { - val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed) - assert( - e.getMessage.startsWith("Expression not in GROUP BY"), - "Non-aggregate attribute(s) not detected\n" + logicalPlan) - } else { - // Should not throw - sql(query).queryExecution.analyzed + try { + def checkAggregation(query: String, isInvalidQuery: Boolean = true) { + val logicalPlan = sql(query).queryExecution.logical + + if (isInvalidQuery) { + val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed) + assert( + e.getMessage.startsWith("Expression not in GROUP BY"), + "Non-aggregate attribute(s) not detected\n" + logicalPlan) + } else { + // Should not throw + sql(query).queryExecution.analyzed + } } - } - checkAggregation("SELECT key, COUNT(*) FROM testData") - checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false) + checkAggregation("SELECT key, COUNT(*) FROM testData") + checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false) - checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key") - checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false) + checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key") + checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false) - checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1") - checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) + checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1") + checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) + } catch { + case e: Exception => println(e.getStackTraceString) + } } test("Test to check we can use Long.MinValue") { From dff99d6b29b02b33be24dae00d8da7122e0d7d2f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 21 Oct 2014 19:26:41 -0700 Subject: [PATCH 20/46] Added UDTs for Vectors in MLlib, plus DatasetExample using the UDTs --- .../spark/examples/mllib/DatasetExample.scala | 101 ++++++++++++++ .../apache/spark/mllib/linalg/Vectors.scala | 131 +++++++++++++++++- .../spark/mllib/regression/LabeledPoint.scala | 2 + .../spark/sql/catalyst/ScalaReflection.scala | 7 +- .../spark/sql/catalyst/UDTRegistry.scala | 11 +- .../org/apache/spark/sql/SchemaRDDLike.scala | 2 +- 6 files changed, 242 insertions(+), 12 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala new file mode 100644 index 0000000000000..0728a0021497c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} + +/** + * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.Dataset [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DatasetExample { + + case class Params( + input: String = "data/mllib/sample_libsvm_data.txt", + dataFormat: String = "libsvm") extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("Dataset") { + head("Dataset: an example app using SchemaRDD as a Dataset for ML.") + opt[String]("input") + .text(s"input path to dataset") + .action((x, c) => c.copy(input = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(input = x)) + checkConfig { params => + success + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"Dataset with $params") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ // for implicit conversions + + // Load input data + val origData: RDD[LabeledPoint] = params.dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, params.input) + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input) + } + println(s"Loaded ${origData.count()} instances from file: ${params.input}") + + // Convert input data to SchemaRDD explicitly. + val schemaRDD: SchemaRDD = origData + println(s"Converted to SchemaRDD with ${schemaRDD.count()} records") + + // Select columns, using implicit conversion to SchemaRDD. + val labelsSchemaRDD: SchemaRDD = origData.select('label) + val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v } + val numLabels = labels.count() + val meanLabel = labels.fold(0.0)(_ + _) / numLabels + println(s"Selected label column with average value $meanLabel") + + val featuresSchemaRDD: SchemaRDD = origData.select('features) + val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v } + val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + + sc.stop() + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 6af225b7f49f7..d473eff1202dc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -25,8 +25,13 @@ import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} -import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException +import org.apache.spark.mllib.util.NumericParser +import org.apache.spark.sql.catalyst.UDTRegistry +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.Row /** * Represents a numeric vector, whose index type is Int and value type is Double. @@ -81,6 +86,8 @@ sealed trait Vector extends Serializable { */ object Vectors { + UDTRegistry.registerType(scala.reflect.runtime.universe.typeOf[Vector], new VectorUDT()) + /** * Creates a dense vector from its values. */ @@ -191,6 +198,7 @@ object Vectors { /** * A dense vector represented by a value array. */ +@SQLUserDefinedType(udt = classOf[DenseVectorUDT]) class DenseVector(val values: Array[Double]) extends Vector { override def size: Int = values.length @@ -242,3 +250,124 @@ class SparseVector( private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) } + +/** + * User-defined type for [[Vector]] which allows easy interaction with SQL + * via [[org.apache.spark.sql.SchemaRDD]]. + */ +private[spark] class VectorUDT extends UserDefinedType[Vector] { + + /** + * vectorType: 0 = dense, 1 = sparse. + * dense, sparse: One element holds the vector, and the other is null. + */ + override def sqlType: StructType = StructType(Seq( + StructField("vectorType", ByteType, nullable = false), + StructField("dense", new DenseVectorUDT(), nullable = true), + StructField("sparse", new SparseVectorUDT(), nullable = true))) + + override def serialize(obj: Any): Row = { + val row = new GenericMutableRow(3) + obj match { + case v: DenseVector => + row.setByte(0, 0) + row.update(1, new DenseVectorUDT().serialize(obj)) + row.setNullAt(2) + case v: SparseVector => + row.setByte(0, 1) + row.setNullAt(1) + row.update(2, new SparseVectorUDT().serialize(obj)) + } + row + } + + override def deserialize(row: Row): Vector = { + require(row.length == 3, + s"VectorUDT.deserialize given row with length ${row.length} but requires length == 3") + val vectorType = row.getByte(0) + vectorType match { + case 0 => + new DenseVectorUDT().deserialize(row.getAs[Row](1)) + case 1 => + new SparseVectorUDT().deserialize(row.getAs[Row](2)) + } + } +} + +/** + * User-defined type for [[DenseVector]] which allows easy interaction with SQL + * via [[org.apache.spark.sql.SchemaRDD]]. + */ +private[spark] class DenseVectorUDT extends UserDefinedType[DenseVector] { + + override def sqlType: ArrayType = ArrayType(DoubleType, containsNull = false) + + override def serialize(obj: Any): Row = obj match { + case v: DenseVector => + val row: GenericMutableRow = new GenericMutableRow(v.size) + var i = 0 + while (i < v.size) { + row.setDouble(i, v(i)) + i += 1 + } + row + } + + override def deserialize(row: Row): DenseVector = { + val values = new Array[Double](row.length) + var i = 0 + while (i < row.length) { + values(i) = row.getDouble(i) + i += 1 + } + new DenseVector(values) + } +} + +/** + * User-defined type for [[SparseVector]] which allows easy interaction with SQL + * via [[org.apache.spark.sql.SchemaRDD]]. + */ +private[spark] class SparseVectorUDT extends UserDefinedType[SparseVector] { + + override def sqlType: StructType = StructType(Seq( + StructField("size", IntegerType, nullable = false), + StructField("indices", ArrayType(DoubleType, containsNull = false), nullable = false), + StructField("values", ArrayType(DoubleType, containsNull = false), nullable = false))) + + override def serialize(obj: Any): Row = obj match { + case v: SparseVector => + val nnz = v.indices.size + val row: GenericMutableRow = new GenericMutableRow(1 + 2 * nnz) + row.setInt(0, v.size) + var i = 0 + while (i < nnz) { + row.setInt(1 + i, v.indices(i)) + i += 1 + } + i = 0 + while (i < nnz) { + row.setDouble(1 + nnz + i, v.values(i)) + i += 1 + } + row + } + + override def deserialize(row: Row): SparseVector = { + require(row.length >= 1, + s"SparseVectorUDT.deserialize given row with length ${row.length} but requires length >= 1") + val vSize = row.getInt(0) + val nnz: Int = (row.length - 1) / 2 + require(nnz * 2 + 1 == row.length, + s"SparseVectorUDT.deserialize given row with non-matching indices, values lengths") + val indices = new Array[Int](nnz) + val values = new Array[Double](nnz) + var i = 0 + while (i < nnz) { + indices(i) = row.getInt(1 + i) + values(i) = row.getDouble(1 + nnz + i) + i += 1 + } + new SparseVector(vSize, indices, values) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 17c753c56681f..936b82a00f14f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -33,6 +33,8 @@ case class LabeledPoint(label: Double, features: Vector) { } } + + /** * Parser for [[org.apache.spark.mllib.regression.LabeledPoint]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 75eb3b4872475..a6e255434baef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -89,6 +89,7 @@ object ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = { val className: String = tpe.erasure.typeSymbol.asClass.fullName + println(s"schemaFor: className = $className") tpe match { case t if Utils.classIsLoadable(className) && Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => @@ -96,8 +97,10 @@ object ScalaReflection { // whereas className is from Scala reflection. This can make it hard to find classes // in some cases, such as when a class is enclosed in an object (in which case // Java appends a '$' to the object name but Scala does not). - UDTRegistry.registerType(t) - Schema(UDTRegistry.udtRegistry(t), nullable = true) + val udt = Utils.classForName(className) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + UDTRegistry.registerType(t, udt) + Schema(udt, nullable = true) case t if UDTRegistry.udtRegistry.contains(t) => Schema(UDTRegistry.udtRegistry(t), nullable = true) case t if t <:< typeOf[Option[_]] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala index a9be187ded96e..7c333e3bc6758 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.types.UserDefinedType /** * Global registry for user-defined types (UDTs). */ -private[sql] object UDTRegistry { +object UDTRegistry { /** Map: UserType --> UserDefinedType */ val udtRegistry = new mutable.HashMap[Any, UserDefinedType[_]]() @@ -35,14 +35,9 @@ private[sql] object UDTRegistry { * RDDs of user types and SchemaRDDs. * If this type has already been registered, this does nothing. */ - def registerType(userType: Type): Unit = { + def registerType(userType: Type, udt: UserDefinedType[_]): Unit = { // TODO: Check to see if type is built-in. Throw exception? - if (!UDTRegistry.udtRegistry.contains(userType)) { - val udt = - getClass.getClassLoader.loadClass(userType.typeSymbol.asClass.fullName) - .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() - UDTRegistry.udtRegistry(userType) = udt - } + UDTRegistry.udtRegistry(userType) = udt // TODO: Else: Should we check (assert) that udt is the same as what is in the registry? } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 15516afb95504..fd5f4abcbcd65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.LogicalRDD * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java) */ private[sql] trait SchemaRDDLike { - @transient val sqlContext: SQLContext + @transient def sqlContext: SQLContext @transient val baseLogicalPlan: LogicalPlan private[sql] def baseSchemaRDD: SchemaRDD From 85872f6e2fbb2385793b645a629ed26ee2e98cbc Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 23 Oct 2014 14:17:55 -0700 Subject: [PATCH 21/46] Allow schema calculation to be lazy, but ensure its available on executors. --- .../scala/org/apache/spark/sql/SchemaRDD.scala | 16 ++++++++++------ .../org/apache/spark/sql/hive/HiveContext.scala | 7 +------ 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 3e59becb0d143..a3212ff9ec6d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -117,14 +117,18 @@ class SchemaRDD( override def getPartitions: Array[Partition] = firstParent[Row].partitions - override protected def getDependencies: Seq[Dependency[_]] = + override protected def getDependencies: Seq[Dependency[_]] = { + schema // Force reification of the schema so it is available on executors. + List(new OneToOneDependency(queryExecution.toRdd)) + } - /** Returns the schema of this SchemaRDD (represented by a [[StructType]]). - * - * @group schema - */ - val schema: StructType = queryExecution.analyzed.schema + /** + * Returns the schema of this SchemaRDD (represented by a [[StructType]]). + * + * @group schema + */ + lazy val schema: StructType = queryExecution.analyzed.schema // ======================================================================= // Query DSL diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 15cc6067ecb05..9915a17882488 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -371,11 +371,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** Extends QueryExecution with hive specific features. */ protected[sql] abstract class QueryExecution extends super.QueryExecution { - override lazy val toRdd: RDD[Row] = { - val schema = StructType.fromAttributes(logical.output) - executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema)) - } - protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, ShortType, DateType, TimestampType, BinaryType) @@ -433,7 +428,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { command.executeCollect().map(_.head.toString) case other => - val result: Seq[Seq[Any]] = toRdd.collect().toSeq + val result: Seq[Seq[Any]] = toRdd.map(_.copy()).collect().toSeq // We need the types so we can output struct field names val types = analyzed.output.map(_.dataType) // Reformat to match hive tab delimited output. From f025035b77b6fa21a3cda3f05f2875895013bdfa Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 23 Oct 2014 17:46:52 -0700 Subject: [PATCH 22/46] Cleanups before PR. Added new tests --- .../spark/examples/mllib/DatasetExample.scala | 6 +- .../apache/spark/mllib/linalg/Vectors.scala | 6 ++ .../spark/mllib/regression/LabeledPoint.scala | 2 - .../apache/spark/mllib/rdd/DatasetSuite.scala | 84 +++++++++++++++++++ .../spark/sql/catalyst/ScalaReflection.scala | 11 +-- .../spark/sql/catalyst/UDTRegistry.scala | 7 +- .../annotation/SQLUserDefinedType.java | 7 +- .../spark/sql/catalyst/types/dataTypes.scala | 5 +- .../sql/catalyst/ScalaReflectionSuite.scala | 1 - .../org/apache/spark/sql/SQLContext.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 40 ++++----- .../spark/sql/UserDefinedTypeSuite.scala | 54 ++++++++---- .../apache/spark/sql/hive/HiveContext.scala | 3 +- 13 files changed, 166 insertions(+), 62 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/rdd/DatasetSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index 0728a0021497c..277393c6e1718 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} /** * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with * {{{ - * ./bin/run-example org.apache.spark.examples.mllib.Dataset [options] + * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ @@ -43,7 +43,7 @@ object DatasetExample { def main(args: Array[String]) { val defaultParams = Params() - val parser = new OptionParser[Params]("Dataset") { + val parser = new OptionParser[Params]("DatasetExample") { head("Dataset: an example app using SchemaRDD as a Dataset for ML.") opt[String]("input") .text(s"input path to dataset") @@ -65,7 +65,7 @@ object DatasetExample { def run(params: Params) { - val conf = new SparkConf().setAppName(s"Dataset with $params") + val conf = new SparkConf().setAppName(s"DatasetExample with $params") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) import sqlContext._ // for implicit conversions diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index d473eff1202dc..e9c57831710db 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -86,7 +86,13 @@ sealed trait Vector extends Serializable { */ object Vectors { + // Note: Explicit registration is only needed for Vector and SparseVector; + // the annotation works for DenseVector. UDTRegistry.registerType(scala.reflect.runtime.universe.typeOf[Vector], new VectorUDT()) + UDTRegistry.registerType(scala.reflect.runtime.universe.typeOf[DenseVector], + new DenseVectorUDT()) + UDTRegistry.registerType(scala.reflect.runtime.universe.typeOf[SparseVector], + new SparseVectorUDT()) /** * Creates a dense vector from its values. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 936b82a00f14f..17c753c56681f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -33,8 +33,6 @@ case class LabeledPoint(label: Double, features: Vector) { } } - - /** * Parser for [[org.apache.spark.mllib.regression.LabeledPoint]]. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/DatasetSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/DatasetSuite.scala new file mode 100644 index 0000000000000..784ed2b3cc37f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/DatasetSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{Vectors, DenseVector, SparseVector, Vector} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} + + +private case class DenseVectorLabeledPoint(label: Double, features: DenseVector) +private case class SparseVectorLabeledPoint(label: Double, features: SparseVector) + +class DatasetSuite extends FunSuite with LocalSparkContext { + + test("SQL and Vector") { + val sqlContext = new SQLContext(sc) + import sqlContext._ + val points = Seq( + LabeledPoint(1.0, Vectors.dense(Array(1.0, 2.0))), + LabeledPoint(2.0, Vectors.dense(Array(3.0, 6.0))), + LabeledPoint(3.0, Vectors.dense(Array(3.0, 3.0))), + LabeledPoint(4.0, Vectors.dense(Array(4.0, 8.0)))) + val data: RDD[LabeledPoint] = sc.parallelize(points) + val labels = data.select('label).map { case Row(label: Double) => label }.collect().toSet + assert(labels == Set(1.0, 2.0, 3.0, 4.0)) + val features = data.select('features).map { case Row(features: Vector) => features }.collect() + assert(features.size === 4) + assert(features.forall(_.size == 2)) + } + + test("SQL and DenseVector") { + val sqlContext = new SQLContext(sc) + import sqlContext._ + val points = Seq( + DenseVectorLabeledPoint(1.0, new DenseVector(Array(1.0, 2.0))), + DenseVectorLabeledPoint(2.0, new DenseVector(Array(3.0, 6.0))), + DenseVectorLabeledPoint(3.0, new DenseVector(Array(3.0, 3.0))), + DenseVectorLabeledPoint(4.0, new DenseVector(Array(4.0, 8.0)))) + val data: RDD[DenseVectorLabeledPoint] = sc.parallelize(points) + val labels = data.select('label).map { case Row(label: Double) => label }.collect().toSet + assert(labels == Set(1.0, 2.0, 3.0, 4.0)) + val features = + data.select('features).map { case Row(features: DenseVector) => features }.collect() + assert(features.size === 4) + assert(features.forall(_.size == 2)) + } + + test("SQL and SparseVector") { + val sqlContext = new SQLContext(sc) + import sqlContext._ + val vSize = 2 + val points = Seq( + SparseVectorLabeledPoint(1.0, new SparseVector(vSize, Array(0, 1), Array(1.0, 2.0))), + SparseVectorLabeledPoint(2.0, new SparseVector(vSize, Array(0, 1), Array(3.0, 6.0))), + SparseVectorLabeledPoint(3.0, new SparseVector(vSize, Array(0, 1), Array(3.0, 3.0))), + SparseVectorLabeledPoint(4.0, new SparseVector(vSize, Array(0, 1), Array(4.0, 8.0)))) + val data: RDD[SparseVectorLabeledPoint] = sc.parallelize(points) + val labels = data.select('label).map { case Row(label: Double) => label }.collect().toSet + assert(labels == Set(1.0, 2.0, 3.0, 4.0)) + val features = + data.select('features).map { case Row(features: SparseVector) => features }.collect() + assert(features.size === 4) + assert(features.forall(_.size == 2)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index a6e255434baef..55f384efe2418 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -76,11 +76,9 @@ object ScalaReflection { } /** Returns a Sequence of attributes for the given case class type. */ - def attributesFor[T: TypeTag]: Seq[Attribute] = { - schemaFor[T] match { - case Schema(s: StructType, _) => - s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) - } + def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { + case Schema(s: StructType, _) => + s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ @@ -89,7 +87,6 @@ object ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = { val className: String = tpe.erasure.typeSymbol.asClass.fullName - println(s"schemaFor: className = $className") tpe match { case t if Utils.classIsLoadable(className) && Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => @@ -180,8 +177,6 @@ object ScalaReflection { * for the the data in the sequence. */ def asRelation: LocalRelation = { - // Pass empty map to attributesFor since this method is only used for debugging Catalyst, - // not used with SparkSQL. val output = attributesFor[A] LocalRelation(output, data) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala index 7c333e3bc6758..5c1fc01efddf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala @@ -20,12 +20,16 @@ package org.apache.spark.sql.catalyst import scala.collection.mutable import scala.reflect.runtime.universe._ -import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.types.UserDefinedType /** + * ::DeveloperApi:: * Global registry for user-defined types (UDTs). + * + * WARNING: UDTs are currently only supported from Scala. */ +@DeveloperApi object UDTRegistry { /** Map: UserType --> UserDefinedType */ val udtRegistry = new mutable.HashMap[Any, UserDefinedType[_]]() @@ -38,6 +42,5 @@ object UDTRegistry { def registerType(userType: Type, udt: UserDefinedType[_]): Unit = { // TODO: Check to see if type is built-in. Throw exception? UDTRegistry.udtRegistry(userType) = udt - // TODO: Else: Should we check (assert) that udt is the same as what is in the registry? } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java index fa909a9eb1b3b..fd815b3b6207f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java @@ -17,18 +17,21 @@ package org.apache.spark.sql.catalyst.annotation; +import java.lang.annotation.*; + import org.apache.spark.annotation.DeveloperApi; import org.apache.spark.sql.catalyst.types.UserDefinedType; -import java.lang.annotation.*; - /** + * ::DeveloperApi:: * A user-defined type which can be automatically recognized by a SQLContext and registered. * * WARNING: This annotation will only work if both Java and Scala reflection return the same class * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class * is enclosed in an object (a singleton). In these cases, the UDT must be registered * manually. + * + * WARNING: UDTs are currently only supported from Scala. */ // TODO: Should I used @Documented ? @DeveloperApi diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 0c9c73af24d21..834cfffb0b3cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -569,6 +569,7 @@ case class MapType( } /** + * ::DeveloperApi:: * The data type for User Defined Types. */ @DeveloperApi @@ -578,8 +579,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { def sqlType: DataType /** Convert the user type to a Row object */ - // TODO: Can we make this take obj: UserType? - // The issue is in ScalaReflection.convertToCatalyst, where we need to convert Any to UserType. + // TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, + // where we need to convert Any to UserType. def serialize(obj: Any): Row /** Convert a Row object to the user type */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 0fe3cdeeba53c..7219e5b833771 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -252,5 +252,4 @@ class ScalaReflectionSuite extends FunSuite { convertToCatalyst(primitiveData, dataType)) assert(convertToCatalyst(data, dataType) === convertedData) } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4df38a88ebcd0..173d7f5af05cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import scala.language.implicitConversions -import scala.reflect.runtime.universe.{TypeTag, typeTag} +import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 73dac52452f23..6befe1b755cc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -737,32 +737,28 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("throw errors for non-aggregate attributes with aggregation") { - try { - def checkAggregation(query: String, isInvalidQuery: Boolean = true) { - val logicalPlan = sql(query).queryExecution.logical - - if (isInvalidQuery) { - val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed) - assert( - e.getMessage.startsWith("Expression not in GROUP BY"), - "Non-aggregate attribute(s) not detected\n" + logicalPlan) - } else { - // Should not throw - sql(query).queryExecution.analyzed - } + def checkAggregation(query: String, isInvalidQuery: Boolean = true) { + val logicalPlan = sql(query).queryExecution.logical + + if (isInvalidQuery) { + val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed) + assert( + e.getMessage.startsWith("Expression not in GROUP BY"), + "Non-aggregate attribute(s) not detected\n" + logicalPlan) + } else { + // Should not throw + sql(query).queryExecution.analyzed } + } - checkAggregation("SELECT key, COUNT(*) FROM testData") - checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false) + checkAggregation("SELECT key, COUNT(*) FROM testData") + checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false) - checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key") - checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false) + checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key") + checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false) - checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1") - checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) - } catch { - case e: Exception => println(e.getStackTraceString) - } + checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1") + checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) } test("Test to check we can use Long.MinValue") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index bcbae7cd50aec..f008104391dff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -23,23 +23,23 @@ import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.test.TestSQLContext._ -@SQLUserDefinedType(udt = classOf[DenseVectorUDT]) -class DenseVector(val data: Array[Double]) extends Serializable { +@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) +class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { - case v: DenseVector => + case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data) case _ => false } } -case class LabeledPoint(label: Double, features: DenseVector) +case class MyLabeledPoint(label: Double, features: MyDenseVector) -class DenseVectorUDT extends UserDefinedType[DenseVector] { +class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { override def sqlType: ArrayType = ArrayType(DoubleType, containsNull = false) override def serialize(obj: Any): Row = obj match { - case features: DenseVector => + case features: MyDenseVector => val row: GenericMutableRow = new GenericMutableRow(features.data.length) var i = 0 while (i < features.data.length) { @@ -49,8 +49,8 @@ class DenseVectorUDT extends UserDefinedType[DenseVector] { row } - override def deserialize(row: Row): DenseVector = { - val features = new DenseVector(new Array[Double](row.length)) + override def deserialize(row: Row): MyDenseVector = { + val features = new MyDenseVector(new Array[Double](row.length)) var i = 0 while (i < row.length) { features.data(i) = row.getDouble(i) @@ -60,13 +60,33 @@ class DenseVectorUDT extends UserDefinedType[DenseVector] { } } +object UserDefinedTypeSuiteObject { + + class ClassInObject(val dv: MyDenseVector) extends Serializable + + case class MyLabeledPointInObject(label: Double, features: ClassInObject) + + class ClassInObjectUDT extends UserDefinedType[ClassInObject] { + + override def sqlType: ArrayType = ArrayType(DoubleType, containsNull = false) + + private val dvUDT = new MyDenseVectorUDT() + + override def serialize(obj: Any): Row = obj match { + case cio: ClassInObject => dvUDT.serialize(cio) + } + + override def deserialize(row: Row): ClassInObject = new ClassInObject(dvUDT.deserialize(row)) + } +} + class UserDefinedTypeSuite extends QueryTest { - test("register user type: DenseVector for LabeledPoint") { + test("register user type: MyDenseVector for MyLabeledPoint") { val points = Seq( - LabeledPoint(1.0, new DenseVector(Array(0.1, 1.0))), - LabeledPoint(0.0, new DenseVector(Array(0.2, 2.0)))) - val pointsRDD: RDD[LabeledPoint] = sparkContext.parallelize(points) + MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) + val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() @@ -74,12 +94,12 @@ class UserDefinedTypeSuite extends QueryTest { assert(labelsArrays.contains(1.0)) assert(labelsArrays.contains(0.0)) - val features: RDD[DenseVector] = - pointsRDD.select('features).map { case Row(v: DenseVector) => v } - val featuresArrays: Array[DenseVector] = features.collect() + val features: RDD[MyDenseVector] = + pointsRDD.select('features).map { case Row(v: MyDenseVector) => v } + val featuresArrays: Array[MyDenseVector] = features.collect() assert(featuresArrays.size === 2) - assert(featuresArrays.contains(new DenseVector(Array(0.1, 1.0)))) - assert(featuresArrays.contains(new DenseVector(Array(0.2, 2.0)))) + assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) + assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 9915a17882488..c0f06623d1afb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -121,8 +121,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * @tparam A A case class that is used to describe the schema of the table to be created. */ def createTable[A <: Product : TypeTag](tableName: String, allowExisting: Boolean = true) { - catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], - allowExisting) + catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], allowExisting) } /** From 51e5282c346d58c9e433c90d019fe35825a2fec1 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 24 Oct 2014 11:10:21 -0700 Subject: [PATCH 23/46] fixed 1 test --- .../org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 7219e5b833771..22619df6cd023 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -247,7 +247,7 @@ class ScalaReflectionSuite extends FunSuite { val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true) val data = OptionalData(Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(true), Some(primitiveData)) - val dataType = schemaFor[PrimitiveData].dataType + val dataType = schemaFor[OptionalData].dataType val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true, convertToCatalyst(primitiveData, dataType)) assert(convertToCatalyst(data, dataType) === convertedData) From 63626a4f2a62e03f22c9b0bc453b754ef5858988 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 24 Oct 2014 11:25:44 -0700 Subject: [PATCH 24/46] Updated ScalaReflectionsSuite per @marmbrus suggestions --- .../apache/spark/sql/catalyst/ScalaReflectionSuite.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 22619df6cd023..ddc3d44869c98 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.types._ case class PrimitiveData( @@ -245,11 +246,11 @@ class ScalaReflectionSuite extends FunSuite { test("convert Option[Product] to catalyst") { val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true) - val data = OptionalData(Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(true), + val data = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), Some(primitiveData)) val dataType = schemaFor[OptionalData].dataType - val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true, - convertToCatalyst(primitiveData, dataType)) + val convertedData = Row(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true, + Row(1, 1, 1, 1, 1, 1, true)) assert(convertToCatalyst(data, dataType) === convertedData) } } From 759af7ac349579967ea5929f5b019097318c28ee Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 27 Oct 2014 14:25:50 -0700 Subject: [PATCH 25/46] Added more doc to UserDefineType --- .../apache/spark/sql/catalyst/types/dataTypes.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 834cfffb0b3cc..eba8274f491bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -570,7 +570,17 @@ case class MapType( /** * ::DeveloperApi:: - * The data type for User Defined Types. + * The data type for User Defined Types (UDTs). + * + * This interface allows a user to make their own classes more interoperable with SparkSQL; + * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create a SchemaRDD + * which has class X in the schema. + * + * For SparkSQL to recognize UDTs, the UDT must be registered in + * [[org.apache.spark.sql.catalyst.UDTRegistry]]. This registration can be done either + * explicitly by calling [[org.apache.spark.sql.catalyst.UDTRegistry.registerType()]] before using + * the UDT with SparkSQL, or implicitly by annotating the UDT with + * [[org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType]]. */ @DeveloperApi abstract class UserDefinedType[UserType] extends DataType with Serializable { From db16139ba5030fc0e6c4455a3adc1abd853211a5 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 27 Oct 2014 18:05:34 -0700 Subject: [PATCH 26/46] Added more doc for UserDefinedType. Removed unused code in Suite --- .../org/apache/spark/sql/catalyst/types/dataTypes.scala | 3 +++ .../scala/org/apache/spark/sql/UserDefinedTypeSuite.scala | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index eba8274f491bc..6d6870aa7ff90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -581,6 +581,9 @@ case class MapType( * explicitly by calling [[org.apache.spark.sql.catalyst.UDTRegistry.registerType()]] before using * the UDT with SparkSQL, or implicitly by annotating the UDT with * [[org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType]]. + * + * The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD. + * The conversion via `deserialize` occurs when reading from a `SchemaRDD`. */ @DeveloperApi abstract class UserDefinedType[UserType] extends DataType with Serializable { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index f008104391dff..3f94dfbd6b540 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.UDTRegistry import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types.UserDefinedType @@ -60,6 +61,10 @@ class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } } +/* +// This is to test registering a UDT which is defined within an object (where Java and Scala +// reflection use different class names). This functionality is currently not supported but +// should be later on. object UserDefinedTypeSuiteObject { class ClassInObject(val dv: MyDenseVector) extends Serializable @@ -79,6 +84,7 @@ object UserDefinedTypeSuiteObject { override def deserialize(row: Row): ClassInObject = new ClassInObject(dvUDT.deserialize(row)) } } +*/ class UserDefinedTypeSuite extends QueryTest { From cfbc3215332571a4cb033a27de495d9865c8a4dc Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Oct 2014 03:44:11 -0700 Subject: [PATCH 27/46] support UDT in parquet --- .../spark/examples/mllib/DatasetExample.scala | 11 +++ .../apache/spark/mllib/linalg/Vectors.scala | 67 +++++++------------ .../spark/sql/catalyst/ScalaReflection.scala | 16 ++--- .../spark/sql/catalyst/UDTRegistry.scala | 6 +- .../annotation/SQLUserDefinedType.java | 4 +- .../spark/sql/catalyst/types/dataTypes.scala | 23 +++++-- .../spark/sql/parquet/ParquetConverter.scala | 3 + .../sql/parquet/ParquetTableSupport.scala | 4 ++ .../spark/sql/parquet/ParquetTypes.scala | 3 + .../spark/sql/UserDefinedTypeSuite.scala | 8 ++- 10 files changed, 81 insertions(+), 64 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index 277393c6e1718..344da2c90c94b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -79,6 +79,7 @@ object DatasetExample { // Convert input data to SchemaRDD explicitly. val schemaRDD: SchemaRDD = origData + println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}") println(s"Converted to SchemaRDD with ${schemaRDD.count()} records") // Select columns, using implicit conversion to SchemaRDD. @@ -95,6 +96,16 @@ object DatasetExample { (sum1, sum2) => sum1.merge(sum2)) println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + schemaRDD.saveAsParquetFile("/tmp/dataset") + val newDataset = sqlContext.parquetFile("/tmp/dataset") + + println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") + val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } + val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + sc.stop() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index e9c57831710db..0fe78f0f08265 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -89,8 +89,6 @@ object Vectors { // Note: Explicit registration is only needed for Vector and SparseVector; // the annotation works for DenseVector. UDTRegistry.registerType(scala.reflect.runtime.universe.typeOf[Vector], new VectorUDT()) - UDTRegistry.registerType(scala.reflect.runtime.universe.typeOf[DenseVector], - new DenseVectorUDT()) UDTRegistry.registerType(scala.reflect.runtime.universe.typeOf[SparseVector], new SparseVectorUDT()) @@ -204,7 +202,7 @@ object Vectors { /** * A dense vector represented by a value array. */ -@SQLUserDefinedType(udt = classOf[DenseVectorUDT]) +@SQLUserDefinedType(serdes = classOf[DenseVectorUDT]) class DenseVector(val values: Array[Double]) extends Vector { override def size: Int = values.length @@ -261,7 +259,7 @@ class SparseVector( * User-defined type for [[Vector]] which allows easy interaction with SQL * via [[org.apache.spark.sql.SchemaRDD]]. */ -private[spark] class VectorUDT extends UserDefinedType[Vector] { +private[spark] class VectorUDT extends UserDefinedTypeSerDes[Vector] { /** * vectorType: 0 = dense, 1 = sparse. @@ -269,8 +267,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { */ override def sqlType: StructType = StructType(Seq( StructField("vectorType", ByteType, nullable = false), - StructField("dense", new DenseVectorUDT(), nullable = true), - StructField("sparse", new SparseVectorUDT(), nullable = true))) + StructField("dense", new UserDefinedType(new DenseVectorUDT), nullable = true), + StructField("sparse", new UserDefinedType(new SparseVectorUDT), nullable = true))) override def serialize(obj: Any): Row = { val row = new GenericMutableRow(3) @@ -298,64 +296,52 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { new SparseVectorUDT().deserialize(row.getAs[Row](2)) } } + + override def userType: Class[Vector] = classOf[Vector] } /** * User-defined type for [[DenseVector]] which allows easy interaction with SQL * via [[org.apache.spark.sql.SchemaRDD]]. */ -private[spark] class DenseVectorUDT extends UserDefinedType[DenseVector] { +private[spark] class DenseVectorUDT extends UserDefinedTypeSerDes[DenseVector] { override def sqlType: ArrayType = ArrayType(DoubleType, containsNull = false) override def serialize(obj: Any): Row = obj match { case v: DenseVector => - val row: GenericMutableRow = new GenericMutableRow(v.size) - var i = 0 - while (i < v.size) { - row.setDouble(i, v(i)) - i += 1 - } + val row: GenericMutableRow = new GenericMutableRow(1) + row.update(0, v.values.toSeq) row } override def deserialize(row: Row): DenseVector = { - val values = new Array[Double](row.length) - var i = 0 - while (i < row.length) { - values(i) = row.getDouble(i) - i += 1 - } + val values = row.getAs[Seq[Double]](0).toArray new DenseVector(values) } + + override def userType: Class[DenseVector] = classOf[DenseVector] } /** * User-defined type for [[SparseVector]] which allows easy interaction with SQL * via [[org.apache.spark.sql.SchemaRDD]]. */ -private[spark] class SparseVectorUDT extends UserDefinedType[SparseVector] { +private[spark] class SparseVectorUDT extends UserDefinedTypeSerDes[SparseVector] { override def sqlType: StructType = StructType(Seq( StructField("size", IntegerType, nullable = false), - StructField("indices", ArrayType(DoubleType, containsNull = false), nullable = false), + StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = false), StructField("values", ArrayType(DoubleType, containsNull = false), nullable = false))) override def serialize(obj: Any): Row = obj match { case v: SparseVector => - val nnz = v.indices.size - val row: GenericMutableRow = new GenericMutableRow(1 + 2 * nnz) + val row: GenericMutableRow = new GenericMutableRow(3) row.setInt(0, v.size) - var i = 0 - while (i < nnz) { - row.setInt(1 + i, v.indices(i)) - i += 1 - } - i = 0 - while (i < nnz) { - row.setDouble(1 + nnz + i, v.values(i)) - i += 1 - } + row.update(1, v.indices.toSeq) + row.update(2, v.values.toSeq) + row + case row: Row => row } @@ -363,17 +349,10 @@ private[spark] class SparseVectorUDT extends UserDefinedType[SparseVector] { require(row.length >= 1, s"SparseVectorUDT.deserialize given row with length ${row.length} but requires length >= 1") val vSize = row.getInt(0) - val nnz: Int = (row.length - 1) / 2 - require(nnz * 2 + 1 == row.length, - s"SparseVectorUDT.deserialize given row with non-matching indices, values lengths") - val indices = new Array[Int](nnz) - val values = new Array[Double](nnz) - var i = 0 - while (i < nnz) { - indices(i) = row.getInt(1 + i) - values(i) = row.getDouble(1 + nnz + i) - i += 1 - } + val indices = row.getAs[Seq[Int]](1).toArray + val values = row.getAs[Seq[Double]](2).toArray new SparseVector(vSize, indices, values) } + + override def userType: Class[SparseVector] = classOf[SparseVector] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 55f384efe2418..701d2213996e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -54,7 +54,7 @@ object ScalaReflection { convertToCatalyst(elem, field.dataType) }.toArray) case (d: BigDecimal, _) => Decimal(d) - case (udt, udtType: UserDefinedType[_]) => udtType.serialize(udt) + case (udt, udtType: UserDefinedTypeSerDes[_]) => udtType.serialize(udt) case (other, _) => other } @@ -64,8 +64,8 @@ object ScalaReflection { case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) } - case (d: Decimal, DecimalType) => d.toBigDecimal - case (udt: Row, udtType: UserDefinedType[_]) => udtType.deserialize(udt) + case (d: Decimal, _: DecimalType) => d.toBigDecimal + case (r: Row, udt: UserDefinedType[_]) => udt.serdes.deserialize(r) case (other, _) => other } @@ -94,12 +94,12 @@ object ScalaReflection { // whereas className is from Scala reflection. This can make it hard to find classes // in some cases, such as when a class is enclosed in an object (in which case // Java appends a '$' to the object name but Scala does not). - val udt = Utils.classForName(className) - .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() - UDTRegistry.registerType(t, udt) - Schema(udt, nullable = true) + val serdes = Utils.classForName(className) + .getAnnotation(classOf[SQLUserDefinedType]).serdes().newInstance() + UDTRegistry.registerType(t, serdes) + Schema(new UserDefinedType(serdes), nullable = true) case t if UDTRegistry.udtRegistry.contains(t) => - Schema(UDTRegistry.udtRegistry(t), nullable = true) + Schema(new UserDefinedType(UDTRegistry.udtRegistry(t)), nullable = true) case t if t <:< typeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t Schema(schemaFor(optType).dataType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala index 5c1fc01efddf6..7f7ee4452d1fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import scala.reflect.runtime.universe._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.types.UserDefinedType +import org.apache.spark.sql.catalyst.types.UserDefinedTypeSerDes /** * ::DeveloperApi:: @@ -32,14 +32,14 @@ import org.apache.spark.sql.catalyst.types.UserDefinedType @DeveloperApi object UDTRegistry { /** Map: UserType --> UserDefinedType */ - val udtRegistry = new mutable.HashMap[Any, UserDefinedType[_]]() + val udtRegistry = new mutable.HashMap[Any, UserDefinedTypeSerDes[_]]() /** * Register a user-defined type and its serializer, to allow automatic conversion between * RDDs of user types and SchemaRDDs. * If this type has already been registered, this does nothing. */ - def registerType(userType: Type, udt: UserDefinedType[_]): Unit = { + def registerType(userType: Type, udt: UserDefinedTypeSerDes[_]): Unit = { // TODO: Check to see if type is built-in. Throw exception? UDTRegistry.udtRegistry(userType) = udt } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java index fd815b3b6207f..ba0fc666fad52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java @@ -20,7 +20,7 @@ import java.lang.annotation.*; import org.apache.spark.annotation.DeveloperApi; -import org.apache.spark.sql.catalyst.types.UserDefinedType; +import org.apache.spark.sql.catalyst.types.UserDefinedTypeSerDes; /** * ::DeveloperApi:: @@ -38,5 +38,5 @@ @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface SQLUserDefinedType { - Class > udt(); + Class > serdes(); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 6d6870aa7ff90..cf2735ebaf932 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -30,7 +30,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.{UDTRegistry, ScalaReflectionLock} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Row} import org.apache.spark.sql.catalyst.types.decimal._ import org.apache.spark.sql.catalyst.util.Metadata @@ -68,6 +68,13 @@ object DataType { ("fields", JArray(fields)), ("type", JString("struct"))) => StructType(fields.map(parseStructField)) + + case JSortedObject( + ("serdes", JString(serdesClass)), + ("type", JString("udt"))) => { + val serdes = Class.forName(serdesClass).newInstance().asInstanceOf[UserDefinedTypeSerDes[_]] + new UserDefinedType(serdes) + } } private def parseStructField(json: JValue): StructField = json match { @@ -573,7 +580,7 @@ case class MapType( * The data type for User Defined Types (UDTs). * * This interface allows a user to make their own classes more interoperable with SparkSQL; - * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create a SchemaRDD + * e.g., by creating a [[UserDefinedTypeSerDes]] for a class X, it becomes possible to create a SchemaRDD * which has class X in the schema. * * For SparkSQL to recognize UDTs, the UDT must be registered in @@ -586,7 +593,9 @@ case class MapType( * The conversion via `deserialize` occurs when reading from a `SchemaRDD`. */ @DeveloperApi -abstract class UserDefinedType[UserType] extends DataType with Serializable { +abstract class UserDefinedTypeSerDes[UserType] extends Serializable { + + def userType: Class[UserType] /** Underlying storage type for this UDT used by SparkSQL */ def sqlType: DataType @@ -598,6 +607,12 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** Convert a Row object to the user type */ def deserialize(row: Row): UserType +} - def simpleString: String = "udt" +case class UserDefinedType[UserType](serdes: UserDefinedTypeSerDes[UserType]) + extends DataType with Serializable { + override private[sql] def jsonValue: JValue = { + ("type" -> "udt") ~ + ("serdes" -> serdes.getClass.getName) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 08feced61a899..4df8a90adcd42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -99,6 +99,9 @@ private[sql] object CatalystConverter { fieldIndex, parent) } + case UserDefinedType(serdes) => { + createConverter(field.copy(dataType = serdes.sqlType), fieldIndex, parent) + } // Strings, Shorts and Bytes do not have a corresponding type in Parquet // so we need to treat them separately case StringType => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 2a5f23b24e8e8..d4c0e591cbce9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -183,6 +183,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case t @ StructType(_) => writeStruct( t, value.asInstanceOf[CatalystConverter.StructScalaType[_]]) + case UserDefinedType(serdes) => { + println(value.getClass) + writeValue(serdes.sqlType, serdes.serialize(value)) + } case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index e5077de8dd908..5a27616bbff67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -337,6 +337,9 @@ private[parquet] object ParquetTypesConverter extends Logging { parquetKeyType, parquetValueType) } + case UserDefinedType(serdes) => { + fromDataType(serdes.sqlType, name, nullable, inArray) + } case _ => sys.error(s"Unsupported datatype $ctype") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 3f94dfbd6b540..8202aa462f8b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -21,10 +21,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.UDTRegistry import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.types.UserDefinedType +import org.apache.spark.sql.catalyst.types.UserDefinedTypeSerDes import org.apache.spark.sql.test.TestSQLContext._ -@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) +@SQLUserDefinedType(serdes = classOf[MyDenseVectorUDT]) class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { case v: MyDenseVector => @@ -35,7 +35,9 @@ class MyDenseVector(val data: Array[Double]) extends Serializable { case class MyLabeledPoint(label: Double, features: MyDenseVector) -class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { +class MyDenseVectorUDT extends UserDefinedTypeSerDes[MyDenseVector] { + + override def userType: Class[MyDenseVector] = classOf[MyDenseVector] override def sqlType: ArrayType = ArrayType(DoubleType, containsNull = false) From 3143ac304a9cbb81e766d915863d343cd81dc673 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Oct 2014 12:19:10 -0700 Subject: [PATCH 28/46] remove unnecessary changes --- .../apache/spark/mllib/linalg/Vectors.scala | 23 +++++++-------- .../spark/sql/catalyst/ScalaReflection.scala | 14 ++++----- .../spark/sql/catalyst/UDTRegistry.scala | 6 ++-- .../annotation/SQLUserDefinedType.java | 4 +-- .../spark/sql/catalyst/types/dataTypes.scala | 21 +++++--------- .../spark/sql/parquet/ParquetConverter.scala | 4 +-- .../sql/parquet/ParquetTableSupport.scala | 5 +--- .../spark/sql/parquet/ParquetTypes.scala | 4 +-- .../spark/sql/UserDefinedTypeSuite.scala | 29 ++++++------------- 9 files changed, 45 insertions(+), 65 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 0fe78f0f08265..4680ca0b92c2c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -202,7 +202,7 @@ object Vectors { /** * A dense vector represented by a value array. */ -@SQLUserDefinedType(serdes = classOf[DenseVectorUDT]) +@SQLUserDefinedType(udt = classOf[DenseVectorUDT]) class DenseVector(val values: Array[Double]) extends Vector { override def size: Int = values.length @@ -259,7 +259,7 @@ class SparseVector( * User-defined type for [[Vector]] which allows easy interaction with SQL * via [[org.apache.spark.sql.SchemaRDD]]. */ -private[spark] class VectorUDT extends UserDefinedTypeSerDes[Vector] { +private[spark] class VectorUDT extends UserDefinedType[Vector] { /** * vectorType: 0 = dense, 1 = sparse. @@ -267,8 +267,8 @@ private[spark] class VectorUDT extends UserDefinedTypeSerDes[Vector] { */ override def sqlType: StructType = StructType(Seq( StructField("vectorType", ByteType, nullable = false), - StructField("dense", new UserDefinedType(new DenseVectorUDT), nullable = true), - StructField("sparse", new UserDefinedType(new SparseVectorUDT), nullable = true))) + StructField("dense", new DenseVectorUDT, nullable = true), + StructField("sparse", new SparseVectorUDT, nullable = true))) override def serialize(obj: Any): Row = { val row = new GenericMutableRow(3) @@ -297,16 +297,17 @@ private[spark] class VectorUDT extends UserDefinedTypeSerDes[Vector] { } } - override def userType: Class[Vector] = classOf[Vector] + // override def userType: Class[Vector] = classOf[Vector] } /** * User-defined type for [[DenseVector]] which allows easy interaction with SQL * via [[org.apache.spark.sql.SchemaRDD]]. */ -private[spark] class DenseVectorUDT extends UserDefinedTypeSerDes[DenseVector] { +private[spark] class DenseVectorUDT extends UserDefinedType[DenseVector] { - override def sqlType: ArrayType = ArrayType(DoubleType, containsNull = false) + override def sqlType: StructType = StructType(Seq( + StructField("values", ArrayType(DoubleType, containsNull = false), nullable = false))) override def serialize(obj: Any): Row = obj match { case v: DenseVector => @@ -320,14 +321,14 @@ private[spark] class DenseVectorUDT extends UserDefinedTypeSerDes[DenseVector] { new DenseVector(values) } - override def userType: Class[DenseVector] = classOf[DenseVector] + // override def userType: Class[DenseVector] = classOf[DenseVector] } /** * User-defined type for [[SparseVector]] which allows easy interaction with SQL * via [[org.apache.spark.sql.SchemaRDD]]. */ -private[spark] class SparseVectorUDT extends UserDefinedTypeSerDes[SparseVector] { +private[spark] class SparseVectorUDT extends UserDefinedType[SparseVector] { override def sqlType: StructType = StructType(Seq( StructField("size", IntegerType, nullable = false), @@ -341,8 +342,6 @@ private[spark] class SparseVectorUDT extends UserDefinedTypeSerDes[SparseVector] row.update(1, v.indices.toSeq) row.update(2, v.values.toSeq) row - case row: Row => - row } override def deserialize(row: Row): SparseVector = { @@ -354,5 +353,5 @@ private[spark] class SparseVectorUDT extends UserDefinedTypeSerDes[SparseVector] new SparseVector(vSize, indices, values) } - override def userType: Class[SparseVector] = classOf[SparseVector] + // override def userType: Class[SparseVector] = classOf[SparseVector] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 701d2213996e6..db1a8924c008e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -54,7 +54,7 @@ object ScalaReflection { convertToCatalyst(elem, field.dataType) }.toArray) case (d: BigDecimal, _) => Decimal(d) - case (udt, udtType: UserDefinedTypeSerDes[_]) => udtType.serialize(udt) + case (obj, udt: UserDefinedType[_]) => udt.serialize(obj) case (other, _) => other } @@ -65,7 +65,7 @@ object ScalaReflection { convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) } case (d: Decimal, _: DecimalType) => d.toBigDecimal - case (r: Row, udt: UserDefinedType[_]) => udt.serdes.deserialize(r) + case (r: Row, udt: UserDefinedType[_]) => udt.deserialize(r) case (other, _) => other } @@ -94,12 +94,12 @@ object ScalaReflection { // whereas className is from Scala reflection. This can make it hard to find classes // in some cases, such as when a class is enclosed in an object (in which case // Java appends a '$' to the object name but Scala does not). - val serdes = Utils.classForName(className) - .getAnnotation(classOf[SQLUserDefinedType]).serdes().newInstance() - UDTRegistry.registerType(t, serdes) - Schema(new UserDefinedType(serdes), nullable = true) + val udt = Utils.classForName(className) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + UDTRegistry.registerType(t, udt) + Schema(udt, nullable = true) case t if UDTRegistry.udtRegistry.contains(t) => - Schema(new UserDefinedType(UDTRegistry.udtRegistry(t)), nullable = true) + Schema(UDTRegistry.udtRegistry(t), nullable = true) case t if t <:< typeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t Schema(schemaFor(optType).dataType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala index 7f7ee4452d1fc..5c1fc01efddf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import scala.reflect.runtime.universe._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.types.UserDefinedTypeSerDes +import org.apache.spark.sql.catalyst.types.UserDefinedType /** * ::DeveloperApi:: @@ -32,14 +32,14 @@ import org.apache.spark.sql.catalyst.types.UserDefinedTypeSerDes @DeveloperApi object UDTRegistry { /** Map: UserType --> UserDefinedType */ - val udtRegistry = new mutable.HashMap[Any, UserDefinedTypeSerDes[_]]() + val udtRegistry = new mutable.HashMap[Any, UserDefinedType[_]]() /** * Register a user-defined type and its serializer, to allow automatic conversion between * RDDs of user types and SchemaRDDs. * If this type has already been registered, this does nothing. */ - def registerType(userType: Type, udt: UserDefinedTypeSerDes[_]): Unit = { + def registerType(userType: Type, udt: UserDefinedType[_]): Unit = { // TODO: Check to see if type is built-in. Throw exception? UDTRegistry.udtRegistry(userType) = udt } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java index ba0fc666fad52..fd815b3b6207f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java @@ -20,7 +20,7 @@ import java.lang.annotation.*; import org.apache.spark.annotation.DeveloperApi; -import org.apache.spark.sql.catalyst.types.UserDefinedTypeSerDes; +import org.apache.spark.sql.catalyst.types.UserDefinedType; /** * ::DeveloperApi:: @@ -38,5 +38,5 @@ @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface SQLUserDefinedType { - Class > serdes(); + Class > udt(); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index cf2735ebaf932..2fa0c35f1095d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -70,11 +70,9 @@ object DataType { StructType(fields.map(parseStructField)) case JSortedObject( - ("serdes", JString(serdesClass)), - ("type", JString("udt"))) => { - val serdes = Class.forName(serdesClass).newInstance().asInstanceOf[UserDefinedTypeSerDes[_]] - new UserDefinedType(serdes) - } + ("class", JString(udtClass)), + ("type", JString("udt"))) => + Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] } private def parseStructField(json: JValue): StructField = json match { @@ -580,7 +578,7 @@ case class MapType( * The data type for User Defined Types (UDTs). * * This interface allows a user to make their own classes more interoperable with SparkSQL; - * e.g., by creating a [[UserDefinedTypeSerDes]] for a class X, it becomes possible to create a SchemaRDD + * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create a SchemaRDD * which has class X in the schema. * * For SparkSQL to recognize UDTs, the UDT must be registered in @@ -593,12 +591,12 @@ case class MapType( * The conversion via `deserialize` occurs when reading from a `SchemaRDD`. */ @DeveloperApi -abstract class UserDefinedTypeSerDes[UserType] extends Serializable { +abstract class UserDefinedType[UserType] extends DataType with Serializable { - def userType: Class[UserType] + // def userType: Class[UserType] /** Underlying storage type for this UDT used by SparkSQL */ - def sqlType: DataType + def sqlType: StructType /** Convert the user type to a Row object */ // TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, @@ -607,12 +605,9 @@ abstract class UserDefinedTypeSerDes[UserType] extends Serializable { /** Convert a Row object to the user type */ def deserialize(row: Row): UserType -} -case class UserDefinedType[UserType](serdes: UserDefinedTypeSerDes[UserType]) - extends DataType with Serializable { override private[sql] def jsonValue: JValue = { ("type" -> "udt") ~ - ("serdes" -> serdes.getClass.getName) + ("class" -> this.getClass.getName) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 4df8a90adcd42..476d776230600 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -99,8 +99,8 @@ private[sql] object CatalystConverter { fieldIndex, parent) } - case UserDefinedType(serdes) => { - createConverter(field.copy(dataType = serdes.sqlType), fieldIndex, parent) + case udt: UserDefinedType[_] => { + createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent) } // Strings, Shorts and Bytes do not have a corresponding type in Parquet // so we need to treat them separately diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index d4c0e591cbce9..4fdb4bde87d24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -183,10 +183,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case t @ StructType(_) => writeStruct( t, value.asInstanceOf[CatalystConverter.StructScalaType[_]]) - case UserDefinedType(serdes) => { - println(value.getClass) - writeValue(serdes.sqlType, serdes.serialize(value)) - } + case t: UserDefinedType[_] => writeValue(t.sqlType, value) case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 5a27616bbff67..f9b33214a83d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -337,8 +337,8 @@ private[parquet] object ParquetTypesConverter extends Logging { parquetKeyType, parquetValueType) } - case UserDefinedType(serdes) => { - fromDataType(serdes.sqlType, name, nullable, inArray) + case udt: UserDefinedType[_] => { + fromDataType(udt.sqlType, name, nullable, inArray) } case _ => sys.error(s"Unsupported datatype $ctype") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 8202aa462f8b0..a470c2765f19e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.UDTRegistry import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.types.UserDefinedTypeSerDes +import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.test.TestSQLContext._ -@SQLUserDefinedType(serdes = classOf[MyDenseVectorUDT]) +@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { case v: MyDenseVector => @@ -35,31 +34,21 @@ class MyDenseVector(val data: Array[Double]) extends Serializable { case class MyLabeledPoint(label: Double, features: MyDenseVector) -class MyDenseVectorUDT extends UserDefinedTypeSerDes[MyDenseVector] { +class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { - override def userType: Class[MyDenseVector] = classOf[MyDenseVector] - - override def sqlType: ArrayType = ArrayType(DoubleType, containsNull = false) + override def sqlType: StructType = StructType(Seq( + StructField("data", ArrayType(DoubleType, containsNull = false), nullable = false))) override def serialize(obj: Any): Row = obj match { case features: MyDenseVector => - val row: GenericMutableRow = new GenericMutableRow(features.data.length) - var i = 0 - while (i < features.data.length) { - row.setDouble(i, features.data(i)) - i += 1 - } + val row: GenericMutableRow = new GenericMutableRow(1) + row.update(0, features.data.toSeq) row } override def deserialize(row: Row): MyDenseVector = { - val features = new MyDenseVector(new Array[Double](row.length)) - var i = 0 - while (i < row.length) { - features.data(i) = row.getDouble(i) - i += 1 - } - features + val features = row.getAs[Seq[Double]](0).toArray + new MyDenseVector(features) } } From 87264a5aa500f3d44c3a806893fcc9df6b5e0e90 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Oct 2014 12:20:06 -0700 Subject: [PATCH 29/46] remove debug code --- .../main/scala/org/apache/spark/mllib/linalg/Vectors.scala | 6 ------ .../org/apache/spark/sql/catalyst/types/dataTypes.scala | 2 -- 2 files changed, 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 4680ca0b92c2c..c593156f30233 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -296,8 +296,6 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { new SparseVectorUDT().deserialize(row.getAs[Row](2)) } } - - // override def userType: Class[Vector] = classOf[Vector] } /** @@ -320,8 +318,6 @@ private[spark] class DenseVectorUDT extends UserDefinedType[DenseVector] { val values = row.getAs[Seq[Double]](0).toArray new DenseVector(values) } - - // override def userType: Class[DenseVector] = classOf[DenseVector] } /** @@ -353,5 +349,3 @@ private[spark] class SparseVectorUDT extends UserDefinedType[SparseVector] { new SparseVector(vSize, indices, values) } - // override def userType: Class[SparseVector] = classOf[SparseVector] -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 2fa0c35f1095d..b1d90dba16ce7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -593,8 +593,6 @@ case class MapType( @DeveloperApi abstract class UserDefinedType[UserType] extends DataType with Serializable { - // def userType: Class[UserType] - /** Underlying storage type for this UDT used by SparkSQL */ def sqlType: StructType From 4500d8a3bd04c515af29fa42d803ee29e415e8a8 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Oct 2014 12:25:57 -0700 Subject: [PATCH 30/46] update example code --- .../spark/examples/mllib/DatasetExample.scala | 15 ++++++++++++--- .../org/apache/spark/mllib/linalg/Vectors.scala | 2 +- .../apache/spark/sql/UserDefinedTypeSuite.scala | 4 ++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index 344da2c90c94b..f8d83f4ec7327 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -17,6 +17,9 @@ package org.apache.spark.examples.mllib +import java.io.File + +import com.google.common.io.Files import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} @@ -96,15 +99,21 @@ object DatasetExample { (sum1, sum2) => sum1.merge(sum2)) println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") - schemaRDD.saveAsParquetFile("/tmp/dataset") - val newDataset = sqlContext.parquetFile("/tmp/dataset") + val tmpDir = Files.createTempDir() + tmpDir.deleteOnExit() + val outputDir = new File(tmpDir, "dataset").toString + println(s"Saving to $outputDir as Parquet file.") + schemaRDD.saveAsParquetFile(outputDir) + + println(s"Loading Parquet file with UDT from $outputDir.") + val newDataset = sqlContext.parquetFile(outputDir) println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) - println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") sc.stop() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index c593156f30233..17d7684b1ddf5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -348,4 +348,4 @@ private[spark] class SparseVectorUDT extends UserDefinedType[SparseVector] { val values = row.getAs[Seq[Double]](2).toArray new SparseVector(vSize, indices, values) } - +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index a470c2765f19e..3208c910a5bc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -47,8 +47,8 @@ class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } override def deserialize(row: Row): MyDenseVector = { - val features = row.getAs[Seq[Double]](0).toArray - new MyDenseVector(features) + val data = row.getAs[Seq[Double]](0).toArray + new MyDenseVector(data) } } From b028675714ba5178f7a1e233eeb35f399ac19ee4 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Oct 2014 12:49:45 -0700 Subject: [PATCH 31/46] allow any type in UDT --- .../apache/spark/mllib/linalg/Vectors.scala | 59 +++++++++++-------- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../spark/sql/catalyst/types/dataTypes.scala | 10 ++-- .../spark/sql/UserDefinedTypeSuite.scala | 22 +++---- 4 files changed, 50 insertions(+), 43 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 17d7684b1ddf5..9aaafa34f8c03 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -285,15 +285,18 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { row } - override def deserialize(row: Row): Vector = { - require(row.length == 3, - s"VectorUDT.deserialize given row with length ${row.length} but requires length == 3") - val vectorType = row.getByte(0) - vectorType match { - case 0 => - new DenseVectorUDT().deserialize(row.getAs[Row](1)) - case 1 => - new SparseVectorUDT().deserialize(row.getAs[Row](2)) + override def deserialize(datum: Any): Vector = { + datum match { + case row: Row => + require(row.length == 3, + s"VectorUDT.deserialize given row with length ${row.length} but requires length == 3") + val vectorType = row.getByte(0) + vectorType match { + case 0 => + new DenseVectorUDT().deserialize(row.getAs[Row](1)) + case 1 => + new SparseVectorUDT().deserialize(row.getAs[Row](2)) + } } } } @@ -304,19 +307,20 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { */ private[spark] class DenseVectorUDT extends UserDefinedType[DenseVector] { - override def sqlType: StructType = StructType(Seq( - StructField("values", ArrayType(DoubleType, containsNull = false), nullable = false))) + override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - override def serialize(obj: Any): Row = obj match { - case v: DenseVector => - val row: GenericMutableRow = new GenericMutableRow(1) - row.update(0, v.values.toSeq) - row + override def serialize(obj: Any): Seq[Double] = { + obj match { + case v: DenseVector => + v.values.toSeq + } } - override def deserialize(row: Row): DenseVector = { - val values = row.getAs[Seq[Double]](0).toArray - new DenseVector(values) + override def deserialize(datum: Any): DenseVector = { + datum match { + case values: Seq[_] => + new DenseVector(values.asInstanceOf[Seq[Double]].toArray) + } } } @@ -340,12 +344,15 @@ private[spark] class SparseVectorUDT extends UserDefinedType[SparseVector] { row } - override def deserialize(row: Row): SparseVector = { - require(row.length >= 1, - s"SparseVectorUDT.deserialize given row with length ${row.length} but requires length >= 1") - val vSize = row.getInt(0) - val indices = row.getAs[Seq[Int]](1).toArray - val values = row.getAs[Seq[Double]](2).toArray - new SparseVector(vSize, indices, values) + override def deserialize(datum: Any): SparseVector = { + datum match { + case row: Row => + require(row.length == 3, + s"SparseVectorUDT.deserialize given row with length ${row.length} but expect 3.") + val vSize = row.getInt(0) + val indices = row.getAs[Seq[Int]](1).toArray + val values = row.getAs[Seq[Double]](2).toArray + new SparseVector(vSize, indices, values) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index db1a8924c008e..de409b8c376b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -65,7 +65,7 @@ object ScalaReflection { convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) } case (d: Decimal, _: DecimalType) => d.toBigDecimal - case (r: Row, udt: UserDefinedType[_]) => udt.deserialize(r) + case (d, udt: UserDefinedType[_]) => udt.deserialize(d) case (other, _) => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index b1d90dba16ce7..220a347af5c0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -594,15 +594,15 @@ case class MapType( abstract class UserDefinedType[UserType] extends DataType with Serializable { /** Underlying storage type for this UDT used by SparkSQL */ - def sqlType: StructType + def sqlType: DataType - /** Convert the user type to a Row object */ + /** Convert the user type to a SQL datum */ // TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, // where we need to convert Any to UserType. - def serialize(obj: Any): Row + def serialize(obj: Any): Any - /** Convert a Row object to the user type */ - def deserialize(row: Row): UserType + /** Convert a SQL datum to the user type */ + def deserialize(datum: Any): UserType override private[sql] def jsonValue: JValue = { ("type" -> "udt") ~ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 3208c910a5bc4..cf793ccbd0c02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.test.TestSQLContext._ @@ -36,19 +35,20 @@ case class MyLabeledPoint(label: Double, features: MyDenseVector) class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { - override def sqlType: StructType = StructType(Seq( - StructField("data", ArrayType(DoubleType, containsNull = false), nullable = false))) + override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - override def serialize(obj: Any): Row = obj match { - case features: MyDenseVector => - val row: GenericMutableRow = new GenericMutableRow(1) - row.update(0, features.data.toSeq) - row + override def serialize(obj: Any): Seq[Double] = { + obj match { + case features: MyDenseVector => + features.data.toSeq + } } - override def deserialize(row: Row): MyDenseVector = { - val data = row.getAs[Seq[Double]](0).toArray - new MyDenseVector(data) + override def deserialize(datum: Any): MyDenseVector = { + datum match { + case data: Seq[_] => + new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray) + } } } From 7f29656d50ba77ebceed3f5b73a4269a00f8d250 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 28 Oct 2014 13:56:07 -0700 Subject: [PATCH 32/46] Moved udt case to top of all matches. Small cleanups --- .../spark/sql/catalyst/ScalaReflection.scala | 6 ++++-- .../spark/sql/parquet/ParquetConverter.scala | 17 +++++++++-------- .../spark/sql/parquet/ParquetTableSupport.scala | 3 ++- .../apache/spark/sql/parquet/ParquetTypes.scala | 7 ++++--- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index de409b8c376b3..6508f6d13c7cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -43,6 +43,8 @@ object ScalaReflection { * This ordering is important for UDT registration. */ def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + // Check UDT first since UDTs can override other types + case (obj, udt: UserDefinedType[_]) => udt.serialize(obj) case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => @@ -54,18 +56,18 @@ object ScalaReflection { convertToCatalyst(elem, field.dataType) }.toArray) case (d: BigDecimal, _) => Decimal(d) - case (obj, udt: UserDefinedType[_]) => udt.serialize(obj) case (other, _) => other } /** Converts Catalyst types used internally in rows to standard Scala types */ def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { + // Check UDT first since UDTs can override other types + case (d, udt: UserDefinedType[_]) => udt.deserialize(d) case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType)) case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) } case (d: Decimal, _: DecimalType) => d.toBigDecimal - case (d, udt: UserDefinedType[_]) => udt.deserialize(d) case (other, _) => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 476d776230600..1097a4e52cb39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -77,6 +77,10 @@ private[sql] object CatalystConverter { parent: CatalystConverter): Converter = { val fieldType: DataType = field.dataType fieldType match { + // Check UDT first since UDTs can override other types + case udt: UserDefinedType[_] => { + createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent) + } // For native JVM types we use a converter with native arrays case ArrayType(elementType: NativeType, false) => { new CatalystNativeArrayConverter(elementType, fieldIndex, parent) @@ -99,9 +103,6 @@ private[sql] object CatalystConverter { fieldIndex, parent) } - case udt: UserDefinedType[_] => { - createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent) - } // Strings, Shorts and Bytes do not have a corresponding type in Parquet // so we need to treat them separately case StringType => { @@ -258,8 +259,8 @@ private[parquet] class CatalystGroupConverter( schema, index, parent, - current=null, - buffer=new ArrayBuffer[Row]( + current = null, + buffer = new ArrayBuffer[Row]( CatalystArrayConverter.INITIAL_ARRAY_SIZE)) /** @@ -304,7 +305,7 @@ private[parquet] class CatalystGroupConverter( override def end(): Unit = { if (!isRootConverter) { - assert(current!=null) // there should be no empty groups + assert(current != null) // there should be no empty groups buffer.append(new GenericRow(current.toArray)) parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]])) } @@ -361,7 +362,7 @@ private[parquet] class CatalystPrimitiveRowConverter( override def end(): Unit = {} - // Overriden here to avoid auto-boxing for primitive types + // Overridden here to avoid auto-boxing for primitive types override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = current.setBoolean(fieldIndex, value) @@ -536,7 +537,7 @@ private[parquet] class CatalystNativeArrayConverter( override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = throw new UnsupportedOperationException - // Overriden here to avoid auto-boxing for primitive types + // Overridden here to avoid auto-boxing for primitive types override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = { checkGrowBuffer() buffer(elements) = value.asInstanceOf[NativeType] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 4fdb4bde87d24..81ec5f6e5427c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -174,6 +174,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { private[parquet] def writeValue(schema: DataType, value: Any): Unit = { if (value != null) { schema match { + // Check UDT first since UDTs can override other types + case t: UserDefinedType[_] => writeValue(t.sqlType, value) case t @ ArrayType(_, _) => writeArray( t, value.asInstanceOf[CatalystConverter.ArrayScalaType[_]]) @@ -183,7 +185,6 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case t @ StructType(_) => writeStruct( t, value.asInstanceOf[CatalystConverter.StructScalaType[_]]) - case t: UserDefinedType[_] => writeValue(t.sqlType, value) case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index f9b33214a83d3..67882bbc80bf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -290,6 +290,10 @@ private[parquet] object ParquetTypesConverter extends Logging { builder.named(name) }.getOrElse { ctype match { + // Check UDT first since UDTs can override other types + case udt: UserDefinedType[_] => { + fromDataType(udt.sqlType, name, nullable, inArray) + } case ArrayType(elementType, false) => { val parquetElementType = fromDataType( elementType, @@ -337,9 +341,6 @@ private[parquet] object ParquetTypesConverter extends Logging { parquetKeyType, parquetValueType) } - case udt: UserDefinedType[_] => { - fromDataType(udt.sqlType, name, nullable, inArray) - } case _ => sys.error(s"Unsupported datatype $ctype") } } From 8b242eae480c49eccbd3c5eb218b5e02a9210dcd Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 29 Oct 2014 11:56:21 -0700 Subject: [PATCH 33/46] Fixed merge error after last merge. Note: Last merge commit also removed SQL UDT examples from mllib. --- sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index a3212ff9ec6d6..3ee2ea05cfa2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import java.util.{List => JList} import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ import net.razorvine.pickle.Pickler @@ -29,12 +28,12 @@ import org.apache.spark.annotation.{AlphaComponent, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.sql.api.java.JavaSchemaRDD +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.storage.StorageLevel /** From 8de957cdec9c8e08715ef3bf904cdf8357998915 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 30 Oct 2014 13:23:41 -0700 Subject: [PATCH 34/46] Modified UserDefinedType to store Java class of user type so that registerUDT takes only the udt argument. Mid-process adding Java support for UDTs. --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../spark/sql/catalyst/UDTRegistry.scala | 17 +- .../spark/sql/catalyst/types/dataTypes.scala | 26 +-- .../apache/spark/sql/api/java/DataType.java | 15 ++ .../spark/sql/api/java/UserDefinedType.java | 53 +++++ .../org/apache/spark/sql/SQLContext.scala | 14 +- .../spark/sql/api/java/JavaSQLContext.scala | 11 + .../spark/sql/api/java/UDTWrappers.scala | 70 ++++++ .../scala/org/apache/spark/sql/package.scala | 22 ++ .../sql/types/util/DataTypeConversions.scala | 7 +- .../api/java/JavaUserDefinedTypeSuite.java | 202 ++++++++++++++++++ .../spark/sql/UserDefinedTypeSuite.scala | 12 +- 12 files changed, 419 insertions(+), 32 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala create mode 100644 sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 6508f6d13c7cd..fdf9fb28b87ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -98,7 +98,7 @@ object ScalaReflection { // Java appends a '$' to the object name but Scala does not). val udt = Utils.classForName(className) .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() - UDTRegistry.registerType(t, udt) + UDTRegistry.registerType(udt) Schema(udt, nullable = true) case t if UDTRegistry.udtRegistry.contains(t) => Schema(UDTRegistry.udtRegistry(t), nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala index 5c1fc01efddf6..a25014f704072 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala @@ -18,20 +18,19 @@ package org.apache.spark.sql.catalyst import scala.collection.mutable -import scala.reflect.runtime.universe._ -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.types.UserDefinedType /** - * ::DeveloperApi:: * Global registry for user-defined types (UDTs). - * - * WARNING: UDTs are currently only supported from Scala. */ -@DeveloperApi +private[sql] object UDTRegistry { - /** Map: UserType --> UserDefinedType */ + /** + * Map: UserType --> UserDefinedType + * + * Internally, we use [[java.lang.Class]] instances for keys in this registry. + */ val udtRegistry = new mutable.HashMap[Any, UserDefinedType[_]]() /** @@ -39,8 +38,8 @@ object UDTRegistry { * RDDs of user types and SchemaRDDs. * If this type has already been registered, this does nothing. */ - def registerType(userType: Type, udt: UserDefinedType[_]): Unit = { + def registerType(udt: UserDefinedType[_]): Unit = { // TODO: Check to see if type is built-in. Throw exception? - UDTRegistry.udtRegistry(userType) = udt + UDTRegistry.udtRegistry(udt.userClass) = udt } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 220a347af5c0a..4fc5dd402fe27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -576,29 +576,19 @@ case class MapType( /** * ::DeveloperApi:: * The data type for User Defined Types (UDTs). - * - * This interface allows a user to make their own classes more interoperable with SparkSQL; - * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create a SchemaRDD - * which has class X in the schema. - * - * For SparkSQL to recognize UDTs, the UDT must be registered in - * [[org.apache.spark.sql.catalyst.UDTRegistry]]. This registration can be done either - * explicitly by calling [[org.apache.spark.sql.catalyst.UDTRegistry.registerType()]] before using - * the UDT with SparkSQL, or implicitly by annotating the UDT with - * [[org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType]]. - * - * The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD. - * The conversion via `deserialize` occurs when reading from a `SchemaRDD`. */ @DeveloperApi abstract class UserDefinedType[UserType] extends DataType with Serializable { - /** Underlying storage type for this UDT used by SparkSQL */ + /** Underlying storage type for this UDT */ def sqlType: DataType - /** Convert the user type to a SQL datum */ - // TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, - // where we need to convert Any to UserType. + /** + * Convert the user type to a SQL datum + * + * TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, + * where we need to convert Any to UserType. + */ def serialize(obj: Any): Any /** Convert a SQL datum to the user type */ @@ -608,4 +598,6 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { ("type" -> "udt") ~ ("class" -> this.getClass.getName) } + + def userClass: java.lang.Class[UserType] } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java index c38354039d686..dd60ac9c3c2ae 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java @@ -200,4 +200,19 @@ public static StructType createStructType(StructField[] fields) { return new StructType(fields); } +/* + public static org.apache.spark.sql.UserDefinedType + wrapAsScala(UserDefinedType udtType) { + // TODO: Check if we can unwrap instead of wrapping. + return new JavaToScalaUDTWrapper(udtType); + } + +// EDITING HERE: Does this method need to be implemented in Scala in order to avoid exposing Catalyst? + public static UserDefinedType + wrapAsJava(org.apache.spark.sql.UserDefinedType udtType) { + // TODO: Check if we can unwrap instead of wrapping. + return new ScalaToJavaUDTWrapper(udtType); + } +*/ + } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java new file mode 100644 index 0000000000000..d98d001eb94e1 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import org.apache.spark.annotation.DeveloperApi; + +/** + * ::DeveloperApi:: + * The data type representing User-Defined Types (UDTs). + * UDTs may use any other DataType for an underlying representation. + * + * TODO: Do we need to provide DataType#createUserDefinedType methods? + */ +@DeveloperApi +public abstract class UserDefinedType extends DataType { + + protected UserDefinedType() { // TODO? + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + UserDefinedType that = (UserDefinedType) o; + return this.sqlType().equals(that.sqlType()); + } + + /** Underlying storage type for this UDT */ + public abstract DataType sqlType(); + + /** Convert the user type to a SQL datum */ + public abstract Object serialize(Object obj); + + /** Convert a SQL datum to the user type */ + public abstract UserType deserialize(Object datum); + + public abstract Class userClass(); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 173d7f5af05cb..eee0f6320a602 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.{UDTRegistry, ScalaReflection} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl.ExpressionConversions import org.apache.spark.sql.catalyst.expressions._ @@ -467,3 +467,15 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } } + +object SQLContext { + + /** + * Registers a User-Defined Type (UDT) so that schemas can include this type. + * UDTs can override built-in types. + */ + def registerUDT(udt: UserDefinedType[_]): Unit = { + UDTRegistry.registerType(udt) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 876b1c6edef20..9d7ff2817ebaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.sql.catalyst.UDTRegistry import org.apache.spark.sql.json.JsonRDD import org.apache.spark.sql.types.util.DataTypeConversions import org.apache.spark.sql.{SQLContext, StructType => SStructType} @@ -240,3 +241,13 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { } } } + +object JavaSQLContext { + /** + * Registers a User-Defined Type (UDT) so that schemas can include this type. + * UDTs can override built-in types. + */ + def registerUDT(udt: UserDefinedType[_]): Unit = { + UDTRegistry.registerType(UDTWrappers.wrapAsScala(udt)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala new file mode 100644 index 0000000000000..33566b7bc008e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java + +import org.apache.spark.sql.{DataType => ScalaDataType, UserDefinedType => ScalaUserDefinedType} +import org.apache.spark.sql.types.util.DataTypeConversions + +/** + * Scala wrapper for a Java UserDefinedType + */ +private[sql] class JavaToScalaUDTWrapper[UserType](val javaUDT: UserDefinedType[UserType]) + extends ScalaUserDefinedType[UserType] { + + /** Underlying storage type for this UDT */ + val sqlType: ScalaDataType = DataTypeConversions.asScalaDataType(javaUDT.sqlType()) + + /** Convert the user type to a SQL datum */ + def serialize(obj: Any): Any = javaUDT.serialize(obj) + + /** Convert a SQL datum to the user type */ + def deserialize(datum: Any): UserType = javaUDT.deserialize(datum) + + val userClass: java.lang.Class[UserType] = javaUDT.userClass() +} + +/** + * Java wrapper for a Scala UserDefinedType + */ +private[sql] class ScalaToJavaUDTWrapper[UserType](val scalaUDT: ScalaUserDefinedType[UserType]) + extends UserDefinedType[UserType] { + + /** Underlying storage type for this UDT */ + val sqlType: DataType = DataTypeConversions.asJavaDataType(scalaUDT.sqlType) + + /** Convert the user type to a SQL datum */ + def serialize(obj: Any): java.lang.Object = scalaUDT.serialize(obj).asInstanceOf[java.lang.Object] + + /** Convert a SQL datum to the user type */ + def deserialize(datum: Any): UserType = scalaUDT.deserialize(datum) + + val userClass: java.lang.Class[UserType] = scalaUDT.userClass +} + +private[sql] object UDTWrappers { + + def wrapAsScala(udtType: UserDefinedType[_]): JavaToScalaUDTWrapper[_] = { + // TODO: Check if we can unwrap instead of wrapping. + new JavaToScalaUDTWrapper(udtType) + } + + def wrapAsJava(udtType: ScalaUserDefinedType[_]): ScalaToJavaUDTWrapper[_] = { + // TODO: Check if we can unwrap instead of wrapping. + new ScalaToJavaUDTWrapper(udtType) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 05926a24c5307..6a5aae197799d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -451,4 +451,26 @@ package object sql { * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. */ type MetadataBuilder = catalyst.util.MetadataBuilder + + /** + * The data type for User Defined Types (UDTs). + * + * This interface allows a user to make their own classes more interoperable with SparkSQL; + * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create + * a SchemaRDD which has class X in the schema. + * + * For SparkSQL to recognize UDTs, the UDT must be registered in + * [[org.apache.spark.sql.catalyst.UDTRegistry]]. This registration can be done either + * explicitly by calling [[org.apache.spark.sql.catalyst.UDTRegistry.registerType()]] + * before using the UDT with SparkSQL, or implicitly by annotating the UDT with + * [[org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType]]. + * + * The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD. + * The conversion via `deserialize` occurs when reading from a `SchemaRDD`. + */ + @DeveloperApi + type UserDefinedType[UserType] = catalyst.types.UserDefinedType[UserType] + + @DeveloperApi + type SQLUserDefinedType = catalyst.annotation.SQLUserDefinedType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 7564bf3923032..c24b85740ab34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.types.util import org.apache.spark.sql._ -import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, MetadataBuilder => JMetaDataBuilder} +import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, + MetadataBuilder => JMetaDataBuilder, UDTWrappers, JavaToScalaUDTWrapper} import org.apache.spark.sql.api.java.{DecimalType => JDecimalType} import org.apache.spark.sql.catalyst.types.decimal.Decimal @@ -63,6 +64,8 @@ protected[sql] object DataTypeConversions { mapType.valueContainsNull) case structType: StructType => JDataType.createStructType( structType.fields.map(asJavaStructField).asJava) + case udtType: UserDefinedType[_] => + UDTWrappers.wrapAsJava(udtType) } /** @@ -118,6 +121,8 @@ protected[sql] object DataTypeConversions { mapType.isValueContainsNull) case structType: org.apache.spark.sql.api.java.StructType => StructType(structType.getFields.map(asScalaStructField)) + case udtType: org.apache.spark.sql.api.java.UserDefinedType[_] => + UDTWrappers.wrapAsScala(udtType) } /** Converts Java objects to catalyst rows / types */ diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java new file mode 100644 index 0000000000000..02511306b052d --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.sql.types.util.DataTypeConversions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; + + +public class JavaUserDefinedTypeSuite implements Serializable { + private transient JavaSparkContext javaCtx; + private transient JavaSQLContext javaSqlCtx; + + @Before + public void setUp() { + javaCtx = new JavaSparkContext("local", "JavaUserDefinedTypeSuite"); + javaSqlCtx = new JavaSQLContext(javaCtx); + JavaSQLContext$.MODULE$.registerUDT(new MyDenseVectorUDT()); + } + + @After + public void tearDown() { + javaCtx.stop(); + javaCtx = null; + javaSqlCtx = null; + } + + // Note: Annotation is awkward since it requires an argument which is a Scala UserDefinedType. + //@SQLUserDefinedType(udt = MyDenseVectorUDT.class) + class MyDenseVector implements Serializable { + + public MyDenseVector(double[] data) { + this.data = data; + } + + public double[] data; + + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + + MyDenseVector dv = (MyDenseVector) other; + return Arrays.equals(this.data, dv.data); + } + } + + class MyLabeledPoint { + double label; + MyDenseVector features; + public MyLabeledPoint(double label, MyDenseVector features) { + this.label = label; + this.features = features; + } + } + + class MyDenseVectorUDT extends UserDefinedType { + + /** + * Underlying storage type for this UDT + */ + public DataType sqlType() { + return DataType.createArrayType(DataType.DoubleType, false); + } + + /** + * Convert the user type to a SQL datum + */ + public Object serialize(Object obj) { + return ((MyDenseVector) obj).data; + } + + /** + * Convert a SQL datum to the user type + */ + public MyDenseVector deserialize(Object datum) { + return new MyDenseVector((double[]) datum); + } + + public Class userClass() { + return MyDenseVector.class; + } + } + + @Test + public void useScalaUDT() { + List points = Arrays.asList( + new org.apache.spark.sql.MyLabeledPoint(1.0, + new org.apache.spark.sql.MyDenseVector(new double[]{0.1, 1.0})), + new org.apache.spark.sql.MyLabeledPoint(0.0, + new org.apache.spark.sql.MyDenseVector(new double[]{0.2, 2.0}))); + JavaRDD rowRDD = javaCtx.parallelize(points).map( + new Function() { + public Row call(org.apache.spark.sql.MyLabeledPoint lp) throws Exception { + return Row.create(lp.label(), lp.features()); + } + }); + + List fields = new ArrayList(2); + fields.add(DataType.createStructField("label", DataType.DoubleType, false)); + // EDITING HERE: HOW TO CONVERT SCALA UDT TO JAVA UDT (without exposing Catalyst)? + fields.add(DataType.createStructField("features", ???, false)); + StructType schema = DataType.createStructType(fields); + + JavaSchemaRDD schemaRDD = + javaSqlCtx.applySchema(pointRDD, org.apache.spark.sql.MyLabeledPoint.class); + + schemaRDD.registerTempTable("points"); + List actual = javaSqlCtx.sql("SELECT * FROM points").collect(); +// check set + List expected = new ArrayList(2); + expected.add(Row.create(1.0, + new org.apache.spark.sql.MyDenseVector(new double[]{0.1, 1.0}))); + expected.add(Row.create(0.0, + new org.apache.spark.sql.MyDenseVector(new double[]{0.2, 2.0}))); + + Assert.assertEquals(expected, actual); + } + + // test("register user type: MyDenseVector for MyLabeledPoint") + @Test + public void registerUDT() { + /* + List points = Arrays.asList( + new MyLabeledPoint(1.0, new MyDenseVector(new double[]{0.1, 1.0})), + new MyLabeledPoint(0.0, new MyDenseVector(new double[]{0.2, 2.0}))); + JavaRDD pointsRDD = javaCtx.parallelize(points).map( + new Function() { + public Row call(MyLabeledPoint lp) throws Exception { + return Row.create(lp.label, lp.features) + } + } + ); + JavaSchemaRDD schemaRDD = pointsRDD; + */ + /* + JavaRDD labels = pointsRDD.select('label).map { case Row(v: double) => v } + val labelsArrays: Array[double] = labels.collect() + assert(labelsArrays.size === 2) + assert(labelsArrays.contains(1.0)) + assert(labelsArrays.contains(0.0)) + + val features: RDD[MyDenseVector] = + pointsRDD.select('features).map { case Row(v: MyDenseVector) => v } + val featuresArrays: Array[MyDenseVector] = features.collect() + assert(featuresArrays.size === 2) + assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) + assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) + */ + +/* + JavaRDD rowRDD = javaCtx.parallelize(personList).map( + new Function() { + public Row call(Person person) throws Exception { + return Row.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList(2); + fields.add(DataType.createStructField("name", DataType.StringType, false)); + fields.add(DataType.createStructField("age", DataType.IntegerType, false)); + StructType schema = DataType.createStructType(fields); + + JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD, schema); + schemaRDD.registerTempTable("people"); + List actual = javaSqlCtx.sql("SELECT * FROM people").collect(); + + List expected = new ArrayList(2); + expected.add(Row.create("Michael", 29)); + expected.add(Row.create("Yin", 28)); + + Assert.assertEquals(expected, actual); + + */ + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index cf793ccbd0c02..65ee110f9e74e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.sql import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType -import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.test.TestSQLContext._ @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) @@ -31,7 +29,13 @@ class MyDenseVector(val data: Array[Double]) extends Serializable { } } -case class MyLabeledPoint(label: Double, features: MyDenseVector) +case class MyLabeledPoint(label: Double, features: MyDenseVector) { + override def equals(other: Any): Boolean = other match { + case lp: MyLabeledPoint => + label == lp.label && features.equals(lp.features) + case _ => false + } +} class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { @@ -50,6 +54,8 @@ class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray) } } + + override def userClass = classOf[MyDenseVector] } /* From fa86b206b97dc1854622d0e128d18f0cfe6167d4 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 30 Oct 2014 22:52:08 -0700 Subject: [PATCH 35/46] Removed Java UserDefinedType, and made UDTs private[spark] for now --- .../spark/sql/catalyst/ScalaReflection.scala | 6 +- .../spark/sql/catalyst/types/dataTypes.scala | 13 ++ .../apache/spark/sql/api/java/DataType.java | 3 +- .../spark/sql/api/java/UserDefinedType.java | 4 +- .../org/apache/spark/sql/SQLContext.scala | 14 +- .../spark/sql/api/java/JavaSQLContext.scala | 28 ++-- .../org/apache/spark/sql/api/java/Row.scala | 2 +- .../spark/sql/api/java/UDTWrappers.scala | 23 ++-- .../scala/org/apache/spark/sql/package.scala | 21 --- .../sql/types/util/DataTypeConversions.scala | 13 +- .../api/java/JavaUserDefinedTypeSuite.java | 124 ++++++------------ .../spark/sql/UserDefinedTypeSuite.scala | 40 +----- 12 files changed, 102 insertions(+), 189 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index fdf9fb28b87ce..cc5473a1298ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -98,10 +98,10 @@ object ScalaReflection { // Java appends a '$' to the object name but Scala does not). val udt = Utils.classForName(className) .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() - UDTRegistry.registerType(udt) + //UDTRegistry.registerType(udt) Schema(udt, nullable = true) - case t if UDTRegistry.udtRegistry.contains(t) => - Schema(UDTRegistry.udtRegistry(t), nullable = true) + //case t if UDTRegistry.udtRegistry.contains(t) => + //Schema(UDTRegistry.udtRegistry(t), nullable = true) case t if t <:< typeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t Schema(schemaFor(optType).dataType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 4fc5dd402fe27..ee78125a056f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -576,6 +576,19 @@ case class MapType( /** * ::DeveloperApi:: * The data type for User Defined Types (UDTs). + * + * This interface allows a user to make their own classes more interoperable with SparkSQL; + * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create + * a SchemaRDD which has class X in the schema. + * + * For SparkSQL to recognize UDTs, the UDT must be registered in + * [[org.apache.spark.sql.catalyst.UDTRegistry]]. This registration can be done either + * explicitly by calling [[org.apache.spark.sql.catalyst.UDTRegistry.registerType()]] + * before using the UDT with SparkSQL, or implicitly by annotating the UDT with + * [[org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType]]. + * + * The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD. + * The conversion via `deserialize` occurs when reading from a `SchemaRDD`. */ @DeveloperApi abstract class UserDefinedType[UserType] extends DataType with Serializable { diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java index dd60ac9c3c2ae..ebe78333bac3a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.api.java; +import java.io.Serializable; import java.util.*; /** @@ -25,7 +26,7 @@ * To get/create specific data type, users should use singleton objects and factory methods * provided by this class. */ -public abstract class DataType { +public abstract class DataType implements Serializable { /** * Gets the StringType object. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java index d98d001eb94e1..cc28201b369c3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.api.java; +import java.io.Serializable; + import org.apache.spark.annotation.DeveloperApi; /** @@ -27,7 +29,7 @@ * TODO: Do we need to provide DataType#createUserDefinedType methods? */ @DeveloperApi -public abstract class UserDefinedType extends DataType { +public abstract class UserDefinedType extends DataType implements Serializable { protected UserDefinedType() { // TODO? } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index eee0f6320a602..173d7f5af05cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{UDTRegistry, ScalaReflection} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl.ExpressionConversions import org.apache.spark.sql.catalyst.expressions._ @@ -467,15 +467,3 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } } - -object SQLContext { - - /** - * Registers a User-Defined Type (UDT) so that schemas can include this type. - * UDTs can override built-in types. - */ - def registerUDT(udt: UserDefinedType[_]): Unit = { - UDTRegistry.registerType(udt) - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 9d7ff2817ebaa..1a7b9864a6f8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -23,13 +23,13 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} -import org.apache.spark.sql.catalyst.UDTRegistry import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.types.util.DataTypeConversions -import org.apache.spark.sql.{SQLContext, StructType => SStructType} +import org.apache.spark.sql.{StructType => SStructType, SQLContext} +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} -import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.types.util.DataTypeConversions import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType import org.apache.spark.util.Utils @@ -89,7 +89,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { * Applies a schema to an RDD of Java Beans. */ def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): JavaSchemaRDD = { - val schema = getSchema(beanClass) + val attributeSeq = getSchema(beanClass) val className = beanClass.getName val rowRdd = rdd.rdd.mapPartitions { iter => // BeanInfo is not serializable so we must rediscover it remotely for each partition. @@ -100,11 +100,13 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { iter.map { row => new GenericRow( - extractors.map(e => DataTypeConversions.convertJavaToCatalyst(e.invoke(row))).toArray[Any] + extractors.zip(attributeSeq).map { case (e, attr) => + DataTypeConversions.convertJavaToCatalyst(e.invoke(row), attr.dataType) + }.toArray[Any] ): ScalaRow } } - new JavaSchemaRDD(sqlContext, LogicalRDD(schema, rowRdd)(sqlContext)) + new JavaSchemaRDD(sqlContext, LogicalRDD(attributeSeq, rowRdd)(sqlContext)) } /** @@ -199,6 +201,8 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") fields.map { property => val (dataType, nullable) = property.getPropertyType match { + case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => + (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) case c: Class[_] if c == classOf[java.lang.String] => (org.apache.spark.sql.StringType, true) case c: Class[_] if c == java.lang.Short.TYPE => @@ -241,13 +245,3 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { } } } - -object JavaSQLContext { - /** - * Registers a User-Defined Type (UDT) so that schemas can include this type. - * UDTs can override built-in types. - */ - def registerUDT(udt: UserDefinedType[_]): Unit = { - UDTRegistry.registerType(UDTWrappers.wrapAsScala(udt)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index 401798e317e96..cfe4fa22268c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -103,7 +103,7 @@ class Row(private[spark] val row: ScalaRow) extends Serializable { override def equals(other: Any): Boolean = other match { case that: Row => (that canEqual this) && - row == that.row + row.equals(that.row) case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala index 33566b7bc008e..a7d0f4f127ecc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.api.java -import org.apache.spark.sql.{DataType => ScalaDataType, UserDefinedType => ScalaUserDefinedType} +import org.apache.spark.sql.catalyst.types.{UserDefinedType => ScalaUserDefinedType} +import org.apache.spark.sql.{DataType => ScalaDataType} import org.apache.spark.sql.types.util.DataTypeConversions /** * Scala wrapper for a Java UserDefinedType */ private[sql] class JavaToScalaUDTWrapper[UserType](val javaUDT: UserDefinedType[UserType]) - extends ScalaUserDefinedType[UserType] { + extends ScalaUserDefinedType[UserType] with Serializable { /** Underlying storage type for this UDT */ val sqlType: ScalaDataType = DataTypeConversions.asScalaDataType(javaUDT.sqlType()) @@ -42,7 +43,7 @@ private[sql] class JavaToScalaUDTWrapper[UserType](val javaUDT: UserDefinedType[ * Java wrapper for a Scala UserDefinedType */ private[sql] class ScalaToJavaUDTWrapper[UserType](val scalaUDT: ScalaUserDefinedType[UserType]) - extends UserDefinedType[UserType] { + extends UserDefinedType[UserType] with Serializable { /** Underlying storage type for this UDT */ val sqlType: DataType = DataTypeConversions.asJavaDataType(scalaUDT.sqlType) @@ -58,13 +59,17 @@ private[sql] class ScalaToJavaUDTWrapper[UserType](val scalaUDT: ScalaUserDefine private[sql] object UDTWrappers { - def wrapAsScala(udtType: UserDefinedType[_]): JavaToScalaUDTWrapper[_] = { - // TODO: Check if we can unwrap instead of wrapping. - new JavaToScalaUDTWrapper(udtType) + def wrapAsScala(udtType: UserDefinedType[_]): ScalaUserDefinedType[_] = { + udtType match { + case t: ScalaToJavaUDTWrapper[_] => t.scalaUDT + case _ => new JavaToScalaUDTWrapper(udtType) + } } - def wrapAsJava(udtType: ScalaUserDefinedType[_]): ScalaToJavaUDTWrapper[_] = { - // TODO: Check if we can unwrap instead of wrapping. - new ScalaToJavaUDTWrapper(udtType) + def wrapAsJava(udtType: ScalaUserDefinedType[_]): UserDefinedType[_] = { + udtType match { + case t: JavaToScalaUDTWrapper[_] => t.javaUDT + case _ => new ScalaToJavaUDTWrapper(udtType) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 6a5aae197799d..a4c01a50d2ebe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -452,25 +452,4 @@ package object sql { */ type MetadataBuilder = catalyst.util.MetadataBuilder - /** - * The data type for User Defined Types (UDTs). - * - * This interface allows a user to make their own classes more interoperable with SparkSQL; - * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create - * a SchemaRDD which has class X in the schema. - * - * For SparkSQL to recognize UDTs, the UDT must be registered in - * [[org.apache.spark.sql.catalyst.UDTRegistry]]. This registration can be done either - * explicitly by calling [[org.apache.spark.sql.catalyst.UDTRegistry.registerType()]] - * before using the UDT with SparkSQL, or implicitly by annotating the UDT with - * [[org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType]]. - * - * The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD. - * The conversion via `deserialize` occurs when reading from a `SchemaRDD`. - */ - @DeveloperApi - type UserDefinedType[UserType] = catalyst.types.UserDefinedType[UserType] - - @DeveloperApi - type SQLUserDefinedType = catalyst.annotation.SQLUserDefinedType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index c24b85740ab34..67ba50f2da708 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -17,13 +17,16 @@ package org.apache.spark.sql.types.util +import scala.collection.JavaConverters._ + import org.apache.spark.sql._ import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, MetadataBuilder => JMetaDataBuilder, UDTWrappers, JavaToScalaUDTWrapper} import org.apache.spark.sql.api.java.{DecimalType => JDecimalType} import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.types.UserDefinedType -import scala.collection.JavaConverters._ protected[sql] object DataTypeConversions { @@ -126,9 +129,11 @@ protected[sql] object DataTypeConversions { } /** Converts Java objects to catalyst rows / types */ - def convertJavaToCatalyst(a: Any): Any = a match { - case d: java.math.BigDecimal => Decimal(BigDecimal(d)) - case other => other + def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type + case (d: java.math.BigDecimal, _) => Decimal(BigDecimal(d)) + case (d: java.math.BigDecimal, _) => BigDecimal(d) + case (other, _) => other } /** Converts Java objects to catalyst rows / types */ diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java index 02511306b052d..28285166bdba0 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java @@ -18,11 +18,8 @@ package org.apache.spark.sql.api.java; import java.io.Serializable; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; +import java.util.*; -import org.apache.spark.sql.types.util.DataTypeConversions; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -41,7 +38,7 @@ public class JavaUserDefinedTypeSuite implements Serializable { public void setUp() { javaCtx = new JavaSparkContext("local", "JavaUserDefinedTypeSuite"); javaSqlCtx = new JavaSQLContext(javaCtx); - JavaSQLContext$.MODULE$.registerUDT(new MyDenseVectorUDT()); + //JavaSQLContext$.MODULE$.registerUDT(new MyDenseVectorUDT()); } @After @@ -70,7 +67,7 @@ public boolean equals(Object other) { } } - class MyLabeledPoint { + class MyLabeledPoint implements Serializable { double label; MyDenseVector features; public MyLabeledPoint(double label, MyDenseVector features) { @@ -79,7 +76,7 @@ public MyLabeledPoint(double label, MyDenseVector features) { } } - class MyDenseVectorUDT extends UserDefinedType { + class MyDenseVectorUDT extends UserDefinedType implements Serializable { /** * Underlying storage type for this UDT @@ -114,89 +111,48 @@ public void useScalaUDT() { new org.apache.spark.sql.MyDenseVector(new double[]{0.1, 1.0})), new org.apache.spark.sql.MyLabeledPoint(0.0, new org.apache.spark.sql.MyDenseVector(new double[]{0.2, 2.0}))); - JavaRDD rowRDD = javaCtx.parallelize(points).map( - new Function() { - public Row call(org.apache.spark.sql.MyLabeledPoint lp) throws Exception { - return Row.create(lp.label(), lp.features()); - } - }); - - List fields = new ArrayList(2); - fields.add(DataType.createStructField("label", DataType.DoubleType, false)); - // EDITING HERE: HOW TO CONVERT SCALA UDT TO JAVA UDT (without exposing Catalyst)? - fields.add(DataType.createStructField("features", ???, false)); - StructType schema = DataType.createStructType(fields); + JavaRDD pointsRDD = javaCtx.parallelize(points); JavaSchemaRDD schemaRDD = - javaSqlCtx.applySchema(pointRDD, org.apache.spark.sql.MyLabeledPoint.class); + javaSqlCtx.applySchema(pointsRDD, org.apache.spark.sql.MyLabeledPoint.class); schemaRDD.registerTempTable("points"); List actual = javaSqlCtx.sql("SELECT * FROM points").collect(); -// check set - List expected = new ArrayList(2); - expected.add(Row.create(1.0, - new org.apache.spark.sql.MyDenseVector(new double[]{0.1, 1.0}))); - expected.add(Row.create(0.0, - new org.apache.spark.sql.MyDenseVector(new double[]{0.2, 2.0}))); - - Assert.assertEquals(expected, actual); - } - - // test("register user type: MyDenseVector for MyLabeledPoint") - @Test - public void registerUDT() { - /* - List points = Arrays.asList( - new MyLabeledPoint(1.0, new MyDenseVector(new double[]{0.1, 1.0})), - new MyLabeledPoint(0.0, new MyDenseVector(new double[]{0.2, 2.0}))); - JavaRDD pointsRDD = javaCtx.parallelize(points).map( - new Function() { - public Row call(MyLabeledPoint lp) throws Exception { - return Row.create(lp.label, lp.features) - } - } - ); - JavaSchemaRDD schemaRDD = pointsRDD; - */ + List actualPoints = + new LinkedList(); + for (Row r : actual) { + // Note: JavaSQLContext.getSchema switches the ordering of the Row elements + // in the MyLabeledPoint case class. + actualPoints.add(new org.apache.spark.sql.MyLabeledPoint( + r.getDouble(1), (org.apache.spark.sql.MyDenseVector)r.get(0))); + } + for (org.apache.spark.sql.MyLabeledPoint lp : points) { + Assert.assertTrue(actualPoints.contains(lp)); + } /* - JavaRDD labels = pointsRDD.select('label).map { case Row(v: double) => v } - val labelsArrays: Array[double] = labels.collect() - assert(labelsArrays.size === 2) - assert(labelsArrays.contains(1.0)) - assert(labelsArrays.contains(0.0)) - - val features: RDD[MyDenseVector] = - pointsRDD.select('features).map { case Row(v: MyDenseVector) => v } - val featuresArrays: Array[MyDenseVector] = features.collect() - assert(featuresArrays.size === 2) - assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) - assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) - */ - -/* - JavaRDD rowRDD = javaCtx.parallelize(personList).map( - new Function() { - public Row call(Person person) throws Exception { - return Row.create(person.getName(), person.getAge()); - } - }); - - List fields = new ArrayList(2); - fields.add(DataType.createStructField("name", DataType.StringType, false)); - fields.add(DataType.createStructField("age", DataType.IntegerType, false)); - StructType schema = DataType.createStructType(fields); - - JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD, schema); - schemaRDD.registerTempTable("people"); - List actual = javaSqlCtx.sql("SELECT * FROM people").collect(); - - List expected = new ArrayList(2); - expected.add(Row.create("Michael", 29)); - expected.add(Row.create("Yin", 28)); - - Assert.assertEquals(expected, actual); + // THIS FAILS BECAUSE JavaSQLContext.getSchema switches the ordering of the Row elements + // in the MyLabeledPoint case class. + List expected = new LinkedList(); + expected.add(Row.create(new org.apache.spark.sql.MyLabeledPoint(1.0, + new org.apache.spark.sql.MyDenseVector(new double[]{0.1, 1.0})))); + expected.add(Row.create(new org.apache.spark.sql.MyLabeledPoint(0.0, + new org.apache.spark.sql.MyDenseVector(new double[]{0.2, 2.0})))); + System.out.println("Expected:"); + for (Row r : expected) { + System.out.println("r: " + r.toString()); + for (int i = 0; i < r.length(); ++i) { + System.out.println(" r[i]: " + r.get(i).toString()); + } + } - */ + System.out.println("Actual:"); + for (Row r : actual) { + System.out.println("r: " + r.toString()); + for (int i = 0; i < r.length(); ++i) { + System.out.println(" r[i]: " + r.get(i).toString()); + } + Assert.assertTrue(expected.contains(r)); + } + */ } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 65ee110f9e74e..092b8fe12be2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -18,10 +18,12 @@ package org.apache.spark.sql import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.test.TestSQLContext._ @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) -class MyDenseVector(val data: Array[Double]) extends Serializable { +private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data) @@ -29,15 +31,9 @@ class MyDenseVector(val data: Array[Double]) extends Serializable { } } -case class MyLabeledPoint(label: Double, features: MyDenseVector) { - override def equals(other: Any): Boolean = other match { - case lp: MyLabeledPoint => - label == lp.label && features.equals(lp.features) - case _ => false - } -} +private[sql] case class MyLabeledPoint(label: Double, features: MyDenseVector) -class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { +private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) @@ -58,31 +54,6 @@ class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { override def userClass = classOf[MyDenseVector] } -/* -// This is to test registering a UDT which is defined within an object (where Java and Scala -// reflection use different class names). This functionality is currently not supported but -// should be later on. -object UserDefinedTypeSuiteObject { - - class ClassInObject(val dv: MyDenseVector) extends Serializable - - case class MyLabeledPointInObject(label: Double, features: ClassInObject) - - class ClassInObjectUDT extends UserDefinedType[ClassInObject] { - - override def sqlType: ArrayType = ArrayType(DoubleType, containsNull = false) - - private val dvUDT = new MyDenseVectorUDT() - - override def serialize(obj: Any): Row = obj match { - case cio: ClassInObject => dvUDT.serialize(cio) - } - - override def deserialize(row: Row): ClassInObject = new ClassInObject(dvUDT.deserialize(row)) - } -} -*/ - class UserDefinedTypeSuite extends QueryTest { test("register user type: MyDenseVector for MyLabeledPoint") { @@ -104,5 +75,4 @@ class UserDefinedTypeSuite extends QueryTest { assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) } - } From 20630bc0f426fbd426f21b161eaf09acdb036f92 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 30 Oct 2014 23:03:31 -0700 Subject: [PATCH 36/46] fixed scalastyle --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index cc5473a1298ee..1de6f589fc455 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -98,10 +98,7 @@ object ScalaReflection { // Java appends a '$' to the object name but Scala does not). val udt = Utils.classForName(className) .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() - //UDTRegistry.registerType(udt) Schema(udt, nullable = true) - //case t if UDTRegistry.udtRegistry.contains(t) => - //Schema(UDTRegistry.udtRegistry(t), nullable = true) case t if t <:< typeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t Schema(schemaFor(optType).dataType, nullable = true) From 6fddc1c33d36860075bc13bd7662ecd53ccce3c5 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 31 Oct 2014 11:11:50 -0700 Subject: [PATCH 37/46] Made MyLabeledPoint into a Java Bean --- .../scala/org/apache/spark/sql/UserDefinedTypeSuite.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 092b8fe12be2c..666235e57f812 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import scala.beans.{BeanInfo, BeanProperty} + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.types.UserDefinedType @@ -31,7 +33,10 @@ private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { } } -private[sql] case class MyLabeledPoint(label: Double, features: MyDenseVector) +@BeanInfo +private[sql] case class MyLabeledPoint( + @BeanProperty label: Double, + @BeanProperty features: MyDenseVector) private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { From a571bb6090a2da71bfde0c591525b19b02751333 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 31 Oct 2014 12:39:52 -0700 Subject: [PATCH 38/46] Removed old UDT code (registry and Java UDTs). Cleaned up other code. Extended JavaUserDefinedTypeSuite --- .../spark/sql/catalyst/UDTRegistry.scala | 45 -------- .../annotation/SQLUserDefinedType.java | 8 +- .../spark/sql/catalyst/types/dataTypes.scala | 8 +- .../apache/spark/sql/api/java/DataType.java | 16 +-- .../spark/sql/api/java/UserDefinedType.java | 6 +- .../org/apache/spark/sql/api/java/Row.scala | 2 +- .../spark/sql/execution/SparkPlan.scala | 2 - .../sql/types/util/DataTypeConversions.scala | 12 +- .../api/java/JavaUserDefinedTypeSuite.java | 104 ++++++------------ 9 files changed, 53 insertions(+), 150 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala deleted file mode 100644 index a25014f704072..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import scala.collection.mutable - -import org.apache.spark.sql.catalyst.types.UserDefinedType - -/** - * Global registry for user-defined types (UDTs). - */ -private[sql] -object UDTRegistry { - /** - * Map: UserType --> UserDefinedType - * - * Internally, we use [[java.lang.Class]] instances for keys in this registry. - */ - val udtRegistry = new mutable.HashMap[Any, UserDefinedType[_]]() - - /** - * Register a user-defined type and its serializer, to allow automatic conversion between - * RDDs of user types and SchemaRDDs. - * If this type has already been registered, this does nothing. - */ - def registerType(udt: UserDefinedType[_]): Unit = { - // TODO: Check to see if type is built-in. Throw exception? - UDTRegistry.udtRegistry(udt.userClass) = udt - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java index fd815b3b6207f..e966aeea1cb23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java @@ -28,8 +28,7 @@ * * WARNING: This annotation will only work if both Java and Scala reflection return the same class * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class - * is enclosed in an object (a singleton). In these cases, the UDT must be registered - * manually. + * is enclosed in an object (a singleton). * * WARNING: UDTs are currently only supported from Scala. */ @@ -38,5 +37,10 @@ @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface SQLUserDefinedType { + + /** + * Returns an instance of the UserDefinedType which can serialize and deserialize the user + * class to and from Catalyst built-in types. + */ Class > udt(); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index ee78125a056f8..4f0ac6ebbb604 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -581,10 +581,7 @@ case class MapType( * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create * a SchemaRDD which has class X in the schema. * - * For SparkSQL to recognize UDTs, the UDT must be registered in - * [[org.apache.spark.sql.catalyst.UDTRegistry]]. This registration can be done either - * explicitly by calling [[org.apache.spark.sql.catalyst.UDTRegistry.registerType()]] - * before using the UDT with SparkSQL, or implicitly by annotating the UDT with + * For SparkSQL to recognize UDTs, the UDT must be annotated with * [[org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType]]. * * The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD. @@ -612,5 +609,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { ("class" -> this.getClass.getName) } + /** + * Class object for the UserType + */ def userClass: java.lang.Class[UserType] } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java index ebe78333bac3a..811cb0e697d81 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java @@ -26,7 +26,7 @@ * To get/create specific data type, users should use singleton objects and factory methods * provided by this class. */ -public abstract class DataType implements Serializable { +public abstract class DataType { /** * Gets the StringType object. @@ -201,19 +201,5 @@ public static StructType createStructType(StructField[] fields) { return new StructType(fields); } -/* - public static org.apache.spark.sql.UserDefinedType - wrapAsScala(UserDefinedType udtType) { - // TODO: Check if we can unwrap instead of wrapping. - return new JavaToScalaUDTWrapper(udtType); - } - -// EDITING HERE: Does this method need to be implemented in Scala in order to avoid exposing Catalyst? - public static UserDefinedType - wrapAsJava(org.apache.spark.sql.UserDefinedType udtType) { - // TODO: Check if we can unwrap instead of wrapping. - return new ScalaToJavaUDTWrapper(udtType); - } -*/ } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java index cc28201b369c3..b751847b464fd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java @@ -25,14 +25,11 @@ * ::DeveloperApi:: * The data type representing User-Defined Types (UDTs). * UDTs may use any other DataType for an underlying representation. - * - * TODO: Do we need to provide DataType#createUserDefinedType methods? */ @DeveloperApi public abstract class UserDefinedType extends DataType implements Serializable { - protected UserDefinedType() { // TODO? - } + protected UserDefinedType() { } @Override public boolean equals(Object o) { @@ -51,5 +48,6 @@ public boolean equals(Object o) { /** Convert a SQL datum to the user type */ public abstract UserType deserialize(Object datum); + /** Class object for the UserType */ public abstract Class userClass(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index cfe4fa22268c4..4a8178e0765bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -103,7 +103,7 @@ class Row(private[spark] val row: ScalaRow) extends Serializable { override def equals(other: Any): Boolean = other match { case that: Row => (that canEqual this) && - row.equals(that.row) + row == that.row // Should this be row.equals(that.row)? case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index df3b0d70b8fd9..81c60e00505c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging import org.apache.spark.rdd.RDD - - import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.{ScalaReflection, trees} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 67ba50f2da708..8d45363a9b6c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -45,6 +45,10 @@ protected[sql] object DataTypeConversions { * Returns the equivalent DataType in Java for the given DataType in Scala. */ def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match { + // Check UDT first since UDTs can override other types + case udtType: UserDefinedType[_] => + UDTWrappers.wrapAsJava(udtType) + case StringType => JDataType.StringType case BinaryType => JDataType.BinaryType case BooleanType => JDataType.BooleanType @@ -67,8 +71,6 @@ protected[sql] object DataTypeConversions { mapType.valueContainsNull) case structType: StructType => JDataType.createStructType( structType.fields.map(asJavaStructField).asJava) - case udtType: UserDefinedType[_] => - UDTWrappers.wrapAsJava(udtType) } /** @@ -86,6 +88,10 @@ protected[sql] object DataTypeConversions { * Returns the equivalent DataType in Scala for the given DataType in Java. */ def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match { + // Check UDT first since UDTs can override other types + case udtType: org.apache.spark.sql.api.java.UserDefinedType[_] => + UDTWrappers.wrapAsScala(udtType) + case stringType: org.apache.spark.sql.api.java.StringType => StringType case binaryType: org.apache.spark.sql.api.java.BinaryType => @@ -124,8 +130,6 @@ protected[sql] object DataTypeConversions { mapType.isValueContainsNull) case structType: org.apache.spark.sql.api.java.StructType => StructType(structType.getFields.map(asScalaStructField)) - case udtType: org.apache.spark.sql.api.java.UserDefinedType[_] => - UDTWrappers.wrapAsScala(udtType) } /** Converts Java objects to catalyst rows / types */ diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java index 28285166bdba0..f6a68b52a4cbf 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java @@ -27,8 +27,8 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; - +import org.apache.spark.sql.MyDenseVector; +import org.apache.spark.sql.MyLabeledPoint; public class JavaUserDefinedTypeSuite implements Serializable { private transient JavaSparkContext javaCtx; @@ -38,7 +38,6 @@ public class JavaUserDefinedTypeSuite implements Serializable { public void setUp() { javaCtx = new JavaSparkContext("local", "JavaUserDefinedTypeSuite"); javaSqlCtx = new JavaSQLContext(javaCtx); - //JavaSQLContext$.MODULE$.registerUDT(new MyDenseVectorUDT()); } @After @@ -48,95 +47,54 @@ public void tearDown() { javaSqlCtx = null; } - // Note: Annotation is awkward since it requires an argument which is a Scala UserDefinedType. - //@SQLUserDefinedType(udt = MyDenseVectorUDT.class) - class MyDenseVector implements Serializable { - - public MyDenseVector(double[] data) { - this.data = data; - } - - public double[] data; - - public boolean equals(Object other) { - if (this == other) return true; - if (other == null || getClass() != other.getClass()) return false; - - MyDenseVector dv = (MyDenseVector) other; - return Arrays.equals(this.data, dv.data); - } - } - - class MyLabeledPoint implements Serializable { - double label; - MyDenseVector features; - public MyLabeledPoint(double label, MyDenseVector features) { - this.label = label; - this.features = features; - } - } + @Test + public void useScalaUDT() { + List points = Arrays.asList( + new MyLabeledPoint(1.0, new MyDenseVector(new double[]{0.1, 1.0})), + new MyLabeledPoint(0.0, new MyDenseVector(new double[]{0.2, 2.0}))); + JavaRDD pointsRDD = javaCtx.parallelize(points); - class MyDenseVectorUDT extends UserDefinedType implements Serializable { + JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(pointsRDD, MyLabeledPoint.class); + schemaRDD.registerTempTable("points"); - /** - * Underlying storage type for this UDT - */ - public DataType sqlType() { - return DataType.createArrayType(DataType.DoubleType, false); + List actualLabelRows = javaSqlCtx.sql("SELECT label FROM points").collect(); + List actualLabels = new LinkedList(); + for (Row r : actualLabelRows) { + actualLabels.add(r.getDouble(0)); } - - /** - * Convert the user type to a SQL datum - */ - public Object serialize(Object obj) { - return ((MyDenseVector) obj).data; + for (MyLabeledPoint lp : points) { + Assert.assertTrue(actualLabels.contains(lp.label())); } - /** - * Convert a SQL datum to the user type - */ - public MyDenseVector deserialize(Object datum) { - return new MyDenseVector((double[]) datum); + List actualFeatureRows = javaSqlCtx.sql("SELECT features FROM points").collect(); + List actualFeatures = new LinkedList(); + for (Row r : actualFeatureRows) { + actualFeatures.add((MyDenseVector)r.get(0)); } - - public Class userClass() { - return MyDenseVector.class; + for (MyLabeledPoint lp : points) { + Assert.assertTrue(actualFeatures.contains(lp.features())); } - } - @Test - public void useScalaUDT() { - List points = Arrays.asList( - new org.apache.spark.sql.MyLabeledPoint(1.0, - new org.apache.spark.sql.MyDenseVector(new double[]{0.1, 1.0})), - new org.apache.spark.sql.MyLabeledPoint(0.0, - new org.apache.spark.sql.MyDenseVector(new double[]{0.2, 2.0}))); - JavaRDD pointsRDD = javaCtx.parallelize(points); - - JavaSchemaRDD schemaRDD = - javaSqlCtx.applySchema(pointsRDD, org.apache.spark.sql.MyLabeledPoint.class); - - schemaRDD.registerTempTable("points"); List actual = javaSqlCtx.sql("SELECT * FROM points").collect(); - List actualPoints = - new LinkedList(); + List actualPoints = + new LinkedList(); for (Row r : actual) { // Note: JavaSQLContext.getSchema switches the ordering of the Row elements // in the MyLabeledPoint case class. - actualPoints.add(new org.apache.spark.sql.MyLabeledPoint( - r.getDouble(1), (org.apache.spark.sql.MyDenseVector)r.get(0))); + actualPoints.add(new MyLabeledPoint( + r.getDouble(1), (MyDenseVector)r.get(0))); } - for (org.apache.spark.sql.MyLabeledPoint lp : points) { + for (MyLabeledPoint lp : points) { Assert.assertTrue(actualPoints.contains(lp)); } /* // THIS FAILS BECAUSE JavaSQLContext.getSchema switches the ordering of the Row elements // in the MyLabeledPoint case class. List expected = new LinkedList(); - expected.add(Row.create(new org.apache.spark.sql.MyLabeledPoint(1.0, - new org.apache.spark.sql.MyDenseVector(new double[]{0.1, 1.0})))); - expected.add(Row.create(new org.apache.spark.sql.MyLabeledPoint(0.0, - new org.apache.spark.sql.MyDenseVector(new double[]{0.2, 2.0})))); + expected.add(Row.create(new MyLabeledPoint(1.0, + new MyDenseVector(new double[]{0.1, 1.0})))); + expected.add(Row.create(new MyLabeledPoint(0.0, + new MyDenseVector(new double[]{0.2, 2.0})))); System.out.println("Expected:"); for (Row r : expected) { System.out.println("r: " + r.toString()); From d0633801f845d8ddb28f33d4e42b32aeea43f110 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 31 Oct 2014 13:04:51 -0700 Subject: [PATCH 39/46] Cleaned up Java UDT Suite, and added warning about element ordering when creating schema from Java Bean --- .../spark/sql/api/java/JavaSQLContext.scala | 11 ++++++- .../api/java/JavaUserDefinedTypeSuite.java | 32 ++----------------- 2 files changed, 12 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 1a7b9864a6f8f..80a417697eb8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -87,6 +87,10 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { /** * Applies a schema to an RDD of Java Beans. + * + * WARNING: The ordering of elements in the schema may differ from Scala. + * If you create a [[org.apache.spark.sql.SchemaRDD]] using [[SQLContext]] + * with the same Java Bean, row elements may be in a different order. */ def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): JavaSchemaRDD = { val attributeSeq = getSchema(beanClass) @@ -193,11 +197,16 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { sqlContext.registerRDDAsTable(rdd.baseSchemaRDD, tableName) } - /** Returns a Catalyst Schema for the given java bean class. */ + /** + * Returns a Catalyst Schema for the given java bean class. + */ protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = { // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. val beanInfo = Introspector.getBeanInfo(beanClass) + // Note: The ordering of elements may differ from when the schema is inferred in Scala. + // This is because beanInfo.getPropertyDescriptors gives no guarantees about + // element ordering. val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") fields.map { property => val (dataType, nullable) = property.getPropertyType match { diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java index f6a68b52a4cbf..0caa8219a63e9 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java @@ -75,42 +75,14 @@ public void useScalaUDT() { Assert.assertTrue(actualFeatures.contains(lp.features())); } - List actual = javaSqlCtx.sql("SELECT * FROM points").collect(); + List actual = javaSqlCtx.sql("SELECT label, features FROM points").collect(); List actualPoints = new LinkedList(); for (Row r : actual) { - // Note: JavaSQLContext.getSchema switches the ordering of the Row elements - // in the MyLabeledPoint case class. - actualPoints.add(new MyLabeledPoint( - r.getDouble(1), (MyDenseVector)r.get(0))); + actualPoints.add(new MyLabeledPoint(r.getDouble(0), (MyDenseVector)r.get(1))); } for (MyLabeledPoint lp : points) { Assert.assertTrue(actualPoints.contains(lp)); } - /* - // THIS FAILS BECAUSE JavaSQLContext.getSchema switches the ordering of the Row elements - // in the MyLabeledPoint case class. - List expected = new LinkedList(); - expected.add(Row.create(new MyLabeledPoint(1.0, - new MyDenseVector(new double[]{0.1, 1.0})))); - expected.add(Row.create(new MyLabeledPoint(0.0, - new MyDenseVector(new double[]{0.2, 2.0})))); - System.out.println("Expected:"); - for (Row r : expected) { - System.out.println("r: " + r.toString()); - for (int i = 0; i < r.length(); ++i) { - System.out.println(" r[i]: " + r.get(i).toString()); - } - } - - System.out.println("Actual:"); - for (Row r : actual) { - System.out.println("r: " + r.toString()); - for (int i = 0; i < r.length(); ++i) { - System.out.println(" r[i]: " + r.get(i).toString()); - } - Assert.assertTrue(expected.contains(r)); - } - */ } } From 30ce5b2d4d469bb35e92075f599f897841636269 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 1 Nov 2014 22:47:30 -0700 Subject: [PATCH 40/46] updates based on code review --- .../scala/org/apache/spark/sql/api/java/JavaSQLContext.scala | 5 ++--- .../src/main/scala/org/apache/spark/sql/api/java/Row.scala | 2 +- .../org/apache/spark/sql/parquet/ParquetConverter.scala | 1 - .../org/apache/spark/sql/parquet/ParquetTableSupport.scala | 1 - .../scala/org/apache/spark/sql/parquet/ParquetTypes.scala | 1 - 5 files changed, 3 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 80a417697eb8b..6e4e651ba2986 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -88,9 +88,8 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { /** * Applies a schema to an RDD of Java Beans. * - * WARNING: The ordering of elements in the schema may differ from Scala. - * If you create a [[org.apache.spark.sql.SchemaRDD]] using [[SQLContext]] - * with the same Java Bean, row elements may be in a different order. + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, + * SELECT * queries will return the columns in an undefined order. */ def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): JavaSchemaRDD = { val attributeSeq = getSchema(beanClass) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index 4a8178e0765bd..401798e317e96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -103,7 +103,7 @@ class Row(private[spark] val row: ScalaRow) extends Serializable { override def equals(other: Any): Boolean = other match { case that: Row => (that canEqual this) && - row == that.row // Should this be row.equals(that.row)? + row == that.row case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 1097a4e52cb39..1bbb66aaa19a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -77,7 +77,6 @@ private[sql] object CatalystConverter { parent: CatalystConverter): Converter = { val fieldType: DataType = field.dataType fieldType match { - // Check UDT first since UDTs can override other types case udt: UserDefinedType[_] => { createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 81ec5f6e5427c..aaa970cb93510 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -174,7 +174,6 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { private[parquet] def writeValue(schema: DataType, value: Any): Unit = { if (value != null) { schema match { - // Check UDT first since UDTs can override other types case t: UserDefinedType[_] => writeValue(t.sqlType, value) case t @ ArrayType(_, _) => writeArray( t, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 67882bbc80bf2..fa37d1f2ae7e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -290,7 +290,6 @@ private[parquet] object ParquetTypesConverter extends Logging { builder.named(name) }.getOrElse { ctype match { - // Check UDT first since UDTs can override other types case udt: UserDefinedType[_] => { fromDataType(udt.sqlType, name, nullable, inArray) } From 5817b2b79a140f77f7eedfd751217340f47d9e9c Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 2 Nov 2014 00:50:22 -0700 Subject: [PATCH 41/46] style edits --- .../org/apache/spark/sql/parquet/ParquetTableSupport.scala | 2 +- .../org/apache/spark/sql/types/util/DataTypeConversions.scala | 2 -- .../src/test/scala/org/apache/spark/sql/json/JsonSuite.scala | 4 ++-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index aaa970cb93510..7bc249660053a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.parquet import java.util.{HashMap => JHashMap} import org.apache.hadoop.conf.Configuration -import org.apache.spark.sql.catalyst.types.decimal.Decimal import parquet.column.ParquetProperties import parquet.hadoop.ParquetOutputFormat import parquet.hadoop.api.ReadSupport.ReadContext @@ -31,6 +30,7 @@ import parquet.schema.MessageType import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /** * A `parquet.io.api.RecordMaterializer` for Rows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 8d45363a9b6c4..1bc15146f0fe8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -45,7 +45,6 @@ protected[sql] object DataTypeConversions { * Returns the equivalent DataType in Java for the given DataType in Scala. */ def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match { - // Check UDT first since UDTs can override other types case udtType: UserDefinedType[_] => UDTWrappers.wrapAsJava(udtType) @@ -88,7 +87,6 @@ protected[sql] object DataTypeConversions { * Returns the equivalent DataType in Scala for the given DataType in Java. */ def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match { - // Check UDT first since UDTs can override other types case udtType: org.apache.spark.sql.api.java.UserDefinedType[_] => UDTWrappers.wrapAsScala(udtType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 1cb6c23c58f36..2b82d4db054bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -233,8 +233,8 @@ class JsonSuite extends QueryTest { StructField("field2", StringType, true) :: StructField("field3", StringType, true) :: Nil), false), true) :: StructField("struct", StructType( - StructField("field1", BooleanType, true) :: - StructField("field2", DecimalType.Unlimited, true) :: Nil), true) :: + StructField("field1", BooleanType, true) :: + StructField("field2", DecimalType.Unlimited, true) :: Nil), true) :: StructField("structWithArrayFields", StructType( StructField("field1", ArrayType(IntegerType, false), true) :: StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil) From e13cd8ae5a5a9fae8b0dee1d2f6d890328b13210 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 2 Nov 2014 14:55:12 -0800 Subject: [PATCH 42/46] Removed Vector UDTs --- .../apache/spark/mllib/linalg/Vectors.scala | 111 ------------------ .../spark/sql/catalyst/types/dataTypes.scala | 2 +- 2 files changed, 1 insertion(+), 112 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 9aaafa34f8c03..5070032e49809 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -27,11 +27,9 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException import org.apache.spark.mllib.util.NumericParser -import org.apache.spark.sql.catalyst.UDTRegistry import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.Row /** * Represents a numeric vector, whose index type is Int and value type is Double. @@ -86,12 +84,6 @@ sealed trait Vector extends Serializable { */ object Vectors { - // Note: Explicit registration is only needed for Vector and SparseVector; - // the annotation works for DenseVector. - UDTRegistry.registerType(scala.reflect.runtime.universe.typeOf[Vector], new VectorUDT()) - UDTRegistry.registerType(scala.reflect.runtime.universe.typeOf[SparseVector], - new SparseVectorUDT()) - /** * Creates a dense vector from its values. */ @@ -202,7 +194,6 @@ object Vectors { /** * A dense vector represented by a value array. */ -@SQLUserDefinedType(udt = classOf[DenseVectorUDT]) class DenseVector(val values: Array[Double]) extends Vector { override def size: Int = values.length @@ -254,105 +245,3 @@ class SparseVector( private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) } - -/** - * User-defined type for [[Vector]] which allows easy interaction with SQL - * via [[org.apache.spark.sql.SchemaRDD]]. - */ -private[spark] class VectorUDT extends UserDefinedType[Vector] { - - /** - * vectorType: 0 = dense, 1 = sparse. - * dense, sparse: One element holds the vector, and the other is null. - */ - override def sqlType: StructType = StructType(Seq( - StructField("vectorType", ByteType, nullable = false), - StructField("dense", new DenseVectorUDT, nullable = true), - StructField("sparse", new SparseVectorUDT, nullable = true))) - - override def serialize(obj: Any): Row = { - val row = new GenericMutableRow(3) - obj match { - case v: DenseVector => - row.setByte(0, 0) - row.update(1, new DenseVectorUDT().serialize(obj)) - row.setNullAt(2) - case v: SparseVector => - row.setByte(0, 1) - row.setNullAt(1) - row.update(2, new SparseVectorUDT().serialize(obj)) - } - row - } - - override def deserialize(datum: Any): Vector = { - datum match { - case row: Row => - require(row.length == 3, - s"VectorUDT.deserialize given row with length ${row.length} but requires length == 3") - val vectorType = row.getByte(0) - vectorType match { - case 0 => - new DenseVectorUDT().deserialize(row.getAs[Row](1)) - case 1 => - new SparseVectorUDT().deserialize(row.getAs[Row](2)) - } - } - } -} - -/** - * User-defined type for [[DenseVector]] which allows easy interaction with SQL - * via [[org.apache.spark.sql.SchemaRDD]]. - */ -private[spark] class DenseVectorUDT extends UserDefinedType[DenseVector] { - - override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - - override def serialize(obj: Any): Seq[Double] = { - obj match { - case v: DenseVector => - v.values.toSeq - } - } - - override def deserialize(datum: Any): DenseVector = { - datum match { - case values: Seq[_] => - new DenseVector(values.asInstanceOf[Seq[Double]].toArray) - } - } -} - -/** - * User-defined type for [[SparseVector]] which allows easy interaction with SQL - * via [[org.apache.spark.sql.SchemaRDD]]. - */ -private[spark] class SparseVectorUDT extends UserDefinedType[SparseVector] { - - override def sqlType: StructType = StructType(Seq( - StructField("size", IntegerType, nullable = false), - StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = false), - StructField("values", ArrayType(DoubleType, containsNull = false), nullable = false))) - - override def serialize(obj: Any): Row = obj match { - case v: SparseVector => - val row: GenericMutableRow = new GenericMutableRow(3) - row.setInt(0, v.size) - row.update(1, v.indices.toSeq) - row.update(2, v.values.toSeq) - row - } - - override def deserialize(datum: Any): SparseVector = { - datum match { - case row: Row => - require(row.length == 3, - s"SparseVectorUDT.deserialize given row with length ${row.length} but expect 3.") - val vSize = row.getInt(0) - val indices = row.getAs[Seq[Int]](1).toArray - val values = row.getAs[Seq[Double]](2).toArray - new SparseVector(vSize, indices, values) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 4f0ac6ebbb604..df35577b659c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -30,7 +30,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.{UDTRegistry, ScalaReflectionLock} +import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Row} import org.apache.spark.sql.catalyst.types.decimal._ import org.apache.spark.sql.catalyst.util.Metadata From f3c72feb144d399bc2f02d265f3335a1dbf019b1 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 2 Nov 2014 15:30:22 -0800 Subject: [PATCH 43/46] Fixing merge --- .../spark/examples/mllib/DatasetExample.scala | 121 ------------------ mllib/pom.xml | 5 - .../apache/spark/mllib/linalg/Vectors.scala | 5 +- .../apache/spark/mllib/rdd/DatasetSuite.scala | 84 ------------ .../apache/spark/sql/api/java/DataType.java | 2 - .../scala/org/apache/spark/sql/package.scala | 1 - 6 files changed, 1 insertion(+), 217 deletions(-) delete mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala delete mode 100644 mllib/src/test/scala/org/apache/spark/mllib/rdd/DatasetSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala deleted file mode 100644 index f8d83f4ec7327..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib - -import java.io.File - -import com.google.common.io.Files -import scopt.OptionParser - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} - -/** - * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with - * {{{ - * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] - * }}} - * If you use it as a template to create your own app, please use `spark-submit` to submit your app. - */ -object DatasetExample { - - case class Params( - input: String = "data/mllib/sample_libsvm_data.txt", - dataFormat: String = "libsvm") extends AbstractParams[Params] - - def main(args: Array[String]) { - val defaultParams = Params() - - val parser = new OptionParser[Params]("DatasetExample") { - head("Dataset: an example app using SchemaRDD as a Dataset for ML.") - opt[String]("input") - .text(s"input path to dataset") - .action((x, c) => c.copy(input = x)) - opt[String]("dataFormat") - .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") - .action((x, c) => c.copy(input = x)) - checkConfig { params => - success - } - } - - parser.parse(args, defaultParams).map { params => - run(params) - }.getOrElse { - sys.exit(1) - } - } - - def run(params: Params) { - - val conf = new SparkConf().setAppName(s"DatasetExample with $params") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - import sqlContext._ // for implicit conversions - - // Load input data - val origData: RDD[LabeledPoint] = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.input) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input) - } - println(s"Loaded ${origData.count()} instances from file: ${params.input}") - - // Convert input data to SchemaRDD explicitly. - val schemaRDD: SchemaRDD = origData - println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}") - println(s"Converted to SchemaRDD with ${schemaRDD.count()} records") - - // Select columns, using implicit conversion to SchemaRDD. - val labelsSchemaRDD: SchemaRDD = origData.select('label) - val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v } - val numLabels = labels.count() - val meanLabel = labels.fold(0.0)(_ + _) / numLabels - println(s"Selected label column with average value $meanLabel") - - val featuresSchemaRDD: SchemaRDD = origData.select('features) - val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v } - val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( - (summary, feat) => summary.add(feat), - (sum1, sum2) => sum1.merge(sum2)) - println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") - - val tmpDir = Files.createTempDir() - tmpDir.deleteOnExit() - val outputDir = new File(tmpDir, "dataset").toString - println(s"Saving to $outputDir as Parquet file.") - schemaRDD.saveAsParquetFile(outputDir) - - println(s"Loading Parquet file with UDT from $outputDir.") - val newDataset = sqlContext.parquetFile(outputDir) - - println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") - val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } - val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( - (summary, feat) => summary.add(feat), - (sum1, sum2) => sum1.merge(sum2)) - println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") - - sc.stop() - } - -} diff --git a/mllib/pom.xml b/mllib/pom.xml index 0335409c9123d..fb7239e779aae 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -40,11 +40,6 @@ spark-core_${scala.binary.version} ${project.version} - - org.apache.spark - spark-sql_${scala.binary.version} - ${project.version} - org.apache.spark spark-streaming_${scala.binary.version} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 5070032e49809..6af225b7f49f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -25,11 +25,8 @@ import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} -import org.apache.spark.SparkException import org.apache.spark.mllib.util.NumericParser -import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.SparkException /** * Represents a numeric vector, whose index type is Int and value type is Double. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/DatasetSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/DatasetSuite.scala deleted file mode 100644 index 784ed2b3cc37f..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/DatasetSuite.scala +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.rdd - -import org.scalatest.FunSuite - -import org.apache.spark.mllib.linalg.{Vectors, DenseVector, SparseVector, Vector} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.LocalSparkContext -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} - - -private case class DenseVectorLabeledPoint(label: Double, features: DenseVector) -private case class SparseVectorLabeledPoint(label: Double, features: SparseVector) - -class DatasetSuite extends FunSuite with LocalSparkContext { - - test("SQL and Vector") { - val sqlContext = new SQLContext(sc) - import sqlContext._ - val points = Seq( - LabeledPoint(1.0, Vectors.dense(Array(1.0, 2.0))), - LabeledPoint(2.0, Vectors.dense(Array(3.0, 6.0))), - LabeledPoint(3.0, Vectors.dense(Array(3.0, 3.0))), - LabeledPoint(4.0, Vectors.dense(Array(4.0, 8.0)))) - val data: RDD[LabeledPoint] = sc.parallelize(points) - val labels = data.select('label).map { case Row(label: Double) => label }.collect().toSet - assert(labels == Set(1.0, 2.0, 3.0, 4.0)) - val features = data.select('features).map { case Row(features: Vector) => features }.collect() - assert(features.size === 4) - assert(features.forall(_.size == 2)) - } - - test("SQL and DenseVector") { - val sqlContext = new SQLContext(sc) - import sqlContext._ - val points = Seq( - DenseVectorLabeledPoint(1.0, new DenseVector(Array(1.0, 2.0))), - DenseVectorLabeledPoint(2.0, new DenseVector(Array(3.0, 6.0))), - DenseVectorLabeledPoint(3.0, new DenseVector(Array(3.0, 3.0))), - DenseVectorLabeledPoint(4.0, new DenseVector(Array(4.0, 8.0)))) - val data: RDD[DenseVectorLabeledPoint] = sc.parallelize(points) - val labels = data.select('label).map { case Row(label: Double) => label }.collect().toSet - assert(labels == Set(1.0, 2.0, 3.0, 4.0)) - val features = - data.select('features).map { case Row(features: DenseVector) => features }.collect() - assert(features.size === 4) - assert(features.forall(_.size == 2)) - } - - test("SQL and SparseVector") { - val sqlContext = new SQLContext(sc) - import sqlContext._ - val vSize = 2 - val points = Seq( - SparseVectorLabeledPoint(1.0, new SparseVector(vSize, Array(0, 1), Array(1.0, 2.0))), - SparseVectorLabeledPoint(2.0, new SparseVector(vSize, Array(0, 1), Array(3.0, 6.0))), - SparseVectorLabeledPoint(3.0, new SparseVector(vSize, Array(0, 1), Array(3.0, 3.0))), - SparseVectorLabeledPoint(4.0, new SparseVector(vSize, Array(0, 1), Array(4.0, 8.0)))) - val data: RDD[SparseVectorLabeledPoint] = sc.parallelize(points) - val labels = data.select('label).map { case Row(label: Double) => label }.collect().toSet - assert(labels == Set(1.0, 2.0, 3.0, 4.0)) - val features = - data.select('features).map { case Row(features: SparseVector) => features }.collect() - assert(features.size === 4) - assert(features.forall(_.size == 2)) - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java index 811cb0e697d81..c38354039d686 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.api.java; -import java.io.Serializable; import java.util.*; /** @@ -201,5 +200,4 @@ public static StructType createStructType(StructField[] fields) { return new StructType(fields); } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index a4c01a50d2ebe..05926a24c5307 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -451,5 +451,4 @@ package object sql { * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. */ type MetadataBuilder = catalyst.util.MetadataBuilder - } From 6cc434d3c48f0dcc60d75fc59fa5fb1c2691ad2d Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 2 Nov 2014 16:49:36 -0800 Subject: [PATCH 44/46] Recursively convert rows. --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 1de6f589fc455..9cda373623cb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -67,6 +67,7 @@ object ScalaReflection { case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) } + case (r: Row, s: StructType) => convertRowToScala(r, s) case (d: Decimal, _: DecimalType) => d.toBigDecimal case (other, _) => other } From 46a3aee35d8b2bc04ebaacec946cac8f30a7ec92 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 2 Nov 2014 16:49:56 -0800 Subject: [PATCH 45/46] Slightly easier to read test output. --- .../test/scala/org/apache/spark/sql/json/JsonSuite.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index b329d3df5a9dd..cade244f7ac39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ @@ -292,8 +291,8 @@ class JsonSuite extends QueryTest { // Access a struct and fields inside of it. checkAnswer( sql("select struct, struct.field1, struct.field2 from jsonTable"), - ( - Seq(true, BigDecimal("92233720368547758070")), + Row( + Row(true, BigDecimal("92233720368547758070")), true, BigDecimal("92233720368547758070")) :: Nil ) From 7ccfc0dbed791e2b3b9e646afd65a5d50c43f4f0 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 2 Nov 2014 16:59:08 -0800 Subject: [PATCH 46/46] remove println --- .../scala/org/apache/spark/sql/execution/SparkStrategies.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 292478663c378..cc7e0c05ffc70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -278,7 +278,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil case logical.LocalRelation(output, data) => val nPartitions = if (data.isEmpty) 1 else numPartitions - println(s"BasicOperators.apply: creating schema from attributes: $output") PhysicalRDD( output, RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions),