From 0f4d4a41492c0d0d8e3b314cab9e06f38ac629a8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 22 Nov 2016 22:06:57 +0800 Subject: [PATCH 1/3] the type of Dataset can't be Option of non-flat type --- .../apache/spark/sql/catalyst/ScalaReflection.scala | 12 ++++++++++++ .../sql/catalyst/encoders/ExpressionEncoder.scala | 8 ++++++++ .../scala/org/apache/spark/sql/DatasetSuite.scala | 9 +++++++++ .../org/apache/spark/sql/JsonFunctionsSuite.scala | 2 +- 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7bcaea7ea2f79..668ac8e2cabc6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -605,6 +605,18 @@ object ScalaReflection extends ScalaReflection { } + /** + * Returns true if the given type is option of non flat type, e.g. `Option[Tuple2]`. + */ + def optionOfNonFlatType(tpe: `Type`): Boolean = ScalaReflectionLock.synchronized { + tpe match { + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + definedByConstructorParams(optType) + case _ => false + } + } + /** * Returns the parameter names and types for the primary constructor of this class. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 82e1a8a7cad96..1231e49ab11e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -47,6 +47,14 @@ object ExpressionEncoder { // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = typeTag[T].mirror val tpe = typeTag[T].tpe + + if (ScalaReflection.optionOfNonFlatType(tpe)) { + throw new UnsupportedOperationException( + "Cannot create encoder for Option of non-flat type, as non-flat type is represented " + + "as a row, and the entire row can not be null in Spark SQL like normal databases. " + + "You can wrap your type with Tuple1 if you do want top level null objects.") + } + val cls = mirror.runtimeClass(tpe) val flat = !ScalaReflection.definedByConstructorParams(tpe) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 81fa8cbf22384..58c26306dd903 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1051,6 +1051,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(dsDouble, arrayDouble) checkDataset(dsString, arrayString) } + + test("SPARK-18251: the type of Dataset can't be Option of non-flat type") { + checkDataset(Seq(Some(1), None).toDS(), Some(1), None) + + val e = intercept[UnsupportedOperationException] { + Seq(Some(1 -> "a"), None).toDS() + } + assert(e.getMessage.contains("Cannot create encoder for Option of non-flat type")) + } } case class Generic[T](id: T, value: Double) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 7d63d31d9b979..890cc5b560d02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -143,7 +143,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { } test("roundtrip in to_json and from_json") { - val dfOne = Seq(Some(Tuple1(Tuple1(1))), None).toDF("struct") + val dfOne = Seq(Tuple1(Tuple1(1)), Tuple1(null)).toDF("struct") val schemaOne = dfOne.schema(0).dataType.asInstanceOf[StructType] val readBackOne = dfOne.select(to_json($"struct").as("json")) .select(from_json($"json", schemaOne).as("struct")) From 01b072d07a21179d28a13fc39d9b18560f10df4a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 29 Nov 2016 19:42:33 +0800 Subject: [PATCH 2/3] address comment --- .../apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 1231e49ab11e0..7a02b309d8718 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -52,7 +52,8 @@ object ExpressionEncoder { throw new UnsupportedOperationException( "Cannot create encoder for Option of non-flat type, as non-flat type is represented " + "as a row, and the entire row can not be null in Spark SQL like normal databases. " + - "You can wrap your type with Tuple1 if you do want top level null objects.") + "You can wrap your type with Tuple1 if you do want top level null objects, e.g. " + + "val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS") } val cls = mirror.runtimeClass(tpe) From 876e5c7e59e5acc5ef76611e0ea53ec06043e154 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 30 Nov 2016 15:00:52 +0800 Subject: [PATCH 3/3] address comments --- .../apache/spark/sql/catalyst/ScalaReflection.scala | 5 +++-- .../sql/catalyst/encoders/ExpressionEncoder.scala | 13 +++++++------ .../scala/org/apache/spark/sql/DatasetSuite.scala | 8 ++++---- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 668ac8e2cabc6..0aa21b9347a9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -606,9 +606,10 @@ object ScalaReflection extends ScalaReflection { } /** - * Returns true if the given type is option of non flat type, e.g. `Option[Tuple2]`. + * Returns true if the given type is option of product type, e.g. `Option[Tuple2]`. Note that, + * we also treat [[DefinedByConstructorParams]] as product type. */ - def optionOfNonFlatType(tpe: `Type`): Boolean = ScalaReflectionLock.synchronized { + def optionOfProductType(tpe: `Type`): Boolean = ScalaReflectionLock.synchronized { tpe match { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 7a02b309d8718..9c4818db6333b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -48,12 +48,13 @@ object ExpressionEncoder { val mirror = typeTag[T].mirror val tpe = typeTag[T].tpe - if (ScalaReflection.optionOfNonFlatType(tpe)) { + if (ScalaReflection.optionOfProductType(tpe)) { throw new UnsupportedOperationException( - "Cannot create encoder for Option of non-flat type, as non-flat type is represented " + + "Cannot create encoder for Option of Product type, because Product type is represented " + "as a row, and the entire row can not be null in Spark SQL like normal databases. " + - "You can wrap your type with Tuple1 if you do want top level null objects, e.g. " + - "val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS") + "You can wrap your type with Tuple1 if you do want top level null Product objects, " + + "e.g. instead of creating `Dataset[Option[MyClass]]`, you can do something like " + + "`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS`") } val cls = mirror.runtimeClass(tpe) @@ -63,9 +64,9 @@ object ExpressionEncoder { val nullSafeInput = if (flat) { inputObject } else { - // For input object of non-flat type, we can't encode it to row if it's null, as Spark SQL + // For input object of Product type, we can't encode it to row if it's null, as Spark SQL // doesn't allow top-level row to be null, only its columns can be null. - AssertNotNull(inputObject, Seq("top level non-flat input object")) + AssertNotNull(inputObject, Seq("top level Product input object")) } val serializer = ScalaReflection.serializerFor[T](nullSafeInput) val deserializer = ScalaReflection.deserializerFor[T] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 58c26306dd903..1174d7354f931 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -867,10 +867,10 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(Seq("a", null).toDS(), "a", null) } - test("Dataset should throw RuntimeException if non-flat input object is null") { + test("Dataset should throw RuntimeException if top-level product input object is null") { val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS()) assert(e.getMessage.contains("Null value appeared in non-nullable field")) - assert(e.getMessage.contains("top level non-flat input object")) + assert(e.getMessage.contains("top level Product input object")) } test("dropDuplicates") { @@ -1052,13 +1052,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(dsString, arrayString) } - test("SPARK-18251: the type of Dataset can't be Option of non-flat type") { + test("SPARK-18251: the type of Dataset can't be Option of Product type") { checkDataset(Seq(Some(1), None).toDS(), Some(1), None) val e = intercept[UnsupportedOperationException] { Seq(Some(1 -> "a"), None).toDS() } - assert(e.getMessage.contains("Cannot create encoder for Option of non-flat type")) + assert(e.getMessage.contains("Cannot create encoder for Option of Product type")) } }