Skip to content

Commit

Permalink
[SPARK-7269] [SQL] Incorrect analysis for aggregation(use semanticEqu…
Browse files Browse the repository at this point in the history
…als)

A modified version of #6110, use `semanticEquals` to make it more efficient.

Author: Wenchen Fan <[email protected]>

Closes #6173 from cloud-fan/7269 and squashes the following commits:

e4a3cc7 [Wenchen Fan] address comments
cc02045 [Wenchen Fan] consider elements length equal
d7ff8f4 [Wenchen Fan] fix 7269
  • Loading branch information
cloud-fan authored and marmbrus committed May 18, 2015
1 parent fc2480e commit 103c863
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet

/**
* A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
Expand Down Expand Up @@ -142,25 +141,6 @@ class Analyzer(
}

object ResolveGroupingAnalytics extends Rule[LogicalPlan] {
/**
* Extract attribute set according to the grouping id
* @param bitmask bitmask to represent the selected of the attribute sequence
* @param exprs the attributes in sequence
* @return the attributes of non selected specified via bitmask (with the bit set to 1)
*/
private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
: OpenHashSet[Expression] = {
val set = new OpenHashSet[Expression](2)

var bit = exprs.length - 1
while (bit >= 0) {
if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
bit -= 1
}

set
}

/*
* GROUP BY a, b, c WITH ROLLUP
* is equivalent to
Expand Down Expand Up @@ -197,10 +177,15 @@ class Analyzer(

g.bitmasks.foreach { bitmask =>
// get the non selected grouping attributes according to the bit mask
val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, g.groupByExprs)
val nonSelectedGroupExprs = ArrayBuffer.empty[Expression]
var bit = g.groupByExprs.length - 1
while (bit >= 0) {
if (((bitmask >> bit) & 1) == 0) nonSelectedGroupExprs += g.groupByExprs(bit)
bit -= 1
}

val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown {
case x: Expression if nonSelectedGroupExprSet.contains(x) =>
case x: Expression if nonSelectedGroupExprs.find(_ semanticEquals x).isDefined =>
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal.create(null, expr.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ trait CheckAnalysis {
case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
case e: Attribute if !groupingExprs.contains(e) =>
case e: Attribute if groupingExprs.find(_ semanticEquals e).isEmpty =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
s"nor is it an aggregate function. " +
"Add to group by or wrap in first() if you don't care which value you get.")
case e if groupingExprs.contains(e) => // OK
case e if groupingExprs.find(_ semanticEquals e).isDefined => // OK
case e if e.references.isEmpty => // OK
case e => e.children.foreach(checkValidAggregateExpression)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,19 @@ abstract class Expression extends TreeNode[Expression] {
case u: UnresolvedAttribute => PrettyAttribute(u.name)
}.toString
}

/**
* Returns true when two expressions will always compute the same result, even if they differ
* cosmetically (i.e. capitalization of names in attributes may be different).
*/
def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && {
val elements1 = this.productIterator.toSeq
val elements2 = other.asInstanceOf[Product].productIterator.toSeq
elements1.length == elements2.length && elements1.zip(elements2).forall {
case (e1: Expression, e2: Expression) => e1 semanticEquals e2
case (i1, i2) => i1 == i2
}
}
}

abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ case class AttributeReference(
case _ => false
}

override def semanticEquals(other: Expression): Boolean = other match {
case ar: AttributeReference => sameRef(ar)
case _ => false
}

override def hashCode: Int = {
// See http://stackoverflow.com/questions/113511/hash-code-implementation
var h = 17
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,10 @@ object PartialAggregation {
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
val trimmed = e.transform { case Alias(g: ExtractValue, _) => g }
namedGroupingExpressions
.get(e.transform { case Alias(g: ExtractValue, _) => g })
.map(_.toAttribute)
.find { case (k, v) => k semanticEquals trimmed }
.map(_._2.toAttribute)
.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -773,4 +773,22 @@ class SQLQuerySuite extends QueryTest {
| select * from v2 order by key limit 1
""".stripMargin), Row(0, 3))
}

test("SPARK-7269 Check analysis failed in case in-sensitive") {
Seq(1, 2, 3).map { i =>
(i.toString, i.toString)
}.toDF("key", "value").registerTempTable("df_analysis")
sql("SELECT kEy from df_analysis group by key").collect()
sql("SELECT kEy+3 from df_analysis group by key+3").collect()
sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect()
sql("SELECT cast(kEy+1 as Int) from df_analysis A group by cast(key+1 as int)").collect()
sql("SELECT cast(kEy+1 as Int) from df_analysis A group by key+1").collect()
sql("SELECT 2 from df_analysis A group by key+1").collect()
intercept[AnalysisException] {
sql("SELECT kEy+1 from df_analysis group by key+3")
}
intercept[AnalysisException] {
sql("SELECT cast(key+2 as Int) from df_analysis A group by cast(key+1 as int)")
}
}
}

0 comments on commit 103c863

Please sign in to comment.