From 65e34fa7be911c625975e85e57bba4fd8a143661 Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Fri, 6 Jan 2017 15:05:20 +0800 Subject: [PATCH] [SPARK-16792][SQL] Dataset containing a Case Class with a List type causes a CompileException (converting sequence to list) ## What changes were proposed in this pull request? Added a `to` call at the end of the code generated by `ScalaReflection.deserializerFor` if the requested type is not a supertype of `WrappedArray[_]` that uses `CanBuildFrom[_, _, _]` to convert result into an arbitrary subtype of `Seq[_]`. Care was taken to preserve the original deserialization where it is possible to avoid the overhead of conversion in cases where it is not needed `ScalaReflection.serializerFor` could already be used to serialize any `Seq[_]` so it was not altered `SQLImplicits` had to be altered and new implicit encoders added to permit serialization of other sequence types Also fixes [SPARK-16815] Dataset[List[T]] leads to ArrayStoreException ## How was this patch tested? ```bash ./build/mvn -DskipTests clean package && ./dev/run-tests ``` Also manual execution of the following sets of commands in the Spark shell: ```scala case class TestCC(key: Int, letters: List[String]) val ds1 = sc.makeRDD(Seq( (List("D")), (List("S","H")), (List("F","H")), (List("D","L","L")) )).map(x=>(x.length,x)).toDF("key","letters").as[TestCC] val test1=ds1.map{_.key} test1.show ``` ```scala case class X(l: List[String]) spark.createDataset(Seq(List("A"))).map(X).show ``` ```scala spark.sqlContext.createDataset(sc.parallelize(List(1) :: Nil)).collect ``` After adding arbitrary sequence support also tested with the following commands: ```scala case class QueueClass(q: scala.collection.immutable.Queue[Int]) spark.createDataset(Seq(List(1,2,3))).map(x => QueueClass(scala.collection.immutable.Queue(x: _*))).map(_.q.dequeue).collect ``` Author: Michal Senkyr Closes #16240 from michalsenkyr/sql-caseclass-list-fix. --- .../spark/sql/catalyst/ScalaReflection.scala | 40 +++++- .../sql/catalyst/ScalaReflectionSuite.scala | 31 +++++ .../org/apache/spark/sql/SQLImplicits.scala | 115 ++++++++++++++---- .../spark/sql/DatasetPrimitiveSuite.scala | 67 ++++++++++ 4 files changed, 231 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index ad218cf88de23..7f7dd51aa2650 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -312,12 +312,50 @@ object ScalaReflection extends ScalaReflection { "array", ObjectType(classOf[Array[Any]])) - StaticInvoke( + val wrappedArray = StaticInvoke( scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", array :: Nil) + if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) { + wrappedArray + } else { + // Convert to another type using `to` + val cls = mirror.runtimeClass(t.typeSymbol.asClass) + import scala.collection.generic.CanBuildFrom + import scala.reflect.ClassTag + + // Some canBuildFrom methods take an implicit ClassTag parameter + val cbfParams = try { + cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]]) + StaticInvoke( + ClassTag.getClass, + ObjectType(classOf[ClassTag[_]]), + "apply", + StaticInvoke( + cls, + ObjectType(classOf[Class[_]]), + "getClass" + ) :: Nil + ) :: Nil + } catch { + case _: NoSuchMethodException => Nil + } + + Invoke( + wrappedArray, + "to", + ObjectType(cls), + StaticInvoke( + cls, + ObjectType(classOf[CanBuildFrom[_, _, _]]), + "canBuildFrom", + cbfParams + ) :: Nil + ) + } + case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 43b6afd9ad896..650a35398f3e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -291,6 +291,37 @@ class ScalaReflectionSuite extends SparkFunSuite { .cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData])) } + test("SPARK 16792: Get correct deserializer for List[_]") { + val listDeserializer = deserializerFor[List[Int]] + assert(listDeserializer.dataType == ObjectType(classOf[List[_]])) + } + + test("serialize and deserialize arbitrary sequence types") { + import scala.collection.immutable.Queue + val queueSerializer = serializerFor[Queue[Int]](BoundReference( + 0, ObjectType(classOf[Queue[Int]]), nullable = false)) + assert(queueSerializer.dataType.head.dataType == + ArrayType(IntegerType, containsNull = false)) + val queueDeserializer = deserializerFor[Queue[Int]] + assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]])) + + import scala.collection.mutable.ArrayBuffer + val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference( + 0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false)) + assert(arrayBufferSerializer.dataType.head.dataType == + ArrayType(IntegerType, containsNull = false)) + val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]] + assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) + + // Check whether conversion is skipped when using WrappedArray[_] supertype + // (would otherwise needlessly add overhead) + import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke + val seqDeserializer = deserializerFor[Seq[Int]] + assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject == + scala.collection.mutable.WrappedArray.getClass) + assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make") + } + private val dataTypeForComplexData = dataTypeFor[ComplexData] private val typeOfComplexData = typeOf[ComplexData] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 872a78b578d28..2caf723669f63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder * @since 1.6.0 */ @InterfaceStability.Evolving -abstract class SQLImplicits { +abstract class SQLImplicits extends LowPrioritySQLImplicits { protected def _sqlContext: SQLContext @@ -45,9 +45,6 @@ abstract class SQLImplicits { } } - /** @since 1.6.0 */ - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T] - // Primitives /** @since 1.6.0 */ @@ -112,33 +109,96 @@ abstract class SQLImplicits { // Seqs - /** @since 1.6.1 */ - implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder() + /** + * @since 1.6.1 + * @deprecated use [[newIntSequenceEncoder]] + */ + def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder() - /** @since 1.6.1 */ - implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder() + /** + * @since 1.6.1 + * @deprecated use [[newLongSequenceEncoder]] + */ + def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder() - /** @since 1.6.1 */ - implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder() + /** + * @since 1.6.1 + * @deprecated use [[newDoubleSequenceEncoder]] + */ + def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder() - /** @since 1.6.1 */ - implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder() + /** + * @since 1.6.1 + * @deprecated use [[newFloatSequenceEncoder]] + */ + def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder() - /** @since 1.6.1 */ - implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder() + /** + * @since 1.6.1 + * @deprecated use [[newByteSequenceEncoder]] + */ + def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder() - /** @since 1.6.1 */ - implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder() + /** + * @since 1.6.1 + * @deprecated use [[newShortSequenceEncoder]] + */ + def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder() - /** @since 1.6.1 */ - implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder() + /** + * @since 1.6.1 + * @deprecated use [[newBooleanSequenceEncoder]] + */ + def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder() - /** @since 1.6.1 */ - implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder() + /** + * @since 1.6.1 + * @deprecated use [[newStringSequenceEncoder]] + */ + def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder() - /** @since 1.6.1 */ + /** + * @since 1.6.1 + * @deprecated use [[newProductSequenceEncoder]] + */ implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder() + /** @since 2.2.0 */ + implicit def newIntSequenceEncoder[T <: Seq[Int] : TypeTag]: Encoder[T] = + ExpressionEncoder() + + /** @since 2.2.0 */ + implicit def newLongSequenceEncoder[T <: Seq[Long] : TypeTag]: Encoder[T] = + ExpressionEncoder() + + /** @since 2.2.0 */ + implicit def newDoubleSequenceEncoder[T <: Seq[Double] : TypeTag]: Encoder[T] = + ExpressionEncoder() + + /** @since 2.2.0 */ + implicit def newFloatSequenceEncoder[T <: Seq[Float] : TypeTag]: Encoder[T] = + ExpressionEncoder() + + /** @since 2.2.0 */ + implicit def newByteSequenceEncoder[T <: Seq[Byte] : TypeTag]: Encoder[T] = + ExpressionEncoder() + + /** @since 2.2.0 */ + implicit def newShortSequenceEncoder[T <: Seq[Short] : TypeTag]: Encoder[T] = + ExpressionEncoder() + + /** @since 2.2.0 */ + implicit def newBooleanSequenceEncoder[T <: Seq[Boolean] : TypeTag]: Encoder[T] = + ExpressionEncoder() + + /** @since 2.2.0 */ + implicit def newStringSequenceEncoder[T <: Seq[String] : TypeTag]: Encoder[T] = + ExpressionEncoder() + + /** @since 2.2.0 */ + implicit def newProductSequenceEncoder[T <: Seq[Product] : TypeTag]: Encoder[T] = + ExpressionEncoder() + // Arrays /** @since 1.6.1 */ @@ -193,3 +253,16 @@ abstract class SQLImplicits { implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) } + +/** + * Lower priority implicit methods for converting Scala objects into [[Dataset]]s. + * Conflicting implicits are placed here to disambiguate resolution. + * + * Reasons for including specific implicits: + * newProductEncoder - to disambiguate for [[List]]s which are both [[Seq]] and [[Product]] + */ +trait LowPrioritySQLImplicits { + /** @since 1.6.0 */ + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T] + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index f8d4c61967f95..6b50cb3e48c76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -17,10 +17,21 @@ package org.apache.spark.sql +import scala.collection.immutable.Queue +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.test.SharedSQLContext case class IntClass(value: Int) +case class SeqClass(s: Seq[Int]) + +case class ListClass(l: List[Int]) + +case class QueueClass(q: Queue[Int]) + +case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass) + package object packageobject { case class PackageClass(value: Int) } @@ -130,6 +141,62 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1))) } + test("arbitrary sequences") { + checkDataset(Seq(Queue(1)).toDS(), Queue(1)) + checkDataset(Seq(Queue(1.toLong)).toDS(), Queue(1.toLong)) + checkDataset(Seq(Queue(1.toDouble)).toDS(), Queue(1.toDouble)) + checkDataset(Seq(Queue(1.toFloat)).toDS(), Queue(1.toFloat)) + checkDataset(Seq(Queue(1.toByte)).toDS(), Queue(1.toByte)) + checkDataset(Seq(Queue(1.toShort)).toDS(), Queue(1.toShort)) + checkDataset(Seq(Queue(true)).toDS(), Queue(true)) + checkDataset(Seq(Queue("test")).toDS(), Queue("test")) + checkDataset(Seq(Queue(Tuple1(1))).toDS(), Queue(Tuple1(1))) + + checkDataset(Seq(ArrayBuffer(1)).toDS(), ArrayBuffer(1)) + checkDataset(Seq(ArrayBuffer(1.toLong)).toDS(), ArrayBuffer(1.toLong)) + checkDataset(Seq(ArrayBuffer(1.toDouble)).toDS(), ArrayBuffer(1.toDouble)) + checkDataset(Seq(ArrayBuffer(1.toFloat)).toDS(), ArrayBuffer(1.toFloat)) + checkDataset(Seq(ArrayBuffer(1.toByte)).toDS(), ArrayBuffer(1.toByte)) + checkDataset(Seq(ArrayBuffer(1.toShort)).toDS(), ArrayBuffer(1.toShort)) + checkDataset(Seq(ArrayBuffer(true)).toDS(), ArrayBuffer(true)) + checkDataset(Seq(ArrayBuffer("test")).toDS(), ArrayBuffer("test")) + checkDataset(Seq(ArrayBuffer(Tuple1(1))).toDS(), ArrayBuffer(Tuple1(1))) + } + + test("sequence and product combinations") { + // Case classes + checkDataset(Seq(SeqClass(Seq(1))).toDS(), SeqClass(Seq(1))) + checkDataset(Seq(Seq(SeqClass(Seq(1)))).toDS(), Seq(SeqClass(Seq(1)))) + checkDataset(Seq(List(SeqClass(Seq(1)))).toDS(), List(SeqClass(Seq(1)))) + checkDataset(Seq(Queue(SeqClass(Seq(1)))).toDS(), Queue(SeqClass(Seq(1)))) + + checkDataset(Seq(ListClass(List(1))).toDS(), ListClass(List(1))) + checkDataset(Seq(Seq(ListClass(List(1)))).toDS(), Seq(ListClass(List(1)))) + checkDataset(Seq(List(ListClass(List(1)))).toDS(), List(ListClass(List(1)))) + checkDataset(Seq(Queue(ListClass(List(1)))).toDS(), Queue(ListClass(List(1)))) + + checkDataset(Seq(QueueClass(Queue(1))).toDS(), QueueClass(Queue(1))) + checkDataset(Seq(Seq(QueueClass(Queue(1)))).toDS(), Seq(QueueClass(Queue(1)))) + checkDataset(Seq(List(QueueClass(Queue(1)))).toDS(), List(QueueClass(Queue(1)))) + checkDataset(Seq(Queue(QueueClass(Queue(1)))).toDS(), Queue(QueueClass(Queue(1)))) + + val complex = ComplexClass(SeqClass(Seq(1)), ListClass(List(2)), QueueClass(Queue(3))) + checkDataset(Seq(complex).toDS(), complex) + checkDataset(Seq(Seq(complex)).toDS(), Seq(complex)) + checkDataset(Seq(List(complex)).toDS(), List(complex)) + checkDataset(Seq(Queue(complex)).toDS(), Queue(complex)) + + // Tuples + checkDataset(Seq(Seq(1) -> Seq(2)).toDS(), Seq(1) -> Seq(2)) + checkDataset(Seq(List(1) -> Queue(2)).toDS(), List(1) -> Queue(2)) + checkDataset(Seq(List(Seq("test1") -> List(Queue("test2")))).toDS(), + List(Seq("test1") -> List(Queue("test2")))) + + // Complex + checkDataset(Seq(ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))).toDS(), + ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))) + } + test("package objects") { import packageobject._ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))