diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 88fb21eb0a6eb..eb0780a6ddea9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -96,17 +96,28 @@ object ScalaReflection { def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { // TODO: Why does this not need to flatMap stuff? Does it not support nesting? // TODO: What about Option and Product? - case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType)) - case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => + case (s: Seq[_], arrayType: ArrayType) => + println("convertToScala: Seq") + s.map(convertToScala(_, arrayType.elementType)) + case (m: Map[_, _], mapType: MapType) => + println("convertToScala: Map") + m.map { case (k, v) => convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) } - case d: Decimal => d.toBigDecimal - case (udt: Row, udtType: UserDefinedType[_]) => udtType.deserialize(udt) - case (other, _) => other + case (d: Decimal, DecimalType) => d.toBigDecimal + case (udt: Row, udtType: UserDefinedType[_]) => + println("convertToScala: udt") + udtType.deserialize(udt) + case (other, _) => + println("convertToScala: other") + other } def convertRowToScala(r: Row, schema: StructType): Row = { - new GenericRow(r.toArray.map(convertToScala(_, schema))) + println("convertRowToScala called with schema: $schema") + new GenericRow( + r.zip(schema.fields.map(_.dataType)) + .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray) } /** Returns a Sequence of attributes for the given case class type. */ @@ -129,9 +140,6 @@ object ScalaReflection { def schemaFor(tpe: `Type`, udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = { println(s"schemaFor: $tpe") tpe match { - case t if udtRegistry.contains(t) => - println(s" schemaFor T matched udtRegistry") - Schema(udtRegistry(t), nullable = true) case t if t <:< typeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t Schema(schemaFor(optType, udtRegistry).dataType, nullable = true) @@ -178,6 +186,9 @@ object ScalaReflection { case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + case t if udtRegistry.contains(t) => + println(s" schemaFor T matched udtRegistry") + Schema(udtRegistry(t), nullable = true) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f838eeae9e753..e66aa22144cdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -106,7 +106,9 @@ class SQLContext(@transient val sparkContext: SparkContext) println(s"createSchemaRDD called") val attributeSeq = ScalaReflection.attributesFor[A](udtRegistry) val schema = StructType.fromAttributes(attributeSeq) - new SchemaRDD(this, LogicalRDD(attributeSeq, RDDConversions.productToRowRdd(rdd, schema))(self)) + val rowRDD = RDDConversions.productToRowRdd(rdd, schema) + println("done with productToRowRdd") + new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index f0cee63947721..80bfa8771e22c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -125,7 +125,7 @@ class SchemaRDD( * * @group schema */ - def schema: StructType = queryExecution.analyzed.schema + val schema: StructType = queryExecution.analyzed.schema // ======================================================================= // Query DSL @@ -148,6 +148,7 @@ class SchemaRDD( case (ne: NamedExpression, _) => ne case (e, i) => Alias(e, s"c$i")() } + assert(sqlContext != null) new SchemaRDD(sqlContext, Project(aliases, logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 455692fa5fd02..15516afb95504 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -29,8 +29,6 @@ private[sql] trait SchemaRDDLike { @transient val sqlContext: SQLContext @transient val baseLogicalPlan: LogicalPlan - assert(sqlContext != null) - private[sql] def baseSchemaRDD: SchemaRDD /** @@ -51,10 +49,7 @@ private[sql] trait SchemaRDDLike { */ @transient @DeveloperApi - lazy val queryExecution = { - assert(sqlContext != null) - sqlContext.executePlan(baseLogicalPlan) - } + lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan) @transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match { // For various commands (like DDL) and queries with side effects, we force query optimization to diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index c1ae1dd71720e..0d6843ed112f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -82,7 +82,7 @@ class UserDefinedTypeSuite extends QueryTest { println("Done converting to SchemaRDD") println("testing labels") - val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v} + val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() assert(labelsArrays.size === 2) assert(labelsArrays.contains(1.0)) @@ -90,7 +90,7 @@ class UserDefinedTypeSuite extends QueryTest { println("testing features") val features: RDD[DenseVector] = - pointsRDD.select('features).map { case Row(v: DenseVector) => v} + pointsRDD.select('features).map { case Row(v: DenseVector) => v } val featuresArrays: Array[DenseVector] = features.collect() assert(featuresArrays.size === 2) assert(featuresArrays.contains(new DenseVector(Array(0.1, 1.0))))