Skip to content

Commit

Permalink
fix 7269
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed May 15, 2015
1 parent 6d0633e commit d7ff8f4
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet

/**
* A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
Expand Down Expand Up @@ -148,17 +147,17 @@ class Analyzer(
* @param exprs the attributes in sequence
* @return the attributes of non selected specified via bitmask (with the bit set to 1)
*/
private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
: OpenHashSet[Expression] = {
val set = new OpenHashSet[Expression](2)
private def buildNonSelectExprs(bitmask: Int, exprs: Seq[Expression])
: Seq[Expression] = {
val buffer = ArrayBuffer.empty[Expression]

var bit = exprs.length - 1
while (bit >= 0) {
if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
if (((bitmask >> bit) & 1) == 0) buffer += exprs(bit)
bit -= 1
}

set
buffer
}

/*
Expand Down Expand Up @@ -197,10 +196,10 @@ class Analyzer(

g.bitmasks.foreach { bitmask =>
// get the non selected grouping attributes according to the bit mask
val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, g.groupByExprs)
val nonSelectedGroupExprs = buildNonSelectExprs(bitmask, g.groupByExprs)

val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown {
case x: Expression if nonSelectedGroupExprSet.contains(x) =>
case x: Expression if nonSelectedGroupExprs.find(_ semanticEquals x).isDefined =>
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal.create(null, expr.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ trait CheckAnalysis {
case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
case e: Attribute if !groupingExprs.contains(e) =>
case e: Attribute if groupingExprs.find(_ semanticEquals e).isEmpty =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
s"nor is it an aggregate function. " +
"Add to group by or wrap in first() if you don't care which value you get.")
case e if groupingExprs.contains(e) => // OK
case e if groupingExprs.find(_ semanticEquals e).isDefined => // OK
case e if e.references.isEmpty => // OK
case e => e.children.foreach(checkValidAggregateExpression)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ abstract class Expression extends TreeNode[Expression] {
case u: UnresolvedAttribute => PrettyAttribute(u.name)
}.toString
}

def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass &&
this.productIterator.zip(other.asInstanceOf[Product].productIterator).forall {
case (e1: Expression, e2: Expression) => e1 semanticEquals e2
case (i1, i2) => i1 == i2
}
}

abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ case class AttributeReference(
case _ => false
}

override def semanticEquals(other: Expression): Boolean = other match {
case ar: AttributeReference => exprId == ar.exprId
case _ => false
}

override def hashCode: Int = {
// See http://stackoverflow.com/questions/113511/hash-code-implementation
var h = 17
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,10 @@ object PartialAggregation {
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
val trimmed = e.transform { case Alias(g: ExtractValue, _) => g }
namedGroupingExpressions
.get(e.transform { case Alias(g: ExtractValue, _) => g })
.map(_.toAttribute)
.find { case (k, v) => k semanticEquals trimmed }
.map(_._2.toAttribute)
.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -773,4 +773,22 @@ class SQLQuerySuite extends QueryTest {
| select * from v2 order by key limit 1
""".stripMargin), Row(0, 3))
}

test("SPARK-7269 Check analysis failed in case in-sensitive") {
Seq(1, 2, 3).map { i =>
(i.toString, i.toString)
}.toDF("key", "value").registerTempTable("df_analysis")
sql("SELECT kEy from df_analysis group by key").collect()
sql("SELECT kEy+3 from df_analysis group by key+3").collect()
sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect()
sql("SELECT cast(kEy+1 as Int) from df_analysis A group by cast(key+1 as int)").collect()
sql("SELECT cast(kEy+1 as Int) from df_analysis A group by key+1").collect()
sql("SELECT 2 from df_analysis A group by key+1").collect()
intercept[AnalysisException] {
sql("SELECT kEy+1 from df_analysis group by key+3")
}
intercept[AnalysisException] {
sql("SELECT cast(key+2 as Int) from df_analysis A group by cast(key+1 as int)")
}
}
}

0 comments on commit d7ff8f4

Please sign in to comment.