Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kiszk committed Apr 17, 2017
1 parent 791aad9 commit f695e50
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,12 @@ object CombineTypedFilters extends Rule[LogicalPlan] {
* 1. Mapobject(e) where e is lambdavariable(), which means types for input output
* are primitive types
* 2. no custom collection class specified
* representation of data item. For example back to back map operations.
* representation of data item. For example back to back map operations.
*/
object EliminateMapObjects extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case _ @ DeserializeToObject(Invoke(
MapObjects(_, _, _, Cast(LambdaVariable(_, _, dataType, _), castDataType, _),
inputData, None),
funcName, returnType: ObjectType, arguments, propagateNull, returnNullable),
outputObjAttr, child) if dataType == castDataType =>
DeserializeToObject(Invoke(
inputData, funcName, returnType, arguments, propagateNull, returnNullable),
outputObjAttr, child)
case _ @ DeserializeToObject(Invoke(
MapObjects(_, _, _, LambdaVariable(_, _, dataType, _), inputData, None),
case DeserializeToObject(Invoke(
MapObjects(_, _, _, _ : LambdaVariable, inputData, None),
funcName, returnType: ObjectType, arguments, propagateNull, returnNullable),
outputObjAttr, child) =>
DeserializeToObject(Invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,12 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

class EliminateMapObjectsSuite extends PlanTest {
class Optimize(addSimplifyCast: Boolean) extends RuleExecutor[LogicalPlan] {
val batches = if (addSimplifyCast) {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = {
Batch("EliminateMapObjects", FixedPoint(50),
NullPropagation(conf),
SimplifyCasts,
EliminateMapObjects) :: Nil
} else {
Batch("EliminateMapObjects", FixedPoint(50),
NullPropagation(conf),
EliminateMapObjects) :: Nil
}
}

Expand All @@ -48,23 +44,19 @@ class EliminateMapObjectsSuite extends PlanTest {
val intObjType = ObjectType(classOf[Array[Int]])
val intInput = LocalRelation('a.array(ArrayType(IntegerType, false)))
val intQuery = intInput.deserialize[Array[Int]].analyze
Seq(true, false).foreach { addSimplifyCast =>
val intOptimized = new Optimize(addSimplifyCast).execute(intQuery)
val intExpected = DeserializeToObject(
Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false),
AttributeReference("obj", intObjType, true)(), intInput)
comparePlans(intOptimized, intExpected)
}
val intOptimized = Optimize.execute(intQuery)
val intExpected = DeserializeToObject(
Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false),
AttributeReference("obj", intObjType, true)(), intInput)
comparePlans(intOptimized, intExpected)

val doubleObjType = ObjectType(classOf[Array[Double]])
val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false)))
val doubleQuery = doubleInput.deserialize[Array[Double]].analyze
Seq(true, false).foreach { addSimplifyCast =>
val doubleOptimized = new Optimize(addSimplifyCast).execute(doubleQuery)
val doubleExpected = DeserializeToObject(
Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false),
AttributeReference("obj", doubleObjType, true)(), doubleInput)
comparePlans(doubleOptimized, doubleExpected)
}
val doubleOptimized = Optimize.execute(doubleQuery)
val doubleExpected = DeserializeToObject(
Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false),
AttributeReference("obj", doubleObjType, true)(), doubleInput)
comparePlans(doubleOptimized, doubleExpected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ package org.apache.spark.sql
import scala.collection.immutable.Queue
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.execution.DeserializeToObjectExec
import org.apache.spark.sql.test.SharedSQLContext

case class IntClass(value: Int)
Expand Down

0 comments on commit f695e50

Please sign in to comment.