Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][SQL] Put check in ExpressionEncoder.fromRow to ensure we can convert deserialized object to required type #16546

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,32 @@ object ScalaReflection extends ScalaReflection {
case _ => false
}

/**
* Returns the element type for Seq[_] and its subclass.
*/
def getElementTypeForSeq(t: `Type`): Option[`Type`] = {
if (!(t <:< localTypeOf[Seq[_]])) {
return None
}
val TypeRef(_, _, elementTypeList) = t
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would probably be better solved by using match and typeParams. Something like this:

t match {
  case TypeRef(_, _, Seq(elementType)) => Some(elementType)
  case _ =>
    t.baseClasses.find { c =>
      val cType = c.asClass.toType
      cType <:< localTypeOf[Seq[_]] && cType.typeParams.nonEmpty
    }.map(t.baseType(_).typeParams.head)
}

Also not sure whether types with more than one type parameter are handled correctly.

val elementType = if (elementTypeList.size == 0) {
val seqType = t.baseClasses.find { c =>
val cType = c.asClass.toType
val TypeRef(_, _, elementTypeList) = cType
cType <:< localTypeOf[Seq[_]] && elementTypeList.size > 0
}
if (seqType.isDefined) {
val TypeRef(_, _, Seq(elementType)) = t.baseType(seqType.get)
elementType
} else {
null
}
} else {
elementTypeList(0)
}
Option(elementType)
}

/**
* Returns an expression that can be used to deserialize an input row to an object of type `T`
* with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
Expand Down Expand Up @@ -293,7 +319,11 @@ object ScalaReflection extends ScalaReflection {
}

case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val elementType = getElementTypeForSeq(t).getOrElse {
throw new UnsupportedOperationException(
s"No Decoder found for $tpe\n" + walkedTypePath.mkString("\n"))
}

val Schema(dataType, nullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
Expand All @@ -318,31 +348,38 @@ object ScalaReflection extends ScalaReflection {
"make",
array :: Nil)

if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) {
wrappedArray
} else {
// Convert to another type using `to`
val cls = mirror.runtimeClass(t.typeSymbol.asClass)
import scala.collection.generic.CanBuildFrom
import scala.reflect.ClassTag
val cls = mirror.runtimeClass(t.typeSymbol.asClass)
import scala.collection.generic.CanBuildFrom
import scala.reflect.ClassTag

val cbfParams = try {
// Some canBuildFrom methods take an implicit ClassTag parameter
val cbfParams = try {
cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]])
cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]])
Some(StaticInvoke(
ClassTag.getClass,
ObjectType(classOf[ClassTag[_]]),
"apply",
StaticInvoke(
ClassTag.getClass,
ObjectType(classOf[ClassTag[_]]),
"apply",
StaticInvoke(
cls,
ObjectType(classOf[Class[_]]),
"getClass"
) :: Nil
cls,
ObjectType(classOf[Class[_]]),
"getClass"
) :: Nil
} catch {
case _: NoSuchMethodException => Nil
}
) :: Nil)
} catch {
case _: NoSuchMethodException =>
try {
cls.getDeclaredMethod("canBuildFrom")
Some(Nil)
} catch {
case _: NoSuchMethodException => None
}
}

if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure ||
cbfParams.isEmpty) {
wrappedArray
} else {
// Convert to another type using `to`
Invoke(
wrappedArray,
"to",
Expand All @@ -351,7 +388,7 @@ object ScalaReflection extends ScalaReflection {
cls,
ObjectType(classOf[CanBuildFrom[_, _, _]]),
"canBuildFrom",
cbfParams
cbfParams.get
) :: Nil
)
}
Expand Down Expand Up @@ -479,8 +516,8 @@ object ScalaReflection extends ScalaReflection {
val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
MapObjects(serializerFor(_, elementType, newPath), input, dt)

case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType) =>
case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType) =>
val cls = input.dataType.asInstanceOf[ObjectType].cls
if (cls.isArray && cls.getComponentType.isPrimitive) {
StaticInvoke(
Expand Down Expand Up @@ -517,7 +554,10 @@ object ScalaReflection extends ScalaReflection {
// "case t if definedByConstructorParams(t)" to make sure it will match to the
// case "localTypeOf[Seq[_]]"
case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val elementType = getElementTypeForSeq(t).getOrElse {
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
}
toCatalystArray(inputObject, elementType)

case t if t <:< localTypeOf[Array[_]] =>
Expand Down Expand Up @@ -730,7 +770,9 @@ object ScalaReflection extends ScalaReflection {
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val elementType = getElementTypeForSeq(t).getOrElse {
throw new UnsupportedOperationException(s"Schema for type $tpe is not supported")
}
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< localTypeOf[Map[_, _]] =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,24 @@ case class ExpressionEncoder[T](
* function.
*/
def fromRow(row: InternalRow): T = try {
constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
assert(deserializer.resolved, "This encoder must `resolveAndBind` to a specific schema " +
"before calling `fromRow`.")
val value = constructProjection(row).get(0, ObjectType(clsTag.runtimeClass))

// Sometimes, we can serialize a type to internal row, but we can't deserialize the row back to
// the type. For example, a `Range` is a `Seq[Int]`, so we can serialize it as `Seq[Int]`. But
// because we deserialize any types <:< `Seq[_]` to `WrappedArray[_]` and `WrappedArray` is not
// a `Range`, there will be conversion error when converting the deserialized `WrappedArray` to
// `Range`.
// In this cases, we can still deserilize the internal row to external row with `RowEncoder` by
// converting the `Dataset` to `DataFrame`.
assert(value == null || clsTag.runtimeClass.isPrimitive ||
clsTag.runtimeClass.isAssignableFrom(value.getClass),
s"ExpressionEncoder.fromRow can't successfully deserialize type ${clsTag.runtimeClass} " +
"from the given internal row. You can try to use `RowEncoder` to deserialize the " +
"internal row to external row, or convert this strongly typed collection " +
"of data to generic DataFrame by using `Dataset.toDF()` API.")
value.asInstanceOf[T]
} catch {
case e: Exception =>
throw new RuntimeException(s"Error while decoding: $e\n${deserializer.treeString}", e)
Expand Down
17 changes: 16 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql
import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, OuterScopes, RowEncoder}
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange}
Expand Down Expand Up @@ -1143,6 +1143,21 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(spark.range(1).map { x => new java.sql.Timestamp(100000) }.head ==
new java.sql.Timestamp(100000))
}

test("Create dataset for subclass of Seq[_] which has no type parameter") {
implicit def rangeEncoder: Encoder[Range.Inclusive] = ExpressionEncoder()
val data = Seq(0 to 2, 2 to 5)
val ds = data.toDS()

val e = intercept[AssertionError] {
ds.collect()
}
assert(e.getMessage().contains("ExpressionEncoder.fromRow can't successfully deserialize type"))

checkAnswer(
ds.toDF(),
Seq(Row(Seq(0, 1, 2)), Row(Seq(2, 3, 4, 5))))
}
}

case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])
Expand Down