Skip to content

Commit

Permalink
[SPARK-12813][SQL] Eliminate serialization for back to back operations
Browse files Browse the repository at this point in the history
The goal of this PR is to eliminate unnecessary translations when there are back-to-back `MapPartitions` operations.  In order to achieve this I also made the following simplifications:

 - Operators no longer have hold encoders, instead they have only the expressions that they need.  The benefits here are twofold: the expressions are visible to transformations so go through the normal resolution/binding process.  now that they are visible we can change them on a case by case basis.
 - Operators no longer have type parameters.  Since the engine is responsible for its own type checking, having the types visible to the complier was an unnecessary complication.  We still leverage the scala compiler in the companion factory when constructing a new operator, but after this the types are discarded.

Deferred to a follow up PR:
 - Remove as much of the resolution/binding from Dataset/GroupedDataset as possible. We should still eagerly check resolution and throw an error though in the case of mismatches for an `as` operation.
 - Eliminate serializations in more cases by adding more cases to `EliminateSerialization`

Author: Michael Armbrust <[email protected]>

Closes #10747 from marmbrus/encoderExpressions.
  • Loading branch information
marmbrus committed Jan 15, 2016
1 parent 2578298 commit cc7af86
Show file tree
Hide file tree
Showing 17 changed files with 518 additions and 274 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,10 @@ object CleanupAliases extends Rule[LogicalPlan] {
Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases),
orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)

// Operators that operate on objects should only have expressions from encoders, which should
// never have extra aliases.
case o: ObjectOperator => o

case other =>
var stop = false
other transformExpressionsDown {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ abstract class Star extends LeafExpression with NamedExpression {
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
override lazy val resolved = false

def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression]
Expand Down Expand Up @@ -246,6 +247,8 @@ case class MultiAlias(child: Expression, names: Seq[String])

override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")

override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")

override lazy val resolved = false

override def toString: String = s"$child AS $names"
Expand All @@ -259,6 +262,7 @@ case class MultiAlias(child: Expression, names: Seq[String])
* @param expressions Expressions to expand.
*/
case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with Unevaluable {
override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = expressions
override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")")
}
Expand Down Expand Up @@ -298,6 +302,7 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def name: String = throw new UnresolvedException(this, "name")
override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")

override lazy val resolved = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,16 @@ case class ExpressionEncoder[T](
resolve(attrs, OuterScopes.outerScopes).bind(attrs)
}


/**
* Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form
* of this object.
*/
def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(toRowExpressions).map {
case (_, ne: NamedExpression) => ne.newInstance()
case (name, e) => Alias(e, name)()
}

/**
* Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
* toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._
case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
extends LeafExpression with NamedExpression {

override def toString: String = s"input[$ordinal, $dataType]"
override def toString: String = s"input[$ordinal, ${dataType.simpleString}]"

// Use special getter for primitive types (for UnsafeRow)
override def eval(input: InternalRow): Any = {
Expand Down Expand Up @@ -66,6 +66,8 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)

override def exprId: ExprId = throw new UnsupportedOperationException

override def newInstance(): NamedExpression = this

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ trait NamedExpression extends Expression {
/** Returns the metadata when an expression is a reference to another expression with metadata. */
def metadata: Metadata = Metadata.empty

/** Returns a copy of this expression with a new `exprId`. */
def newInstance(): NamedExpression

protected def typeSuffix =
if (resolved) {
dataType match {
Expand Down Expand Up @@ -144,6 +147,9 @@ case class Alias(child: Expression, name: String)(
}
}

def newInstance(): NamedExpression =
Alias(child, name)(qualifiers = qualifiers, explicitMetadata = explicitMetadata)

override def toAttribute: Attribute = {
if (resolved) {
AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ case class Invoke(
$objNullCheck
"""
}

override def toString: String = s"$targetObject.$functionName"
}

object NewInstance {
Expand Down Expand Up @@ -253,6 +255,8 @@ case class NewInstance(
"""
}
}

override def toString: String = s"newInstance($cls)"
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
RemoveDispensableExpressions,
SimplifyFilters,
SimplifyCasts,
SimplifyCaseConversionExpressions) ::
SimplifyCaseConversionExpressions,
EliminateSerialization) ::
Batch("Decimal Optimizations", FixedPoint(100),
DecimalAggregates) ::
Batch("LocalRelation", FixedPoint(100),
Expand Down Expand Up @@ -96,6 +97,19 @@ object SamplePushDown extends Rule[LogicalPlan] {
}
}

/**
* Removes cases where we are unnecessarily going between the object and serialized (InternalRow)
* representation of data item. For example back to back map operations.
*/
object EliminateSerialization extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case m @ MapPartitions(_, input, _, child: ObjectOperator)
if !input.isInstanceOf[Attribute] && m.input.dataType == child.outputObject.dataType =>
val childWithoutSerialization = child.withObjectOutput
m.copy(input = childWithoutSerialization.output.head, child = childWithoutSerialization)
}
}

/**
* Pushes certain operations to both sides of a Union, Intersect or Except operator.
* Operations that are safe to pushdown are listed as follows.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
Expand Down Expand Up @@ -480,120 +478,3 @@ case object OneRowRelation extends LeafNode {
*/
override def statistics: Statistics = Statistics(sizeInBytes = 1)
}

/**
* A relation produced by applying `func` to each partition of the `child`. tEncoder/uEncoder are
* used respectively to decode/encode from the JVM object representation expected by `func.`
*/
case class MapPartitions[T, U](
func: Iterator[T] => Iterator[U],
tEncoder: ExpressionEncoder[T],
uEncoder: ExpressionEncoder[U],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def producedAttributes: AttributeSet = outputSet
}

/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
def apply[T, U : Encoder](
func: T => U,
tEncoder: ExpressionEncoder[T],
child: LogicalPlan): AppendColumns[T, U] = {
val attrs = encoderFor[U].schema.toAttributes
new AppendColumns[T, U](func, tEncoder, encoderFor[U], attrs, child)
}
}

/**
* A relation produced by applying `func` to each partition of the `child`, concatenating the
* resulting columns at the end of the input row. tEncoder/uEncoder are used respectively to
* decode/encode from the JVM object representation expected by `func.`
*/
case class AppendColumns[T, U](
func: T => U,
tEncoder: ExpressionEncoder[T],
uEncoder: ExpressionEncoder[U],
newColumns: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output ++ newColumns
override def producedAttributes: AttributeSet = AttributeSet(newColumns)
}

/** Factory for constructing new `MapGroups` nodes. */
object MapGroups {
def apply[K, T, U : Encoder](
func: (K, Iterator[T]) => TraversableOnce[U],
kEncoder: ExpressionEncoder[K],
tEncoder: ExpressionEncoder[T],
groupingAttributes: Seq[Attribute],
child: LogicalPlan): MapGroups[K, T, U] = {
new MapGroups(
func,
kEncoder,
tEncoder,
encoderFor[U],
groupingAttributes,
encoderFor[U].schema.toAttributes,
child)
}
}

/**
* Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`.
* Func is invoked with an object representation of the grouping key an iterator containing the
* object representation of all the rows with that key.
*/
case class MapGroups[K, T, U](
func: (K, Iterator[T]) => TraversableOnce[U],
kEncoder: ExpressionEncoder[K],
tEncoder: ExpressionEncoder[T],
uEncoder: ExpressionEncoder[U],
groupingAttributes: Seq[Attribute],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def producedAttributes: AttributeSet = outputSet
}

/** Factory for constructing new `CoGroup` nodes. */
object CoGroup {
def apply[Key, Left, Right, Result : Encoder](
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
keyEnc: ExpressionEncoder[Key],
leftEnc: ExpressionEncoder[Left],
rightEnc: ExpressionEncoder[Right],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan): CoGroup[Key, Left, Right, Result] = {
CoGroup(
func,
keyEnc,
leftEnc,
rightEnc,
encoderFor[Result],
encoderFor[Result].schema.toAttributes,
leftGroup,
rightGroup,
left,
right)
}
}

/**
* A relation produced by applying `func` to each grouping key and associated values from left and
* right children.
*/
case class CoGroup[Key, Left, Right, Result](
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
keyEnc: ExpressionEncoder[Key],
leftEnc: ExpressionEncoder[Left],
rightEnc: ExpressionEncoder[Right],
resultEnc: ExpressionEncoder[Result],
output: Seq[Attribute],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan) extends BinaryNode {
override def producedAttributes: AttributeSet = outputSet
}
Loading

0 comments on commit cc7af86

Please sign in to comment.