From f695e50e38bd329db3b75951dd7af52fea3b3dde Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 17 Apr 2017 13:31:37 +0900 Subject: [PATCH] address review comments --- .../sql/catalyst/optimizer/objects.scala | 14 ++------ .../optimizer/EliminateMapObjectsSuite.scala | 32 +++++++------------ .../spark/sql/DatasetPrimitiveSuite.scala | 2 -- 3 files changed, 15 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index a6b1c55a5d750..55288ac654a65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -104,20 +104,12 @@ object CombineTypedFilters extends Rule[LogicalPlan] { * 1. Mapobject(e) where e is lambdavariable(), which means types for input output * are primitive types * 2. no custom collection class specified - * representation of data item. For example back to back map operations. + * representation of data item. For example back to back map operations. */ object EliminateMapObjects extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case _ @ DeserializeToObject(Invoke( - MapObjects(_, _, _, Cast(LambdaVariable(_, _, dataType, _), castDataType, _), - inputData, None), - funcName, returnType: ObjectType, arguments, propagateNull, returnNullable), - outputObjAttr, child) if dataType == castDataType => - DeserializeToObject(Invoke( - inputData, funcName, returnType, arguments, propagateNull, returnNullable), - outputObjAttr, child) - case _ @ DeserializeToObject(Invoke( - MapObjects(_, _, _, LambdaVariable(_, _, dataType, _), inputData, None), + case DeserializeToObject(Invoke( + MapObjects(_, _, _, _ : LambdaVariable, inputData, None), funcName, returnType: ObjectType, arguments, propagateNull, returnNullable), outputObjAttr, child) => DeserializeToObject(Invoke( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala index d274379f2294f..d4f37e2a5e877 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala @@ -28,16 +28,12 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types._ class EliminateMapObjectsSuite extends PlanTest { - class Optimize(addSimplifyCast: Boolean) extends RuleExecutor[LogicalPlan] { - val batches = if (addSimplifyCast) { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = { Batch("EliminateMapObjects", FixedPoint(50), NullPropagation(conf), SimplifyCasts, EliminateMapObjects) :: Nil - } else { - Batch("EliminateMapObjects", FixedPoint(50), - NullPropagation(conf), - EliminateMapObjects) :: Nil } } @@ -48,23 +44,19 @@ class EliminateMapObjectsSuite extends PlanTest { val intObjType = ObjectType(classOf[Array[Int]]) val intInput = LocalRelation('a.array(ArrayType(IntegerType, false))) val intQuery = intInput.deserialize[Array[Int]].analyze - Seq(true, false).foreach { addSimplifyCast => - val intOptimized = new Optimize(addSimplifyCast).execute(intQuery) - val intExpected = DeserializeToObject( - Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false), - AttributeReference("obj", intObjType, true)(), intInput) - comparePlans(intOptimized, intExpected) - } + val intOptimized = Optimize.execute(intQuery) + val intExpected = DeserializeToObject( + Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false), + AttributeReference("obj", intObjType, true)(), intInput) + comparePlans(intOptimized, intExpected) val doubleObjType = ObjectType(classOf[Array[Double]]) val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false))) val doubleQuery = doubleInput.deserialize[Array[Double]].analyze - Seq(true, false).foreach { addSimplifyCast => - val doubleOptimized = new Optimize(addSimplifyCast).execute(doubleQuery) - val doubleExpected = DeserializeToObject( - Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false), - AttributeReference("obj", doubleObjType, true)(), doubleInput) - comparePlans(doubleOptimized, doubleExpected) - } + val doubleOptimized = Optimize.execute(doubleQuery) + val doubleExpected = DeserializeToObject( + Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false), + AttributeReference("obj", doubleObjType, true)(), doubleInput) + comparePlans(doubleOptimized, doubleExpected) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 42ef298d7d088..541565344f758 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql import scala.collection.immutable.Queue import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.execution.DeserializeToObjectExec import org.apache.spark.sql.test.SharedSQLContext case class IntClass(value: Int)