Skip to content

Commit

Permalink
[SPARK-16792][SQL] Dataset containing a Case Class with a List type c…
Browse files Browse the repository at this point in the history
…auses a CompileException (converting sequence to list)

## What changes were proposed in this pull request?

Added a `to` call at the end of the code generated by `ScalaReflection.deserializerFor` if the requested type is not a supertype of `WrappedArray[_]` that uses `CanBuildFrom[_, _, _]` to convert result into an arbitrary subtype of `Seq[_]`.

Care was taken to preserve the original deserialization where it is possible to avoid the overhead of conversion in cases where it is not needed

`ScalaReflection.serializerFor` could already be used to serialize any `Seq[_]` so it was not altered

`SQLImplicits` had to be altered and new implicit encoders added to permit serialization of other sequence types

Also fixes [SPARK-16815] Dataset[List[T]] leads to ArrayStoreException

## How was this patch tested?
```bash
./build/mvn -DskipTests clean package && ./dev/run-tests
```

Also manual execution of the following sets of commands in the Spark shell:
```scala
case class TestCC(key: Int, letters: List[String])

val ds1 = sc.makeRDD(Seq(
(List("D")),
(List("S","H")),
(List("F","H")),
(List("D","L","L"))
)).map(x=>(x.length,x)).toDF("key","letters").as[TestCC]

val test1=ds1.map{_.key}
test1.show
```

```scala
case class X(l: List[String])
spark.createDataset(Seq(List("A"))).map(X).show
```

```scala
spark.sqlContext.createDataset(sc.parallelize(List(1) :: Nil)).collect
```

After adding arbitrary sequence support also tested with the following commands:

```scala
case class QueueClass(q: scala.collection.immutable.Queue[Int])

spark.createDataset(Seq(List(1,2,3))).map(x => QueueClass(scala.collection.immutable.Queue(x: _*))).map(_.q.dequeue).collect
```

Author: Michal Senkyr <[email protected]>

Closes apache#16240 from michalsenkyr/sql-caseclass-list-fix.
  • Loading branch information
michalsenkyr authored and uzadude committed Jan 27, 2017
1 parent 48bce29 commit 65e34fa
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,50 @@ object ScalaReflection extends ScalaReflection {
"array",
ObjectType(classOf[Array[Any]]))

StaticInvoke(
val wrappedArray = StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
array :: Nil)

if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) {
wrappedArray
} else {
// Convert to another type using `to`
val cls = mirror.runtimeClass(t.typeSymbol.asClass)
import scala.collection.generic.CanBuildFrom
import scala.reflect.ClassTag

// Some canBuildFrom methods take an implicit ClassTag parameter
val cbfParams = try {
cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]])
StaticInvoke(
ClassTag.getClass,
ObjectType(classOf[ClassTag[_]]),
"apply",
StaticInvoke(
cls,
ObjectType(classOf[Class[_]]),
"getClass"
) :: Nil
) :: Nil
} catch {
case _: NoSuchMethodException => Nil
}

Invoke(
wrappedArray,
"to",
ObjectType(cls),
StaticInvoke(
cls,
ObjectType(classOf[CanBuildFrom[_, _, _]]),
"canBuildFrom",
cbfParams
) :: Nil
)
}

case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
val TypeRef(_, _, Seq(keyType, valueType)) = t
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,37 @@ class ScalaReflectionSuite extends SparkFunSuite {
.cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData]))
}

test("SPARK 16792: Get correct deserializer for List[_]") {
val listDeserializer = deserializerFor[List[Int]]
assert(listDeserializer.dataType == ObjectType(classOf[List[_]]))
}

test("serialize and deserialize arbitrary sequence types") {
import scala.collection.immutable.Queue
val queueSerializer = serializerFor[Queue[Int]](BoundReference(
0, ObjectType(classOf[Queue[Int]]), nullable = false))
assert(queueSerializer.dataType.head.dataType ==
ArrayType(IntegerType, containsNull = false))
val queueDeserializer = deserializerFor[Queue[Int]]
assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]]))

import scala.collection.mutable.ArrayBuffer
val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference(
0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false))
assert(arrayBufferSerializer.dataType.head.dataType ==
ArrayType(IntegerType, containsNull = false))
val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))

// Check whether conversion is skipped when using WrappedArray[_] supertype
// (would otherwise needlessly add overhead)
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
val seqDeserializer = deserializerFor[Seq[Int]]
assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject ==
scala.collection.mutable.WrappedArray.getClass)
assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make")
}

private val dataTypeForComplexData = dataTypeFor[ComplexData]
private val typeOfComplexData = typeOf[ComplexData]

Expand Down
115 changes: 94 additions & 21 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
* @since 1.6.0
*/
@InterfaceStability.Evolving
abstract class SQLImplicits {
abstract class SQLImplicits extends LowPrioritySQLImplicits {

protected def _sqlContext: SQLContext

Expand All @@ -45,9 +45,6 @@ abstract class SQLImplicits {
}
}

/** @since 1.6.0 */
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T]

// Primitives

/** @since 1.6.0 */
Expand Down Expand Up @@ -112,33 +109,96 @@ abstract class SQLImplicits {

// Seqs

/** @since 1.6.1 */
implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newIntSequenceEncoder]]
*/
def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newLongSequenceEncoder]]
*/
def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newDoubleSequenceEncoder]]
*/
def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newFloatSequenceEncoder]]
*/
def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newByteSequenceEncoder]]
*/
def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newShortSequenceEncoder]]
*/
def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newBooleanSequenceEncoder]]
*/
def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newStringSequenceEncoder]]
*/
def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()

/** @since 1.6.1 */
/**
* @since 1.6.1
* @deprecated use [[newProductSequenceEncoder]]
*/
implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()

/** @since 2.2.0 */
implicit def newIntSequenceEncoder[T <: Seq[Int] : TypeTag]: Encoder[T] =
ExpressionEncoder()

/** @since 2.2.0 */
implicit def newLongSequenceEncoder[T <: Seq[Long] : TypeTag]: Encoder[T] =
ExpressionEncoder()

/** @since 2.2.0 */
implicit def newDoubleSequenceEncoder[T <: Seq[Double] : TypeTag]: Encoder[T] =
ExpressionEncoder()

/** @since 2.2.0 */
implicit def newFloatSequenceEncoder[T <: Seq[Float] : TypeTag]: Encoder[T] =
ExpressionEncoder()

/** @since 2.2.0 */
implicit def newByteSequenceEncoder[T <: Seq[Byte] : TypeTag]: Encoder[T] =
ExpressionEncoder()

/** @since 2.2.0 */
implicit def newShortSequenceEncoder[T <: Seq[Short] : TypeTag]: Encoder[T] =
ExpressionEncoder()

/** @since 2.2.0 */
implicit def newBooleanSequenceEncoder[T <: Seq[Boolean] : TypeTag]: Encoder[T] =
ExpressionEncoder()

/** @since 2.2.0 */
implicit def newStringSequenceEncoder[T <: Seq[String] : TypeTag]: Encoder[T] =
ExpressionEncoder()

/** @since 2.2.0 */
implicit def newProductSequenceEncoder[T <: Seq[Product] : TypeTag]: Encoder[T] =
ExpressionEncoder()

// Arrays

/** @since 1.6.1 */
Expand Down Expand Up @@ -193,3 +253,16 @@ abstract class SQLImplicits {
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)

}

/**
* Lower priority implicit methods for converting Scala objects into [[Dataset]]s.
* Conflicting implicits are placed here to disambiguate resolution.
*
* Reasons for including specific implicits:
* newProductEncoder - to disambiguate for [[List]]s which are both [[Seq]] and [[Product]]
*/
trait LowPrioritySQLImplicits {
/** @since 1.6.0 */
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T]

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,21 @@

package org.apache.spark.sql

import scala.collection.immutable.Queue
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.test.SharedSQLContext

case class IntClass(value: Int)

case class SeqClass(s: Seq[Int])

case class ListClass(l: List[Int])

case class QueueClass(q: Queue[Int])

case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass)

package object packageobject {
case class PackageClass(value: Int)
}
Expand Down Expand Up @@ -130,6 +141,62 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
}

test("arbitrary sequences") {
checkDataset(Seq(Queue(1)).toDS(), Queue(1))
checkDataset(Seq(Queue(1.toLong)).toDS(), Queue(1.toLong))
checkDataset(Seq(Queue(1.toDouble)).toDS(), Queue(1.toDouble))
checkDataset(Seq(Queue(1.toFloat)).toDS(), Queue(1.toFloat))
checkDataset(Seq(Queue(1.toByte)).toDS(), Queue(1.toByte))
checkDataset(Seq(Queue(1.toShort)).toDS(), Queue(1.toShort))
checkDataset(Seq(Queue(true)).toDS(), Queue(true))
checkDataset(Seq(Queue("test")).toDS(), Queue("test"))
checkDataset(Seq(Queue(Tuple1(1))).toDS(), Queue(Tuple1(1)))

checkDataset(Seq(ArrayBuffer(1)).toDS(), ArrayBuffer(1))
checkDataset(Seq(ArrayBuffer(1.toLong)).toDS(), ArrayBuffer(1.toLong))
checkDataset(Seq(ArrayBuffer(1.toDouble)).toDS(), ArrayBuffer(1.toDouble))
checkDataset(Seq(ArrayBuffer(1.toFloat)).toDS(), ArrayBuffer(1.toFloat))
checkDataset(Seq(ArrayBuffer(1.toByte)).toDS(), ArrayBuffer(1.toByte))
checkDataset(Seq(ArrayBuffer(1.toShort)).toDS(), ArrayBuffer(1.toShort))
checkDataset(Seq(ArrayBuffer(true)).toDS(), ArrayBuffer(true))
checkDataset(Seq(ArrayBuffer("test")).toDS(), ArrayBuffer("test"))
checkDataset(Seq(ArrayBuffer(Tuple1(1))).toDS(), ArrayBuffer(Tuple1(1)))
}

test("sequence and product combinations") {
// Case classes
checkDataset(Seq(SeqClass(Seq(1))).toDS(), SeqClass(Seq(1)))
checkDataset(Seq(Seq(SeqClass(Seq(1)))).toDS(), Seq(SeqClass(Seq(1))))
checkDataset(Seq(List(SeqClass(Seq(1)))).toDS(), List(SeqClass(Seq(1))))
checkDataset(Seq(Queue(SeqClass(Seq(1)))).toDS(), Queue(SeqClass(Seq(1))))

checkDataset(Seq(ListClass(List(1))).toDS(), ListClass(List(1)))
checkDataset(Seq(Seq(ListClass(List(1)))).toDS(), Seq(ListClass(List(1))))
checkDataset(Seq(List(ListClass(List(1)))).toDS(), List(ListClass(List(1))))
checkDataset(Seq(Queue(ListClass(List(1)))).toDS(), Queue(ListClass(List(1))))

checkDataset(Seq(QueueClass(Queue(1))).toDS(), QueueClass(Queue(1)))
checkDataset(Seq(Seq(QueueClass(Queue(1)))).toDS(), Seq(QueueClass(Queue(1))))
checkDataset(Seq(List(QueueClass(Queue(1)))).toDS(), List(QueueClass(Queue(1))))
checkDataset(Seq(Queue(QueueClass(Queue(1)))).toDS(), Queue(QueueClass(Queue(1))))

val complex = ComplexClass(SeqClass(Seq(1)), ListClass(List(2)), QueueClass(Queue(3)))
checkDataset(Seq(complex).toDS(), complex)
checkDataset(Seq(Seq(complex)).toDS(), Seq(complex))
checkDataset(Seq(List(complex)).toDS(), List(complex))
checkDataset(Seq(Queue(complex)).toDS(), Queue(complex))

// Tuples
checkDataset(Seq(Seq(1) -> Seq(2)).toDS(), Seq(1) -> Seq(2))
checkDataset(Seq(List(1) -> Queue(2)).toDS(), List(1) -> Queue(2))
checkDataset(Seq(List(Seq("test1") -> List(Queue("test2")))).toDS(),
List(Seq("test1") -> List(Queue("test2"))))

// Complex
checkDataset(Seq(ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))).toDS(),
ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2))))
}

test("package objects") {
import packageobject._
checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))
Expand Down

0 comments on commit 65e34fa

Please sign in to comment.