Skip to content

Commit

Permalink
[SPARK-13427][SQL] Support USING clause in JOIN.
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Support queries that JOIN tables with USING clause.
SELECT * from table1 JOIN table2 USING <column_list>

USING clause can be used as a means to simplify the join condition
when :

1) Equijoin semantics is desired and
2) The column names in the equijoin have the same name.

We already have the support for Natural Join in Spark. This PR makes
use of the already existing infrastructure for natural join to
form the join condition and also the projection list.

## How was the this patch tested?

Have added unit tests in SQLQuerySuite, CatalystQlSuite, ResolveNaturalJoinSuite

Author: Dilip Biswal <[email protected]>

Closes #11297 from dilipbiswal/spark-13427.
  • Loading branch information
dilipbiswal authored and marmbrus committed Mar 17, 2016
1 parent 65b75e6 commit 637a78f
Show file tree
Hide file tree
Showing 12 changed files with 259 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,17 @@ fromClause
joinSource
@init { gParent.pushMsg("join source", state); }
@after { gParent.popMsg(state); }
: fromSource ( joinToken^ fromSource ( KW_ON! expression {$joinToken.start.getType() != COMMA}? )? )*
: fromSource ( joinToken^ fromSource ( joinCond {$joinToken.start.getType() != COMMA}? )? )*
| uniqueJoinToken^ uniqueJoinSource (COMMA! uniqueJoinSource)+
;
joinCond
@init { gParent.pushMsg("join expression list", state); }
@after { gParent.popMsg(state); }
: KW_ON! expression
| KW_USING LPAREN columnNameList RPAREN -> ^(TOK_USING columnNameList)
;
uniqueJoinSource
@init { gParent.pushMsg("unique join source", state); }
@after { gParent.popMsg(state); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ TOK_SETCONFIG;
TOK_DFS;
TOK_ADDFILE;
TOK_ADDJAR;
TOK_USING;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class Analyzer(
ResolveSubquery ::
ResolveWindowOrder ::
ResolveWindowFrame ::
ResolveNaturalJoin ::
ResolveNaturalAndUsingJoin ::
ExtractWindowExpressions ::
GlobalAggregates ::
ResolveAggregateFunctions ::
Expand Down Expand Up @@ -1329,48 +1329,69 @@ class Analyzer(
}

/**
* Removes natural joins by calculating output columns based on output from two sides,
* Then apply a Project on a normal Join to eliminate natural join.
* Removes natural or using joins by calculating output columns based on output from two sides,
* Then apply a Project on a normal Join to eliminate natural or using join.
*/
object ResolveNaturalJoin extends Rule[LogicalPlan] {
object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case j @ Join(left, right, UsingJoin(joinType, usingCols), condition)
if left.resolved && right.resolved && j.duplicateResolved =>
// Resolve the column names referenced in using clause from both the legs of join.
val lCols = usingCols.flatMap(col => left.resolveQuoted(col.name, resolver))
val rCols = usingCols.flatMap(col => right.resolveQuoted(col.name, resolver))
if ((lCols.length == usingCols.length) && (rCols.length == usingCols.length)) {
val joinNames = lCols.map(exp => exp.name)
commonNaturalJoinProcessing(left, right, joinType, joinNames, None)
} else {
j
}
case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural =>
// find common column names from both sides
val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get)
val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get)
val joinPairs = leftKeys.zip(rightKeys)

// Add joinPairs to joinConditions
val newCondition = (condition ++ joinPairs.map {
case (l, r) => EqualTo(l, r)
}).reduceOption(And)

// columns not in joinPairs
val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))

// the output list looks like: join keys, columns from left, columns from right
val projectList = joinType match {
case LeftOuter =>
leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
case RightOuter =>
rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
case FullOuter =>
// in full outer join, joinCols should be non-null if there is.
val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() }
joinedCols ++
lUniqueOutput.map(_.withNullability(true)) ++
rUniqueOutput.map(_.withNullability(true))
case Inner =>
rightKeys ++ lUniqueOutput ++ rUniqueOutput
case _ =>
sys.error("Unsupported natural join type " + joinType)
}
// use Project to trim unnecessary fields
Project(projectList, Join(left, right, joinType, newCondition))
commonNaturalJoinProcessing(left, right, joinType, joinNames, condition)
}
}

private def commonNaturalJoinProcessing(
left: LogicalPlan,
right: LogicalPlan,
joinType: JoinType,
joinNames: Seq[String],
condition: Option[Expression]) = {
val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get)
val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get)
val joinPairs = leftKeys.zip(rightKeys)

val newCondition = (condition ++ joinPairs.map(EqualTo.tupled)).reduceOption(And)

// columns not in joinPairs
val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))

// the output list looks like: join keys, columns from left, columns from right
val projectList = joinType match {
case LeftOuter =>
leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
case LeftSemi =>
leftKeys ++ lUniqueOutput
case RightOuter =>
rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
case FullOuter =>
// in full outer join, joinCols should be non-null if there is.
val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() }
joinedCols ++
lUniqueOutput.map(_.withNullability(true)) ++
rUniqueOutput.map(_.withNullability(true))
case Inner =>
leftKeys ++ lUniqueOutput ++ rUniqueOutput
case _ =>
sys.error("Unsupported natural join type " + joinType)
}
// use Project to trim unnecessary fields
Project(projectList, Join(left, right, joinType, newCondition))
}


}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.UsingJoin
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -109,6 +110,12 @@ trait CheckAnalysis {
s"filter expression '${f.condition.sql}' " +
s"of type ${f.condition.dataType.simpleString} is not a boolean.")

case j @ Join(_, _, UsingJoin(_, cols), _) =>
val from = operator.inputSet.map(_.name).mkString(", ")
failAnalysis(
s"using columns [${cols.mkString(",")}] " +
s"can not be resolved given input columns: [$from] ")

case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType =>
failAnalysis(
s"join condition '${condition.sql}' " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
case FullOuter => f // DO Nothing for Full Outer Join
case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
case UsingJoin(_, _) => sys.error("Untransformed Using join node")
}

// push down the join filter into sub query scanning if applicable
Expand Down Expand Up @@ -1168,6 +1169,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
Join(newLeft, newRight, LeftOuter, newJoinCond)
case FullOuter => f
case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
case UsingJoin(_, _) => sys.error("Untransformed Using join node")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,30 +419,47 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
sys.error(s"Unsupported join operation: $other")
}

val joinType = joinToken match {
case "TOK_JOIN" => Inner
case "TOK_CROSSJOIN" => Inner
case "TOK_RIGHTOUTERJOIN" => RightOuter
case "TOK_LEFTOUTERJOIN" => LeftOuter
case "TOK_FULLOUTERJOIN" => FullOuter
case "TOK_LEFTSEMIJOIN" => LeftSemi
case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node)
case "TOK_ANTIJOIN" => noParseRule("Anti Join", node)
case "TOK_NATURALJOIN" => NaturalJoin(Inner)
case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter)
case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter)
case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter)
}
val (joinType, joinCondition) = getJoinInfo(joinToken, other, node)

Join(nodeToRelation(relation1),
nodeToRelation(relation2),
joinType,
other.headOption.map(nodeToExpr))

joinCondition)
case _ =>
noParseRule("Relation", node)
}
}

protected def getJoinInfo(
joinToken: String,
joinConditionToken: Seq[ASTNode],
node: ASTNode): (JoinType, Option[Expression]) = {
val joinType = joinToken match {
case "TOK_JOIN" => Inner
case "TOK_CROSSJOIN" => Inner
case "TOK_RIGHTOUTERJOIN" => RightOuter
case "TOK_LEFTOUTERJOIN" => LeftOuter
case "TOK_FULLOUTERJOIN" => FullOuter
case "TOK_LEFTSEMIJOIN" => LeftSemi
case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node)
case "TOK_ANTIJOIN" => noParseRule("Anti Join", node)
case "TOK_NATURALJOIN" => NaturalJoin(Inner)
case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter)
case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter)
case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter)
}

joinConditionToken match {
case Token("TOK_USING", columnList :: Nil) :: Nil =>
val colNames = columnList.children.collect {
case Token(name, Nil) => UnresolvedAttribute(name)
}
(UsingJoin(joinType, colNames), None)
/* Join expression specified using ON clause */
case _ => (joinType, joinConditionToken.headOption.map(nodeToExpr))
}
}

protected def nodeToSortOrder(node: ASTNode): SortOrder = node match {
case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) =>
SortOrder(nodeToExpr(sortExpr), Ascending)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.plans

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute

object JoinType {
def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match {
case "inner" => Inner
Expand Down Expand Up @@ -66,3 +68,9 @@ case class NaturalJoin(tpe: JoinType) extends JoinType {
"Unsupported natural join type " + tpe)
override def sql: String = "NATURAL " + tpe.sql
}

case class UsingJoin(tpe: JoinType, usingColumns: Seq[UnresolvedAttribute]) extends JoinType {
require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter).contains(tpe),
"Unsupported using join type " + tpe)
override def sql: String = "USING " + tpe.sql
}
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,11 @@ case class Join(
condition.forall(_.dataType == BooleanType)
}

// if not a natural join, use `resolvedExceptNatural`. if it is a natural join, we still need
// to eliminate natural before we mark it resolved.
// if not a natural join, use `resolvedExceptNatural`. if it is a natural join or
// using join, we still need to eliminate natural or using before we mark it resolved.
override lazy val resolved: Boolean = joinType match {
case NaturalJoin(_) => false
case UsingJoin(_, _) => false
case _ => resolvedExceptNatural
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -35,56 +36,81 @@ class ResolveNaturalJoinSuite extends AnalysisTest {
lazy val r3 = LocalRelation(aNotNull, bNotNull)
lazy val r4 = LocalRelation(cNotNull, bNotNull)

test("natural inner join") {
val plan = r1.join(r2, NaturalJoin(Inner), None)
test("natural/using inner join") {
val naturalPlan = r1.join(r2, NaturalJoin(Inner), None)
val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("a"))), None)
val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c)
checkAnalysis(plan, expected)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}

test("natural left join") {
val plan = r1.join(r2, NaturalJoin(LeftOuter), None)
test("natural/using left join") {
val naturalPlan = r1.join(r2, NaturalJoin(LeftOuter), None)
val usingPlan = r1.join(r2, UsingJoin(LeftOuter, Seq(UnresolvedAttribute("a"))), None)
val expected = r1.join(r2, LeftOuter, Some(EqualTo(a, a))).select(a, b, c)
checkAnalysis(plan, expected)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}

test("natural right join") {
val plan = r1.join(r2, NaturalJoin(RightOuter), None)
test("natural/using right join") {
val naturalPlan = r1.join(r2, NaturalJoin(RightOuter), None)
val usingPlan = r1.join(r2, UsingJoin(RightOuter, Seq(UnresolvedAttribute("a"))), None)
val expected = r1.join(r2, RightOuter, Some(EqualTo(a, a))).select(a, b, c)
checkAnalysis(plan, expected)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}

test("natural full outer join") {
val plan = r1.join(r2, NaturalJoin(FullOuter), None)
test("natural/using full outer join") {
val naturalPlan = r1.join(r2, NaturalJoin(FullOuter), None)
val usingPlan = r1.join(r2, UsingJoin(FullOuter, Seq(UnresolvedAttribute("a"))), None)
val expected = r1.join(r2, FullOuter, Some(EqualTo(a, a))).select(
Alias(Coalesce(Seq(a, a)), "a")(), b, c)
checkAnalysis(plan, expected)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}

test("natural inner join with no nullability") {
val plan = r3.join(r4, NaturalJoin(Inner), None)
test("natural/using inner join with no nullability") {
val naturalPlan = r3.join(r4, NaturalJoin(Inner), None)
val usingPlan = r3.join(r4, UsingJoin(Inner, Seq(UnresolvedAttribute("b"))), None)
val expected = r3.join(r4, Inner, Some(EqualTo(bNotNull, bNotNull))).select(
bNotNull, aNotNull, cNotNull)
checkAnalysis(plan, expected)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}

test("natural left join with no nullability") {
val plan = r3.join(r4, NaturalJoin(LeftOuter), None)
test("natural/using left join with no nullability") {
val naturalPlan = r3.join(r4, NaturalJoin(LeftOuter), None)
val usingPlan = r3.join(r4, UsingJoin(LeftOuter, Seq(UnresolvedAttribute("b"))), None)
val expected = r3.join(r4, LeftOuter, Some(EqualTo(bNotNull, bNotNull))).select(
bNotNull, aNotNull, c)
checkAnalysis(plan, expected)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}

test("natural right join with no nullability") {
val plan = r3.join(r4, NaturalJoin(RightOuter), None)
test("natural/using right join with no nullability") {
val naturalPlan = r3.join(r4, NaturalJoin(RightOuter), None)
val usingPlan = r3.join(r4, UsingJoin(RightOuter, Seq(UnresolvedAttribute("b"))), None)
val expected = r3.join(r4, RightOuter, Some(EqualTo(bNotNull, bNotNull))).select(
bNotNull, a, cNotNull)
checkAnalysis(plan, expected)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}

test("natural full outer join with no nullability") {
val plan = r3.join(r4, NaturalJoin(FullOuter), None)
test("natural/using full outer join with no nullability") {
val naturalPlan = r3.join(r4, NaturalJoin(FullOuter), None)
val usingPlan = r3.join(r4, UsingJoin(FullOuter, Seq(UnresolvedAttribute("b"))), None)
val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, bNotNull))).select(
Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c)
checkAnalysis(plan, expected)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}

test("using unresolved attribute") {
val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("d"))), None)
val error = intercept[AnalysisException] {
SimpleAnalyzer.checkAnalysis(usingPlan)
}
assert(error.message.contains(
"using columns ['d] can not be resolved given input columns: [b, a, c]"))
}
}
Loading

0 comments on commit 637a78f

Please sign in to comment.