From ad6762a02f3ebdb9139d2c5164c960d2d7f86f56 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 12 Mar 2018 14:32:22 +0900 Subject: [PATCH 1/2] CatalystToExternalMap should support interpreted execution --- .../expressions/objects/objects.scala | 20 ++++++++--- .../expressions/ObjectExpressionsSuite.scala | 35 +++++++++++++++---- 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 72b202b3a5020..6eea60bbda415 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -28,12 +28,12 @@ import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -1009,8 +1009,20 @@ case class CatalystToExternalMap private( override def children: Seq[Expression] = keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + private lazy val toScalaValue: Any => Any = { + assert(inputData.dataType.isInstanceOf[MapType]) + val mapType = inputData.dataType.asInstanceOf[MapType] + CatalystTypeConverters.createToScalaConverter(mapType) + } + + override def eval(input: InternalRow): Any = { + val result = inputData.eval(input).asInstanceOf[MapData] + if (result != null) { + toScalaValue(result) + } else { + null + } + } override def dataType: DataType = ObjectType(collClass) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index b0188b0098def..78b0dac1d1cd6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -27,12 +27,14 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone -import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{SQLDate, SQLTimestamp} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -148,9 +150,10 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "fromPrimitiveArray", ObjectType(classOf[Array[Int]]), Array[Int](1, 2, 3), UnsafeArrayData.fromPrimitiveArray(Array[Int](1, 2, 3))), (DateTimeUtils.getClass, ObjectType(classOf[Date]), - "toJavaDate", ObjectType(classOf[SQLDate]), 77777, DateTimeUtils.toJavaDate(77777)), + "toJavaDate", ObjectType(classOf[DateTimeUtils.SQLDate]), 77777, + DateTimeUtils.toJavaDate(77777)), (DateTimeUtils.getClass, ObjectType(classOf[Timestamp]), - "toJavaTimestamp", ObjectType(classOf[SQLTimestamp]), + "toJavaTimestamp", ObjectType(classOf[DateTimeUtils.SQLTimestamp]), 88888888.toLong, DateTimeUtils.toJavaTimestamp(88888888)) ).foreach { case (cls, dataType, methodName, argType, arg, expected) => checkObjectExprEvaluation(StaticInvoke(cls, dataType, methodName, @@ -383,6 +386,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } +<<<<<<< 0c94e48bc50717e1627c0d2acd5382d9adc73c97 test("LambdaVariable should support interpreted execution") { def genSchema(dt: DataType): Seq[StructType] = { Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil), @@ -415,6 +419,25 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + + implicit private def mapIntStrEncoder = ExpressionEncoder[Map[Int, String]]() + + test("SPARK-23588 CatalystToExternalMap should support interpreted execution") { + // To get a resolved `CatalystToExternalMap` expression, we build a deserializer plan + // with dummy input, resolve the plan by the analyzer, and replace the dummy input + // with a literal for tests. + val unresolvedDeser = UnresolvedDeserializer(encoderFor[Map[Int, String]].deserializer) + val dummyInputPlan = LocalRelation('value.map(MapType(IntegerType, StringType))) + val plan = Project(Alias(unresolvedDeser, "none")() :: Nil, dummyInputPlan) + + val analyzedPlan = SimpleAnalyzer.execute(plan) + val Alias(toMapExpr: CatalystToExternalMap, _) = analyzedPlan.expressions.head + + // Replaces the dummy input with a literal for tests here + val data = Map[Int, String](0 -> "v0", 1 -> "v1", 2 -> null, 3 -> "v3") + val deserializer = toMapExpr.copy(inputData = Literal.create(data)) + checkObjectExprEvaluation(deserializer, expected = data) + } } class TestBean extends Serializable { From 07f4c824f48f6eeb03b2b44749a0ac0be1d35a6f Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 11 Apr 2018 23:56:05 +0900 Subject: [PATCH 2/2] Fix --- .../expressions/objects/objects.scala | 29 +++++++++++++++---- .../expressions/ObjectExpressionsSuite.scala | 1 - 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 6eea60bbda415..56e6f4b1908e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1009,16 +1009,35 @@ case class CatalystToExternalMap private( override def children: Seq[Expression] = keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil - private lazy val toScalaValue: Any => Any = { - assert(inputData.dataType.isInstanceOf[MapType]) - val mapType = inputData.dataType.asInstanceOf[MapType] - CatalystTypeConverters.createToScalaConverter(mapType) + private lazy val inputMapType = inputData.dataType.asInstanceOf[MapType] + + private lazy val keyConverter = + CatalystTypeConverters.createToScalaConverter(inputMapType.keyType) + private lazy val valueConverter = + CatalystTypeConverters.createToScalaConverter(inputMapType.valueType) + + private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + val clazz = Utils.classForName(collClass.getCanonicalName + "$") + val module = clazz.getField("MODULE$").get(null) + val method = clazz.getMethod("newBuilder") + method.invoke(module).asInstanceOf[Builder[AnyRef, AnyRef]] } override def eval(input: InternalRow): Any = { val result = inputData.eval(input).asInstanceOf[MapData] if (result != null) { - toScalaValue(result) + val builder = newMapBuilder() + builder.sizeHint(result.numElements()) + val keyArray = result.keyArray() + val valueArray = result.valueArray() + var i = 0 + while (i < result.numElements()) { + val key = keyConverter(keyArray.get(i, inputMapType.keyType)) + val value = valueConverter(valueArray.get(i, inputMapType.valueType)) + builder += Tuple2(key, value) + i += 1 + } + builder.result() } else { null } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 78b0dac1d1cd6..fd96b1aa14816 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -386,7 +386,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } -<<<<<<< 0c94e48bc50717e1627c0d2acd5382d9adc73c97 test("LambdaVariable should support interpreted execution") { def genSchema(dt: DataType): Seq[StructType] = { Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil),