Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kiszk committed Apr 16, 2017
1 parent 0fd8c25 commit 791aad9
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
CombineUnions,
// Constant folding and strength reduction
NullPropagation(conf),
EliminateMapObjects,
FoldablePropagation,
OptimizeIn(conf),
ConstantFolding,
Expand Down Expand Up @@ -120,7 +119,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
CostBasedJoinReorder(conf)) ::
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates(conf)) ::
Batch("Typed Filter Optimization", fixedPoint,
Batch("Object Expressions Optimization", fixedPoint,
EliminateMapObjects,
CombineTypedFilters) ::
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] {
case EqualNullSafe(Literal(null, _), r) => IsNull(r)
case EqualNullSafe(l, Literal(null, _)) => IsNull(l)

case _ @ AssertNotNull(c, _) if !c.nullable => c
case AssertNotNull(c, _) if !c.nullable => c

// For Coalesce, remove null literals.
case e @ Coalesce(children) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,18 @@ object EliminateMapObjects extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case _ @ DeserializeToObject(Invoke(
MapObjects(_, _, _, Cast(LambdaVariable(_, _, dataType, _), castDataType, _),
inputData, None, _),
inputData, None),
funcName, returnType: ObjectType, arguments, propagateNull, returnNullable),
outputObjAttr, child) if dataType == castDataType =>
DeserializeToObject(Invoke(
inputData, funcName, returnType, arguments, propagateNull, returnNullable),
outputObjAttr, child)
case _ @ DeserializeToObject(Invoke(
MapObjects(_, _, _, LambdaVariable(_, _, dataType, _), inputData, None),
funcName, returnType: ObjectType, arguments, propagateNull, returnNullable),
outputObjAttr, child) =>
DeserializeToObject(Invoke(
inputData, funcName, returnType, arguments, propagateNull, returnNullable),
outputObjAttr, child)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,17 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

class EliminateMapObjectsSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
class Optimize(addSimplifyCast: Boolean) extends RuleExecutor[LogicalPlan] {
val batches = if (addSimplifyCast) {
Batch("EliminateMapObjects", FixedPoint(50),
NullPropagation(conf),
SimplifyCasts,
EliminateMapObjects) :: Nil
} else {
Batch("EliminateMapObjects", FixedPoint(50),
NullPropagation(conf),
EliminateMapObjects) :: Nil
}
}

implicit private def intArrayEncoder = ExpressionEncoder[Array[Int]]()
Expand All @@ -42,19 +48,23 @@ class EliminateMapObjectsSuite extends PlanTest {
val intObjType = ObjectType(classOf[Array[Int]])
val intInput = LocalRelation('a.array(ArrayType(IntegerType, false)))
val intQuery = intInput.deserialize[Array[Int]].analyze
val intOptimized = Optimize.execute(intQuery)
val intExpected = DeserializeToObject(
Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false),
AttributeReference("obj", intObjType, true)(), intInput)
comparePlans(intOptimized, intExpected)
Seq(true, false).foreach { addSimplifyCast =>
val intOptimized = new Optimize(addSimplifyCast).execute(intQuery)
val intExpected = DeserializeToObject(
Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false),
AttributeReference("obj", intObjType, true)(), intInput)
comparePlans(intOptimized, intExpected)
}

val doubleObjType = ObjectType(classOf[Array[Double]])
val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false)))
val doubleQuery = doubleInput.deserialize[Array[Double]].analyze
val doubleOptimized = Optimize.execute(doubleQuery)
val doubleExpected = DeserializeToObject(
Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false),
AttributeReference("obj", doubleObjType, true)(), doubleInput)
comparePlans(doubleOptimized, doubleExpected)
Seq(true, false).foreach { addSimplifyCast =>
val doubleOptimized = new Optimize(addSimplifyCast).execute(doubleQuery)
val doubleExpected = DeserializeToObject(
Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false),
AttributeReference("obj", doubleObjType, true)(), doubleInput)
comparePlans(doubleOptimized, doubleExpected)
}
}
}

0 comments on commit 791aad9

Please sign in to comment.