Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-16792][SQL] Dataset containing a Case Class with a List type causes a CompileException (converting sequence to list) #16240

Closed
wants to merge 12 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,50 @@ object ScalaReflection extends ScalaReflection {
"array",
ObjectType(classOf[Array[Any]]))

StaticInvoke(
val wrappedArray = StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
array :: Nil)

if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) {
wrappedArray
} else {
// Convert to another type using `to`
val cls = mirror.runtimeClass(t.typeSymbol.asClass)
import scala.collection.generic.CanBuildFrom
import scala.reflect.ClassTag

// Some canBuildFrom methods take an implicit ClassTag parameter
val cbfParams = try {
cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]])
StaticInvoke(
ClassTag.getClass,
ObjectType(classOf[ClassTag[_]]),
"apply",
StaticInvoke(
cls,
ObjectType(classOf[Class[_]]),
"getClass"
) :: Nil
) :: Nil
} catch {
case _: NoSuchMethodException => Nil
}

Invoke(
wrappedArray,
"to",
ObjectType(cls),
StaticInvoke(
cls,
ObjectType(classOf[CanBuildFrom[_, _, _]]),
"canBuildFrom",
cbfParams
) :: Nil
)
}

case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
val TypeRef(_, _, Seq(keyType, valueType)) = t
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,37 @@ class ScalaReflectionSuite extends SparkFunSuite {
.cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData]))
}

test("SPARK 16792: Get correct deserializer for List[_]") {
val listDeserializer = deserializerFor[List[Int]]
assert(listDeserializer.dataType == ObjectType(classOf[List[_]]))
}

test("serialize and deserialize arbitrary sequence types") {
import scala.collection.immutable.Queue
val queueSerializer = serializerFor[Queue[Int]](BoundReference(
0, ObjectType(classOf[Queue[Int]]), nullable = false))
assert(queueSerializer.dataType.head.dataType ==
ArrayType(IntegerType, containsNull = false))
val queueDeserializer = deserializerFor[Queue[Int]]
assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]]))

import scala.collection.mutable.ArrayBuffer
val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference(
0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false))
assert(arrayBufferSerializer.dataType.head.dataType ==
ArrayType(IntegerType, containsNull = false))
val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))

// Check whether conversion is skipped when using WrappedArray[_] supertype
// (would otherwise needlessly add overhead)
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
val seqDeserializer = deserializerFor[Seq[Int]]
assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject ==
scala.collection.mutable.WrappedArray.getClass)
assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make")
}

private val dataTypeForComplexData = dataTypeFor[ComplexData]
private val typeOfComplexData = typeOf[ComplexData]

Expand Down
37 changes: 24 additions & 13 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
* @since 1.6.0
*/
@InterfaceStability.Evolving
abstract class SQLImplicits {
abstract class SQLImplicits extends LowPrioritySQLImplicits {

protected def _sqlContext: SQLContext

Expand All @@ -45,9 +45,6 @@ abstract class SQLImplicits {
}
}

/** @since 1.6.0 */
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T]

// Primitives

/** @since 1.6.0 */
Expand Down Expand Up @@ -100,31 +97,32 @@ abstract class SQLImplicits {
// Seqs

/** @since 1.6.1 */
implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
implicit def newIntSeqEncoder[T <: Seq[Int] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
implicit def newLongSeqEncoder[T <: Seq[Long] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
implicit def newDoubleSeqEncoder[T <: Seq[Double] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
implicit def newFloatSeqEncoder[T <: Seq[Float] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
implicit def newByteSeqEncoder[T <: Seq[Byte] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
implicit def newShortSeqEncoder[T <: Seq[Short] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
implicit def newBooleanSeqEncoder[T <: Seq[Boolean] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
implicit def newStringSeqEncoder[T <: Seq[String] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()
implicit def newProductSeqEncoder[T <: Seq[Product] : TypeTag]: Encoder[T] =
ExpressionEncoder()

// Arrays

Expand Down Expand Up @@ -180,3 +178,16 @@ abstract class SQLImplicits {
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)

}

/**
* Lower priority implicit methods for converting Scala objects into [[Dataset]]s.
* Conflicting implicits are placed here to disambiguate resolution.
*
* Reasons for including specific implicits:
* newProductEncoder - to disambiguate for [[List]]s which are both [[Seq]] and [[Product]]
*/
trait LowPrioritySQLImplicits {
/** @since 1.6.0 */
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T]

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,21 @@

package org.apache.spark.sql

import scala.collection.immutable.Queue
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.test.SharedSQLContext

case class IntClass(value: Int)

case class SeqCC(s: Seq[Int])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does CC short for? How about SeqClass?


case class ListCC(l: List[Int])

case class QueueCC(q: Queue[Int])

case class ComplexCC(seq: SeqCC, list: ListCC, queue: QueueCC)

package object packageobject {
case class PackageClass(value: Int)
}
Expand Down Expand Up @@ -130,6 +141,62 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
}

test("arbitrary sequences") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also test nested sequences, e.g. List(Queue(1)), and sequences inside product, e.g. List(1) -> Queue(1)

Copy link
Contributor Author

@michalsenkyr michalsenkyr Jan 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some sequence-product combination tests.
Nested sequences were never supported (tried on master and 2.0.2). That would probably be worthy of another ticket.

checkDataset(Seq(Queue(1)).toDS(), Queue(1))
checkDataset(Seq(Queue(1.toLong)).toDS(), Queue(1.toLong))
checkDataset(Seq(Queue(1.toDouble)).toDS(), Queue(1.toDouble))
checkDataset(Seq(Queue(1.toFloat)).toDS(), Queue(1.toFloat))
checkDataset(Seq(Queue(1.toByte)).toDS(), Queue(1.toByte))
checkDataset(Seq(Queue(1.toShort)).toDS(), Queue(1.toShort))
checkDataset(Seq(Queue(true)).toDS(), Queue(true))
checkDataset(Seq(Queue("test")).toDS(), Queue("test"))
checkDataset(Seq(Queue(Tuple1(1))).toDS(), Queue(Tuple1(1)))

checkDataset(Seq(ArrayBuffer(1)).toDS(), ArrayBuffer(1))
checkDataset(Seq(ArrayBuffer(1.toLong)).toDS(), ArrayBuffer(1.toLong))
checkDataset(Seq(ArrayBuffer(1.toDouble)).toDS(), ArrayBuffer(1.toDouble))
checkDataset(Seq(ArrayBuffer(1.toFloat)).toDS(), ArrayBuffer(1.toFloat))
checkDataset(Seq(ArrayBuffer(1.toByte)).toDS(), ArrayBuffer(1.toByte))
checkDataset(Seq(ArrayBuffer(1.toShort)).toDS(), ArrayBuffer(1.toShort))
checkDataset(Seq(ArrayBuffer(true)).toDS(), ArrayBuffer(true))
checkDataset(Seq(ArrayBuffer("test")).toDS(), ArrayBuffer("test"))
checkDataset(Seq(ArrayBuffer(Tuple1(1))).toDS(), ArrayBuffer(Tuple1(1)))
}

test("sequence and product combinations") {
// Case classes
checkDataset(Seq(SeqCC(Seq(1))).toDS(), SeqCC(Seq(1)))
checkDataset(Seq(Seq(SeqCC(Seq(1)))).toDS(), Seq(SeqCC(Seq(1))))
checkDataset(Seq(List(SeqCC(Seq(1)))).toDS(), List(SeqCC(Seq(1))))
checkDataset(Seq(Queue(SeqCC(Seq(1)))).toDS(), Queue(SeqCC(Seq(1))))

checkDataset(Seq(ListCC(List(1))).toDS(), ListCC(List(1)))
checkDataset(Seq(Seq(ListCC(List(1)))).toDS(), Seq(ListCC(List(1))))
checkDataset(Seq(List(ListCC(List(1)))).toDS(), List(ListCC(List(1))))
checkDataset(Seq(Queue(ListCC(List(1)))).toDS(), Queue(ListCC(List(1))))

checkDataset(Seq(QueueCC(Queue(1))).toDS(), QueueCC(Queue(1)))
checkDataset(Seq(Seq(QueueCC(Queue(1)))).toDS(), Seq(QueueCC(Queue(1))))
checkDataset(Seq(List(QueueCC(Queue(1)))).toDS(), List(QueueCC(Queue(1))))
checkDataset(Seq(Queue(QueueCC(Queue(1)))).toDS(), Queue(QueueCC(Queue(1))))

val complexCC = ComplexCC(SeqCC(Seq(1)), ListCC(List(2)), QueueCC(Queue(3)))
checkDataset(Seq(complexCC).toDS(), complexCC)
checkDataset(Seq(Seq(complexCC)).toDS(), Seq(complexCC))
checkDataset(Seq(List(complexCC)).toDS(), List(complexCC))
checkDataset(Seq(Queue(complexCC)).toDS(), Queue(complexCC))

// Tuples
checkDataset(Seq(Seq(1) -> Seq(2)).toDS(), Seq(1) -> Seq(2))
checkDataset(Seq(List(1) -> Queue(2)).toDS(), List(1) -> Queue(2))
checkDataset(Seq(List(Seq("test1") -> List(Queue("test2")))).toDS(),
List(Seq("test1") -> List(Queue("test2"))))

// Complex
checkDataset(Seq(ListCC(List(1)) -> Queue("test" -> SeqCC(Seq(2)))).toDS(),
ListCC(List(1)) -> Queue("test" -> SeqCC(Seq(2))))
}

test("package objects") {
import packageobject._
checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))
Expand Down