Skip to content

Commit

Permalink
UT: Add more test
Browse files Browse the repository at this point in the history
  • Loading branch information
karuppayya committed Jun 19, 2020
1 parent 1c0399a commit 73de4c8
Showing 1 changed file with 29 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val data = Seq(("James", 1), ("James", 1), ("Phil", 1))
val aggDF = data.toDF("name", "values").groupBy("name").sum("values")
val partAggNode = aggDF.queryExecution.executedPlan.find {
case h: HashAggregateExec
if AggUtils.areAggExpressionsPartial(h.aggregateExpressions) => true
case h: HashAggregateExec =>
AggUtils.areAggExpressionsPartial(h.aggregateExpressions.map(_.mode))
case _ => false
}
checkAnswer(aggDF, Seq(Row("James", 2), Row("Phil", 1)))
Expand All @@ -69,6 +69,33 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
}
}

test(s"Partial aggregation should not happen when no Aggregate expr" ) {
withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "true")) {
val aggDF = testData2.select(sumDistinct($"a"))
val aggNodes = aggDF.queryExecution.executedPlan.collect {
case h: HashAggregateExec => h
}
checkAnswer(aggDF, Row(6))
assert(aggNodes.nonEmpty)
Thread.sleep(1000000)
assert(aggNodes.forall(_.metrics("partialAggSkipped").value == 0))
}
}

test(s"Distinct: Partial aggregation should happen for" +
s" HashAggregate nodes performing partial Aggregate operations " ) {
withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "true")) {
val aggDF = testData2.select(sumDistinct($"a"), sum($"b"))
val aggNodes = aggDF.queryExecution.executedPlan.collect {
case h: HashAggregateExec => h
}
val (baseNodes, other) = aggNodes.partition(_.child.isInstanceOf[SerializeFromObjectExec])
checkAnswer(aggDF, Row(6, 9))
assert(baseNodes.size == 1 )
assert(baseNodes.head.metrics("partialAggSkipped").value == testData2.count())
assert(other.forall(_.metrics("partialAggSkipped").value == 0))
}
}

test("Aggregate with grouping keys should be included in WholeStageCodegen") {
val df = spark.range(3).groupBy(col("id") * 2).count().orderBy(col("id") * 2)
Expand Down

0 comments on commit 73de4c8

Please sign in to comment.