Skip to content

Commit

Permalink
[SPARK-2054][SQL] Code Generation for Expression Evaluation
Browse files Browse the repository at this point in the history
Adds a new method for evaluating expressions using code that is generated though Scala reflection.  This functionality is configured by the SQLConf option `spark.sql.codegen` and is currently turned off by default.

Evaluation can be done in several specialized ways:
 - *Projection* - Given an input row, produce a new row from a set of expressions that define each column in terms of the input row.  This can either produce a new Row object or perform the projection in-place on an existing Row (MutableProjection).
 - *Ordering* - Compares two rows based on a list of `SortOrder` expressions
 - *Condition* - Returns `true` or `false` given an input row.

For each of the above operations there is both a Generated and Interpreted version.  When generation for a given expression type is undefined, the code generator falls back on calling the `eval` function of the expression class.  Even without custom code, there is still a potential speed up, as loops are unrolled and code can still be inlined by JIT.

This PR also contains a new type of Aggregation operator, `GeneratedAggregate`, that performs aggregation by using generated `Projection` code.  Currently the required expression rewriting only works for simple aggregations like `SUM` and `COUNT`.  This functionality will be extended in a future PR.

This PR also performs several clean ups that simplified the implementation:
 - The notion of `Binding` all expressions in a tree automatically before query execution has been removed.  Instead it is the responsibly of an operator to provide the input schema when creating one of the specialized evaluators defined above.  In cases when the standard eval method is going to be called, binding can still be done manually using `BindReferences`.  There are a few reasons for this change:  First, there were many operators where it just didn't work before.  For example, operators with more than one child, and operators like aggregation that do significant rewriting of the expression. Second, the semantics of equality with `BoundReferences` are broken.  Specifically, we have had a few bugs where partitioning breaks because of the binding.
 - A copy of the current `SQLContext` is automatically propagated to all `SparkPlan` nodes by the query planner.  Before this was done ad-hoc for the nodes that needed this.  However, this required a lot of boilerplate as one had to always remember to make it `transient` and also had to modify the `otherCopyArgs`.

Author: Michael Armbrust <[email protected]>

Closes #993 from marmbrus/newCodeGen and squashes the following commits:

96ef82c [Michael Armbrust] Merge remote-tracking branch 'apache/master' into newCodeGen
f34122d [Michael Armbrust] Merge remote-tracking branch 'apache/master' into newCodeGen
67b1c48 [Michael Armbrust] Use conf variable in SQLConf object
4bdc42c [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen
41a40c9 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen
de22aac [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen
fed3634 [Michael Armbrust] Inspectors are not serializable.
ef8d42b [Michael Armbrust] comments
533fdfd [Michael Armbrust] More logging of expression rewriting for GeneratedAggregate.
3cd773e [Michael Armbrust] Allow codegen for Generate.
64b2ee1 [Michael Armbrust] Implement copy
3587460 [Michael Armbrust] Drop unused string builder function.
9cce346 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen
1a61293 [Michael Armbrust] Address review comments.
0672e8a [Michael Armbrust] Address comments.
1ec2d6e [Michael Armbrust] Address comments
033abc6 [Michael Armbrust] off by default
4771fab [Michael Armbrust] Docs, more test coverage.
d30fee2 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen
d2ad5c5 [Michael Armbrust] Refactor putting SQLContext into SparkPlan. Fix ordering, other test cases.
be2cd6b [Michael Armbrust] WIP: Remove old method for reference binding, more work on configuration.
bc88ecd [Michael Armbrust] Style
6cc97ca [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen
4220f1e [Michael Armbrust] Better config, docs, etc.
ca6cc6b [Michael Armbrust] WIP
9d67d85 [Michael Armbrust] Fix hive planner
fc522d5 [Michael Armbrust] Hook generated aggregation in to the planner.
e742640 [Michael Armbrust] Remove unneeded changes and code.
675e679 [Michael Armbrust] Upgrade paradise.
0093376 [Michael Armbrust] Comment / indenting cleanup.
d81f998 [Michael Armbrust] include schema for binding.
0e889e8 [Michael Armbrust] Use typeOf instead tq
f623ffd [Michael Armbrust] Quiet logging from test suite.
efad14f [Michael Armbrust] Remove some half finished functions.
92e74a4 [Michael Armbrust] add overrides
a2b5408 [Michael Armbrust] WIP: Code generation with scala reflection.
  • Loading branch information
marmbrus committed Jul 30, 2014
1 parent 22649b6 commit 8446746
Show file tree
Hide file tree
Showing 53 changed files with 1,889 additions and 297 deletions.
10 changes: 10 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
<sbt.project.name>spark</sbt.project.name>
<scala.version>2.10.4</scala.version>
<scala.binary.version>2.10</scala.binary.version>
<scala.macros.version>2.0.1</scala.macros.version>
<mesos.version>0.18.1</mesos.version>
<mesos.classifier>shaded-protobuf</mesos.classifier>
<akka.group>org.spark-project.akka</akka.group>
Expand Down Expand Up @@ -825,6 +826,15 @@
<javacArg>-target</javacArg>
<javacArg>${java.version}</javacArg>
</javacArgs>
<!-- The following plugin is required to use quasiquotes in Scala 2.10 and is used
by Spark SQL for code generation. -->
<compilerPlugins>
<compilerPlugin>
<groupId>org.scalamacros</groupId>
<artifactId>paradise_${scala.version}</artifactId>
<version>${scala.macros.version}</version>
</compilerPlugin>
</compilerPlugins>
</configuration>
</plugin>
<plugin>
Expand Down
11 changes: 8 additions & 3 deletions project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ object SparkBuild extends PomBuild {
/* Enable unidoc only for the root spark project */
enable(Unidoc.settings)(spark)

/* Catalyst macro settings */
enable(Catalyst.settings)(catalyst)

/* Spark SQL Core console settings */
enable(SQL.settings)(sql)

Expand All @@ -189,10 +192,13 @@ object Flume {
lazy val settings = sbtavro.SbtAvro.avroSettings
}

object SQL {

object Catalyst {
lazy val settings = Seq(
addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full))
}

object SQL {
lazy val settings = Seq(
initialCommands in console :=
"""
|import org.apache.spark.sql.catalyst.analysis._
Expand All @@ -207,7 +213,6 @@ object SQL {
|import org.apache.spark.sql.test.TestSQLContext._
|import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin
)

}

object Hive {
Expand Down
9 changes: 9 additions & 0 deletions sql/catalyst/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,19 @@
</properties>

<dependencies>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-compiler</artifactId>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-reflect</artifactId>
</dependency>
<dependency>
<groupId>org.scalamacros</groupId>
<artifactId>quasiquotes_${scala.binary.version}</artifactId>
<version>${scala.macros.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ package object dsl {
// Protobuf terminology
def required = a.withNullability(false)

def at(ordinal: Int) = BoundReference(ordinal, a)
def at(ordinal: Int) = BoundReference(ordinal, a.dataType, a.nullable)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,72 +17,40 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.trees

import org.apache.spark.sql.Logging

/**
* A bound reference points to a specific slot in the input tuple, allowing the actual value
* to be retrieved more efficiently. However, since operations like column pruning can change
* the layout of intermediate tuples, BindReferences should be run after all such transformations.
*/
case class BoundReference(ordinal: Int, baseReference: Attribute)
extends Attribute with trees.LeafNode[Expression] {
case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
extends Expression with trees.LeafNode[Expression] {

type EvaluatedType = Any

override def nullable = baseReference.nullable
override def dataType = baseReference.dataType
override def exprId = baseReference.exprId
override def qualifiers = baseReference.qualifiers
override def name = baseReference.name
override def references = Set.empty

override def newInstance = BoundReference(ordinal, baseReference.newInstance)
override def withNullability(newNullability: Boolean) =
BoundReference(ordinal, baseReference.withNullability(newNullability))
override def withQualifiers(newQualifiers: Seq[String]) =
BoundReference(ordinal, baseReference.withQualifiers(newQualifiers))

override def toString = s"$baseReference:$ordinal"
override def toString = s"input[$ordinal]"

override def eval(input: Row): Any = input(ordinal)
}

/**
* Used to denote operators that do their own binding of attributes internally.
*/
trait NoBind { self: trees.TreeNode[_] => }

class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] {
import BindReferences._

def apply(plan: TreeNode): TreeNode = {
plan.transform {
case n: NoBind => n.asInstanceOf[TreeNode]
case leafNode if leafNode.children.isEmpty => leafNode
case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e =>
bindReference(e, unaryNode.children.head.output)
}
}
}
}

object BindReferences extends Logging {
def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = {
expression.transform { case a: AttributeReference =>
attachTree(a, "Binding attribute") {
val ordinal = input.indexWhere(_.exprId == a.exprId)
if (ordinal == -1) {
// TODO: This fallback is required because some operators (such as ScriptTransform)
// produce new attributes that can't be bound. Likely the right thing to do is remove
// this rule and require all operators to explicitly bind to the input schema that
// they specify.
logger.debug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
a
sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
} else {
BoundReference(ordinal, a)
BoundReference(ordinal, a.dataType, a.nullable)
}
}
}.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

package org.apache.spark.sql.catalyst.expressions


/**
* Converts a [[Row]] to another Row given a sequence of expression that define each column of the
* new row. If the schema of the input row is specified, then the given expression will be bound to
* that schema.
* A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
* @param expressions a sequence of expressions that determine the value of each column of the
* output row.
*/
class Projection(expressions: Seq[Expression]) extends (Row => Row) {
class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))

Expand All @@ -40,25 +41,25 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) {
}

/**
* Converts a [[Row]] to another Row given a sequence of expression that define each column of th
* new row. If the schema of the input row is specified, then the given expression will be bound to
* that schema.
*
* In contrast to a normal projection, a MutableProjection reuses the same underlying row object
* each time an input row is added. This significantly reduces the cost of calculating the
* projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()`
* has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()`
* and hold on to the returned [[Row]] before calling `next()`.
* A [[MutableProjection]] that is calculated by calling `eval` on each of the specified
* expressions.
* @param expressions a sequence of expressions that determine the value of each column of the
* output row.
*/
case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) {
case class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))

private[this] val exprArray = expressions.toArray
private[this] val mutableRow = new GenericMutableRow(exprArray.size)
private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.size)
def currentValue: Row = mutableRow

def apply(input: Row): Row = {
override def target(row: MutableRow): MutableProjection = {
mutableRow = row
this
}

override def apply(input: Row): Row = {
var i = 0
while (i < exprArray.length) {
mutableRow(i) = exprArray(i).eval(input)
Expand All @@ -76,6 +77,12 @@ class JoinedRow extends Row {
private[this] var row1: Row = _
private[this] var row2: Row = _

def this(left: Row, right: Row) = {
this()
row1 = left
row2 = right
}

/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
def apply(r1: Row, r2: Row): Row = {
row1 = r1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,6 @@ trait MutableRow extends Row {
def setByte(ordinal: Int, value: Byte)
def setFloat(ordinal: Int, value: Float)
def setString(ordinal: Int, value: String)

/**
* Experimental
*
* Returns a mutable string builder for the specified column. A given row should return the
* result of any mutations made to the returned buffer next time getString is called for the same
* column.
*/
def getStringBuilder(ordinal: Int): StringBuilder
}

/**
Expand Down Expand Up @@ -180,15 +171,42 @@ class GenericRow(protected[catalyst] val values: Array[Any]) extends Row {
values(i).asInstanceOf[String]
}

// Custom hashCode function that matches the efficient code generated version.
override def hashCode(): Int = {
var result: Int = 37

var i = 0
while (i < values.length) {
val update: Int =
if (isNullAt(i)) {
0
} else {
apply(i) match {
case b: Boolean => if (b) 0 else 1
case b: Byte => b.toInt
case s: Short => s.toInt
case i: Int => i
case l: Long => (l ^ (l >>> 32)).toInt
case f: Float => java.lang.Float.floatToIntBits(f)
case d: Double =>
val b = java.lang.Double.doubleToLongBits(d)
(b ^ (b >>> 32)).toInt
case other => other.hashCode()
}
}
result = 37 * result + update
i += 1
}
result
}

def copy() = this
}

class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow {
/** No-arg constructor for serialization. */
def this() = this(0)

def getStringBuilder(ordinal: Int): StringBuilder = ???

override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value }
override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value }
override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi

override def eval(input: Row): Any = {
children.size match {
case 0 => function.asInstanceOf[() => Any]()
case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input))
case 2 =>
function.asInstanceOf[(Any, Any) => Any](
Expand Down
Loading

0 comments on commit 8446746

Please sign in to comment.