Skip to content

Commit

Permalink
[SPARK-18251][SQL] the type of Dataset can't be Option of non-flat type
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

For input object of non-flat type, we can't encode it to row if it's null, as Spark SQL doesn't allow the entire row to be null, only its columns can be null. That's the reason we forbid users to use top level null objects in apache#13469

However, if users wrap non-flat type with `Option`, then we may still encoder top level null object to row, which is not allowed.

This PR fixes this case, and suggests users to wrap their type with `Tuple1` if they do wanna top level null objects.

## How was this patch tested?

new test

Author: Wenchen Fan <[email protected]>

Closes apache#15979 from cloud-fan/option.
  • Loading branch information
cloud-fan authored and uzadude committed Jan 27, 2017
1 parent fa59136 commit 81b8a46
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,19 @@ object ScalaReflection extends ScalaReflection {

}

/**
* Returns true if the given type is option of product type, e.g. `Option[Tuple2]`. Note that,
* we also treat [[DefinedByConstructorParams]] as product type.
*/
def optionOfProductType(tpe: `Type`): Boolean = ScalaReflectionLock.synchronized {
tpe match {
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
definedByConstructorParams(optType)
case _ => false
}
}

/**
* Returns the parameter names and types for the primary constructor of this class.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,26 @@ object ExpressionEncoder {
// We convert the not-serializable TypeTag into StructType and ClassTag.
val mirror = typeTag[T].mirror
val tpe = typeTag[T].tpe

if (ScalaReflection.optionOfProductType(tpe)) {
throw new UnsupportedOperationException(
"Cannot create encoder for Option of Product type, because Product type is represented " +
"as a row, and the entire row can not be null in Spark SQL like normal databases. " +
"You can wrap your type with Tuple1 if you do want top level null Product objects, " +
"e.g. instead of creating `Dataset[Option[MyClass]]`, you can do something like " +
"`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS`")
}

val cls = mirror.runtimeClass(tpe)
val flat = !ScalaReflection.definedByConstructorParams(tpe)

val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true)
val nullSafeInput = if (flat) {
inputObject
} else {
// For input object of non-flat type, we can't encode it to row if it's null, as Spark SQL
// For input object of Product type, we can't encode it to row if it's null, as Spark SQL
// doesn't allow top-level row to be null, only its columns can be null.
AssertNotNull(inputObject, Seq("top level non-flat input object"))
AssertNotNull(inputObject, Seq("top level Product input object"))
}
val serializer = ScalaReflection.serializerFor[T](nullSafeInput)
val deserializer = ScalaReflection.deserializerFor[T]
Expand Down
13 changes: 11 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -867,10 +867,10 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkDataset(Seq("a", null).toDS(), "a", null)
}

test("Dataset should throw RuntimeException if non-flat input object is null") {
test("Dataset should throw RuntimeException if top-level product input object is null") {
val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS())
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
assert(e.getMessage.contains("top level non-flat input object"))
assert(e.getMessage.contains("top level Product input object"))
}

test("dropDuplicates") {
Expand Down Expand Up @@ -1051,6 +1051,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkDataset(dsDouble, arrayDouble)
checkDataset(dsString, arrayString)
}

test("SPARK-18251: the type of Dataset can't be Option of Product type") {
checkDataset(Seq(Some(1), None).toDS(), Some(1), None)

val e = intercept[UnsupportedOperationException] {
Seq(Some(1 -> "a"), None).toDS()
}
assert(e.getMessage.contains("Cannot create encoder for Option of Product type"))
}
}

case class Generic[T](id: T, value: Double)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
}

test("roundtrip in to_json and from_json") {
val dfOne = Seq(Some(Tuple1(Tuple1(1))), None).toDF("struct")
val dfOne = Seq(Tuple1(Tuple1(1)), Tuple1(null)).toDF("struct")
val schemaOne = dfOne.schema(0).dataType.asInstanceOf[StructType]
val readBackOne = dfOne.select(to_json($"struct").as("json"))
.select(from_json($"json", schemaOne).as("struct"))
Expand Down

0 comments on commit 81b8a46

Please sign in to comment.