diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7bcaea7ea2f79..0aa21b9347a9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -605,6 +605,19 @@ object ScalaReflection extends ScalaReflection { } + /** + * Returns true if the given type is option of product type, e.g. `Option[Tuple2]`. Note that, + * we also treat [[DefinedByConstructorParams]] as product type. + */ + def optionOfProductType(tpe: `Type`): Boolean = ScalaReflectionLock.synchronized { + tpe match { + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + definedByConstructorParams(optType) + case _ => false + } + } + /** * Returns the parameter names and types for the primary constructor of this class. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 82e1a8a7cad96..9c4818db6333b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -47,6 +47,16 @@ object ExpressionEncoder { // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = typeTag[T].mirror val tpe = typeTag[T].tpe + + if (ScalaReflection.optionOfProductType(tpe)) { + throw new UnsupportedOperationException( + "Cannot create encoder for Option of Product type, because Product type is represented " + + "as a row, and the entire row can not be null in Spark SQL like normal databases. " + + "You can wrap your type with Tuple1 if you do want top level null Product objects, " + + "e.g. instead of creating `Dataset[Option[MyClass]]`, you can do something like " + + "`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS`") + } + val cls = mirror.runtimeClass(tpe) val flat = !ScalaReflection.definedByConstructorParams(tpe) @@ -54,9 +64,9 @@ object ExpressionEncoder { val nullSafeInput = if (flat) { inputObject } else { - // For input object of non-flat type, we can't encode it to row if it's null, as Spark SQL + // For input object of Product type, we can't encode it to row if it's null, as Spark SQL // doesn't allow top-level row to be null, only its columns can be null. - AssertNotNull(inputObject, Seq("top level non-flat input object")) + AssertNotNull(inputObject, Seq("top level Product input object")) } val serializer = ScalaReflection.serializerFor[T](nullSafeInput) val deserializer = ScalaReflection.deserializerFor[T] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 81fa8cbf22384..1174d7354f931 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -867,10 +867,10 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(Seq("a", null).toDS(), "a", null) } - test("Dataset should throw RuntimeException if non-flat input object is null") { + test("Dataset should throw RuntimeException if top-level product input object is null") { val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS()) assert(e.getMessage.contains("Null value appeared in non-nullable field")) - assert(e.getMessage.contains("top level non-flat input object")) + assert(e.getMessage.contains("top level Product input object")) } test("dropDuplicates") { @@ -1051,6 +1051,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(dsDouble, arrayDouble) checkDataset(dsString, arrayString) } + + test("SPARK-18251: the type of Dataset can't be Option of Product type") { + checkDataset(Seq(Some(1), None).toDS(), Some(1), None) + + val e = intercept[UnsupportedOperationException] { + Seq(Some(1 -> "a"), None).toDS() + } + assert(e.getMessage.contains("Cannot create encoder for Option of Product type")) + } } case class Generic[T](id: T, value: Double) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 7d63d31d9b979..890cc5b560d02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -143,7 +143,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { } test("roundtrip in to_json and from_json") { - val dfOne = Seq(Some(Tuple1(Tuple1(1))), None).toDF("struct") + val dfOne = Seq(Tuple1(Tuple1(1)), Tuple1(null)).toDF("struct") val schemaOne = dfOne.schema(0).dataType.asInstanceOf[StructType] val readBackOne = dfOne.select(to_json($"struct").as("json")) .select(from_json($"json", schemaOne).as("struct"))