diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7f7dd51aa2650..bfedb6b212795 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -119,6 +119,32 @@ object ScalaReflection extends ScalaReflection { case _ => false } + /** + * Returns the element type for Seq[_] and its subclass. + */ + def getElementTypeForSeq(t: `Type`): Option[`Type`] = { + if (!(t <:< localTypeOf[Seq[_]])) { + return None + } + val TypeRef(_, _, elementTypeList) = t + val elementType = if (elementTypeList.size == 0) { + val seqType = t.baseClasses.find { c => + val cType = c.asClass.toType + val TypeRef(_, _, elementTypeList) = cType + cType <:< localTypeOf[Seq[_]] && elementTypeList.size > 0 + } + if (seqType.isDefined) { + val TypeRef(_, _, Seq(elementType)) = t.baseType(seqType.get) + elementType + } else { + null + } + } else { + elementTypeList(0) + } + Option(elementType) + } + /** * Returns an expression that can be used to deserialize an input row to an object of type `T` * with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes @@ -293,7 +319,11 @@ object ScalaReflection extends ScalaReflection { } case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t + val elementType = getElementTypeForSeq(t).getOrElse { + throw new UnsupportedOperationException( + s"No Decoder found for $tpe\n" + walkedTypePath.mkString("\n")) + } + val Schema(dataType, nullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath @@ -318,31 +348,38 @@ object ScalaReflection extends ScalaReflection { "make", array :: Nil) - if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) { - wrappedArray - } else { - // Convert to another type using `to` - val cls = mirror.runtimeClass(t.typeSymbol.asClass) - import scala.collection.generic.CanBuildFrom - import scala.reflect.ClassTag + val cls = mirror.runtimeClass(t.typeSymbol.asClass) + import scala.collection.generic.CanBuildFrom + import scala.reflect.ClassTag + val cbfParams = try { // Some canBuildFrom methods take an implicit ClassTag parameter - val cbfParams = try { - cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]]) + cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]]) + Some(StaticInvoke( + ClassTag.getClass, + ObjectType(classOf[ClassTag[_]]), + "apply", StaticInvoke( - ClassTag.getClass, - ObjectType(classOf[ClassTag[_]]), - "apply", - StaticInvoke( - cls, - ObjectType(classOf[Class[_]]), - "getClass" - ) :: Nil + cls, + ObjectType(classOf[Class[_]]), + "getClass" ) :: Nil - } catch { - case _: NoSuchMethodException => Nil - } + ) :: Nil) + } catch { + case _: NoSuchMethodException => + try { + cls.getDeclaredMethod("canBuildFrom") + Some(Nil) + } catch { + case _: NoSuchMethodException => None + } + } + if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure || + cbfParams.isEmpty) { + wrappedArray + } else { + // Convert to another type using `to` Invoke( wrappedArray, "to", @@ -351,7 +388,7 @@ object ScalaReflection extends ScalaReflection { cls, ObjectType(classOf[CanBuildFrom[_, _, _]]), "canBuildFrom", - cbfParams + cbfParams.get ) :: Nil ) } @@ -479,8 +516,8 @@ object ScalaReflection extends ScalaReflection { val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath MapObjects(serializerFor(_, elementType, newPath), input, dt) - case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType) => + case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType) => val cls = input.dataType.asInstanceOf[ObjectType].cls if (cls.isArray && cls.getComponentType.isPrimitive) { StaticInvoke( @@ -517,7 +554,10 @@ object ScalaReflection extends ScalaReflection { // "case t if definedByConstructorParams(t)" to make sure it will match to the // case "localTypeOf[Seq[_]]" case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t + val elementType = getElementTypeForSeq(t).getOrElse { + throw new UnsupportedOperationException( + s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) + } toCatalystArray(inputObject, elementType) case t if t <:< localTypeOf[Array[_]] => @@ -730,7 +770,9 @@ object ScalaReflection extends ScalaReflection { val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t + val elementType = getElementTypeForSeq(t).getOrElse { + throw new UnsupportedOperationException(s"Schema for type $tpe is not supported") + } val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< localTypeOf[Map[_, _]] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 3757eccfa2dd8..ffc09f88a54aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -297,7 +297,24 @@ case class ExpressionEncoder[T]( * function. */ def fromRow(row: InternalRow): T = try { - constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T] + assert(deserializer.resolved, "This encoder must `resolveAndBind` to a specific schema " + + "before calling `fromRow`.") + val value = constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)) + + // Sometimes, we can serialize a type to internal row, but we can't deserialize the row back to + // the type. For example, a `Range` is a `Seq[Int]`, so we can serialize it as `Seq[Int]`. But + // because we deserialize any types <:< `Seq[_]` to `WrappedArray[_]` and `WrappedArray` is not + // a `Range`, there will be conversion error when converting the deserialized `WrappedArray` to + // `Range`. + // In this cases, we can still deserilize the internal row to external row with `RowEncoder` by + // converting the `Dataset` to `DataFrame`. + assert(value == null || clsTag.runtimeClass.isPrimitive || + clsTag.runtimeClass.isAssignableFrom(value.getClass), + s"ExpressionEncoder.fromRow can't successfully deserialize type ${clsTag.runtimeClass} " + + "from the given internal row. You can try to use `RowEncoder` to deserialize the " + + "internal row to external row, or convert this strongly typed collection " + + "of data to generic DataFrame by using `Dataset.toDF()` API.") + value.asInstanceOf[T] } catch { case e: Exception => throw new RuntimeException(s"Error while decoding: $e\n${deserializer.treeString}", e) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 731a28c237bae..1c4e535cd3947 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} @@ -1143,6 +1143,21 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(spark.range(1).map { x => new java.sql.Timestamp(100000) }.head == new java.sql.Timestamp(100000)) } + + test("Create dataset for subclass of Seq[_] which has no type parameter") { + implicit def rangeEncoder: Encoder[Range.Inclusive] = ExpressionEncoder() + val data = Seq(0 to 2, 2 to 5) + val ds = data.toDS() + + val e = intercept[AssertionError] { + ds.collect() + } + assert(e.getMessage().contains("ExpressionEncoder.fromRow can't successfully deserialize type")) + + checkAnswer( + ds.toDF(), + Seq(Row(Seq(0, 1, 2)), Row(Seq(2, 3, 4, 5)))) + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])