Skip to content

Commit

Permalink
fix UT
Browse files Browse the repository at this point in the history
  • Loading branch information
AngersZhuuuu committed Jun 7, 2020
1 parent c4ff823 commit 6d1b60e
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -510,12 +510,10 @@ class Analyzer(
// collect all the found AggregateExpression, so we can check an expression is part of
// any AggregateExpression or not.
val aggsBuffer = ArrayBuffer[Expression]()

// Returns whether the expression belongs to any expressions in `aggsBuffer` or not.
def isPartOfAggregation(e: Expression): Boolean = {
aggsBuffer.exists(a => a.find(_ eq e).isDefined)
}

replaceGroupingFunc(_, groupByExprs, gid).transformDown {
// AggregateExpression should be computed on the unmodified value of its argument
// expressions, so we should not replace any references to grouping expression
Expand Down Expand Up @@ -1489,14 +1487,14 @@ class Analyzer(
case p: LogicalPlan if needResolveStructField(p) =>
logTrace(s"Attempting to resolve ${p.simpleString(SQLConf.get.maxToStringFields)}")
val resolved = p.mapExpressions(resolveExpressionTopDown(_, p))
val structFieldMap = new mutable.HashMap[String, Alias]
val structFieldMap = mutable.Map[String, Alias]()
resolved.transformExpressions {
case a @ Alias(struct: GetStructField, _) =>
if (structFieldMap.contains(struct.sql)) {
val exprId = structFieldMap.getOrElse(struct.sql, a).exprId
Alias(a.child, a.name)(exprId, a.qualifier, a.explicitMetadata)
} else {
structFieldMap.put(struct.sql, a)
structFieldMap += (struct.sql -> a)
a
}
case e => e
Expand All @@ -1507,12 +1505,8 @@ class Analyzer(
q.mapExpressions(resolveExpressionTopDown(_, q))
}

def needResolveStructField(plan: LogicalPlan): Boolean = {
private def needResolveStructField(plan: LogicalPlan): Boolean = {
plan match {
case UnresolvedHaving(havingCondition, a: Aggregate)
if containSameStructFields(a.groupingExpressions.flatMap(_.references),
a.aggregateExpressions.flatMap(_.references),
Some(havingCondition.references.toSeq)) => true
case Aggregate(groupingExpressions, aggregateExpressions, _)
if containSameStructFields(groupingExpressions.flatMap(_.references),
aggregateExpressions.flatMap(_.references)) => true
Expand All @@ -1524,8 +1518,8 @@ class Analyzer(
}
}

def containSameStructFields(
grpExprs: Seq[Attribute],
private def containSameStructFields(
groupExprs: Seq[Attribute],
aggExprs: Seq[Attribute],
extra: Option[Seq[Attribute]] = None): Boolean = {

Expand All @@ -1534,7 +1528,7 @@ class Analyzer(
attr.asInstanceOf[UnresolvedAttribute].nameParts.size == 2
}

val grpAttrs = grpExprs.filter(isStructField)
val grpAttrs = groupExprs.filter(isStructField)
.map(_.asInstanceOf[UnresolvedAttribute].name)
val aggAttrs = aggExprs.filter(isStructField)
.map(_.asInstanceOf[UnresolvedAttribute].name)
Expand Down
93 changes: 55 additions & 38 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3496,85 +3496,102 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
checkIfSeedExistsInExplain(df2)
}

test("SPARK-31670: Struct Field in groupByExpr with CUBE") {
test("SPARK-31670: Resolve Struct Field in Grouping Aggregate with same ExprId") {
withTable("t") {
sql(
"""CREATE TABLE t(
|a STRING,
|b INT,
|c ARRAY<STRUCT<row_id:INT,json_string:STRING>>,
|d ARRAY<ARRAY<STRING>>,
|e ARRAY<MAP<STRING, INT>>)
|c STRUCT<row_id:INT,json_string:STRING>)
|USING ORC""".stripMargin)

sql(
"""
|INSERT INTO TABLE t
|SELECT * FROM VALUES
|('A', 1, NAMED_STRUCT('row_id', 1, 'json_string', '{"i": 1}')),
|('A', 2, NAMED_STRUCT('row_id', 2, 'json_string', '{"i": 1}')),
|('A', 2, NAMED_STRUCT('row_id', 2, 'json_string', '{"i": 2}')),
|('B', 1, NAMED_STRUCT('row_id', 3, 'json_string', '{"i": 1}')),
|('C', 3, NAMED_STRUCT('row_id', 4, 'json_string', '{"i": 1}'))
""".stripMargin)

checkAnswer(
sql(
"""
|SELECT a, each.json_string, SUM(b)
|SELECT a, c.json_string, SUM(b)
|FROM t
|LATERAL VIEW EXPLODE(c) x AS each
|GROUP BY a, each.json_string
|GROUP BY a, c.json_string
|WITH CUBE
|""".stripMargin), Nil)
|""".stripMargin),
Row("A", "{\"i\": 1}", 3) :: Row("A", "{\"i\": 2}", 2) :: Row("A", null, 5) ::
Row("B", "{\"i\": 1}", 1) :: Row("B", null, 1) ::
Row("C", "{\"i\": 1}", 3) :: Row("C", null, 3) ::
Row(null, "{\"i\": 1}", 7) :: Row(null, "{\"i\": 2}", 2) :: Row(null, null, 9) :: Nil)

checkAnswer(
sql(
"""
|SELECT a, get_json_object(each.json_string, '$.i'), SUM(b)
|SELECT a, get_json_object(c.json_string, '$.i'), SUM(b)
|FROM t
|LATERAL VIEW EXPLODE(c) X AS each
|GROUP BY a, get_json_object(each.json_string, '$.i')
|GROUP BY a, get_json_object(c.json_string, '$.i')
|WITH CUBE
|""".stripMargin), Nil)
|""".stripMargin),
Row("A", "1", 3) :: Row("A", "2", 2) :: Row("A", null, 5) ::
Row("B", "1", 1) :: Row("B", null, 1) ::
Row("C", "1", 3) :: Row("C", null, 3) ::
Row(null, "1", 7) :: Row(null, "2", 2) :: Row(null, null, 9) :: Nil)

checkAnswer(
sql(
"""
|SELECT a, each.json_string AS json_string, SUM(b)
|SELECT a, c.json_string AS json_string, SUM(b)
|FROM t
|LATERAL VIEW EXPLODE(c) x AS each
|GROUP BY a, each.json_string
|GROUP BY a, c.json_string
|WITH CUBE
|""".stripMargin), Nil)
|""".stripMargin),
Row("A", null, 5) :: Row("A", "{\"i\": 1}", 3) :: Row("A", "{\"i\": 2}", 2) ::
Row("B", null, 1) :: Row("B", "{\"i\": 1}", 1) ::
Row("C", null, 3) :: Row("C", "{\"i\": 1}", 3) ::
Row(null, null, 9) :: Row(null, "{\"i\": 1}", 7) :: Row(null, "{\"i\": 2}", 2) :: Nil)

checkAnswer(
sql(
"""
|SELECT a, each.json_string as js, SUM(b)
|SELECT a, c.json_string as js, SUM(b)
|FROM t
|LATERAL VIEW EXPLODE(c) X AS each
|GROUP BY a, each.json_string
|GROUP BY a, c.json_string
|WITH CUBE
|""".stripMargin), Nil)
|""".stripMargin),
Row("A", null, 5) :: Row("A", "{\"i\": 1}", 3) :: Row("A", "{\"i\": 2}", 2) ::
Row("B", null, 1) :: Row("B", "{\"i\": 1}", 1) ::
Row("C", null, 3) :: Row("C", "{\"i\": 1}", 3) ::
Row(null, null, 9) :: Row(null, "{\"i\": 1}", 7) :: Row(null, "{\"i\": 2}", 2) :: Nil)

checkAnswer(
sql(
"""
|SELECT a, each.json_string as js, SUM(b)
|SELECT a, c.json_string as js, SUM(b)
|FROM t
|LATERAL VIEW EXPLODE(c) X AS each
|GROUP BY a, each.json_string
|GROUP BY a, c.json_string
|WITH ROLLUP
|""".stripMargin), Nil)

sql(
"""
|SELECT a, each.json_string, SUM(b)
|FROM t
|LATERAL VIEW EXPLODE(c) X AS each
|GROUP BY a, each.json_string
|GROUPING sets((a),(a, each.json_string))
|""".stripMargin).explain(true)
|""".stripMargin),
Row("A", null, 5) :: Row("A", "{\"i\": 1}", 3) :: Row("A", "{\"i\": 2}", 2) ::
Row("B", null, 1) :: Row("B", "{\"i\": 1}", 1) ::
Row("C", null, 3) :: Row("C", "{\"i\": 1}", 3) ::
Row(null, null, 9) :: Nil)

checkAnswer(
sql(
"""
|SELECT a, each.json_string, SUM(b)
|SELECT a, c.json_string, SUM(b)
|FROM t
|LATERAL VIEW EXPLODE(c) X AS each
|GROUP BY a, each.json_string
|GROUPING sets((a),(a, each.json_string))
|""".stripMargin), Nil)
|GROUP BY a, c.json_string
|GROUPING sets((a),(a, c.json_string))
|""".stripMargin),
Row("A", null, 5) :: Row("A", "{\"i\": 1}", 3) :: Row("A", "{\"i\": 2}", 2) ::
Row("B", null, 1) :: Row("B", "{\"i\": 1}", 1) ::
Row("C", null, 3) :: Row("C", "{\"i\": 1}", 3) :: Nil)
}
}

Expand Down

0 comments on commit 6d1b60e

Please sign in to comment.