Skip to content

Commit

Permalink
[SPARK-2327] [SQL] Fix nullabilities of Join/Generate/Aggregate.
Browse files Browse the repository at this point in the history
Fix nullabilities of `Join`/`Generate`/`Aggregate` because:
- Output attributes of opposite side of `OuterJoin` should be nullable.
- Output attributes of generater side of `Generate` should be nullable if `join` is `true` and `outer` is `true`.
- `AttributeReference` of `computedAggregates` of `Aggregate` should be the same as `aggregateExpression`'s.

Author: Takuya UESHIN <[email protected]>

Closes #1266 from ueshin/issues/SPARK-2327 and squashes the following commits:

3ace83a [Takuya UESHIN] Add withNullability to Attribute and use it to change nullabilities.
df1ae53 [Takuya UESHIN] Modify nullabilize to leave attribute if not resolved.
799ce56 [Takuya UESHIN] Add nullabilization to Generate of SparkPlan.
a0fc9bc [Takuya UESHIN] Fix scalastyle errors.
0e31e37 [Takuya UESHIN] Fix Aggregate resultAttribute nullabilities.
09532ec [Takuya UESHIN] Fix Generate output nullabilities.
f20f196 [Takuya UESHIN] Fix Join output nullabilities.
  • Loading branch information
ueshin authored and marmbrus committed Jul 5, 2014
1 parent 3da8df9 commit 9d5ecf8
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
override lazy val resolved = false

override def newInstance = this
override def withNullability(newNullability: Boolean) = this
override def withQualifiers(newQualifiers: Seq[String]) = this

// Unresolved attributes are transient at compile time and don't get evaluated during execution.
Expand Down Expand Up @@ -95,6 +96,7 @@ case class Star(
override lazy val resolved = false

override def newInstance = this
override def withNullability(newNullability: Boolean) = this
override def withQualifiers(newQualifiers: Seq[String]) = this

def expand(input: Seq[Attribute]): Seq[NamedExpression] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,16 @@ case class BoundReference(ordinal: Int, baseReference: Attribute)

type EvaluatedType = Any

def nullable = baseReference.nullable
def dataType = baseReference.dataType
def exprId = baseReference.exprId
def qualifiers = baseReference.qualifiers
def name = baseReference.name
override def nullable = baseReference.nullable
override def dataType = baseReference.dataType
override def exprId = baseReference.exprId
override def qualifiers = baseReference.qualifiers
override def name = baseReference.name

def newInstance = BoundReference(ordinal, baseReference.newInstance)
def withQualifiers(newQualifiers: Seq[String]) =
override def newInstance = BoundReference(ordinal, baseReference.newInstance)
override def withNullability(newNullability: Boolean) =
BoundReference(ordinal, baseReference.withNullability(newNullability))
override def withQualifiers(newQualifiers: Seq[String]) =
BoundReference(ordinal, baseReference.withQualifiers(newQualifiers))

override def toString = s"$baseReference:$ordinal"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ abstract class NamedExpression extends Expression {
abstract class Attribute extends NamedExpression {
self: Product =>

def withNullability(newNullability: Boolean): Attribute
def withQualifiers(newQualifiers: Seq[String]): Attribute

def toAttribute = this
Expand Down Expand Up @@ -133,7 +134,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
/**
* Returns a copy of this [[AttributeReference]] with changed nullability.
*/
def withNullability(newNullability: Boolean) = {
override def withNullability(newNullability: Boolean) = {
if (nullable == newNullability) {
this
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{LeftSemi, JoinType}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.types._

case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
Expand Down Expand Up @@ -46,10 +46,16 @@ case class Generate(
child: LogicalPlan)
extends UnaryNode {

protected def generatorOutput: Seq[Attribute] =
alias
protected def generatorOutput: Seq[Attribute] = {
val output = alias
.map(a => generator.output.map(_.withQualifiers(a :: Nil)))
.getOrElse(generator.output)
if (join && outer) {
output.map(_.withNullability(true))
} else {
output
}
}

override def output =
if (join) child.output ++ generatorOutput else generatorOutput
Expand Down Expand Up @@ -81,11 +87,20 @@ case class Join(
condition: Option[Expression]) extends BinaryNode {

override def references = condition.map(_.references).getOrElse(Set.empty)
override def output = joinType match {
case LeftSemi =>
left.output
case _ =>
left.output ++ right.output

override def output = {
joinType match {
case LeftSemi =>
left.output
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
case _ =>
left.output ++ right.output
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ case class Aggregate(
case a: AggregateExpression =>
ComputedAggregate(
a,
BindReferences.bindReference(a, childOutput).asInstanceOf[AggregateExpression],
AttributeReference(s"aggResult:$a", a.dataType, nullable = true)())
BindReferences.bindReference(a, childOutput),
AttributeReference(s"aggResult:$a", a.dataType, a.nullable)())
}
}.toArray

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.expressions.{Generator, JoinedRow, Literal, Projection}
import org.apache.spark.sql.catalyst.expressions._

/**
* :: DeveloperApi ::
Expand All @@ -39,8 +39,16 @@ case class Generate(
child: SparkPlan)
extends UnaryNode {

protected def generatorOutput: Seq[Attribute] = {
if (join && outer) {
generator.output.map(_.withNullability(true))
} else {
generator.output
}
}

override def output =
if (join) child.output ++ generator.output else generator.output
if (join) child.output ++ generatorOutput else generatorOutput

override def execute() = {
if (join) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,18 @@ case class BroadcastNestedLoopJoin(

override def otherCopyArgs = sqlContext :: Nil

def output = left.output ++ right.output
override def output = {
joinType match {
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
case _ =>
left.output ++ right.output
}
}

/** The Streamed Relation */
def left = streamed
Expand Down

0 comments on commit 9d5ecf8

Please sign in to comment.