Skip to content

Commit

Permalink
[SPARK-34037][SQL] Remove unnecessary upcasting for Avg & Sum which h…
Browse files Browse the repository at this point in the history
…andle by themself internally

### What changes were proposed in this pull request?
The type-coercion for numeric types of average and sum is not necessary at all, as the resultType and sumType can prevent the overflow.

### Why are the changes needed?

rm unnecessary logic which may cause potential performance regressions

### Does this PR introduce _any_ user-facing change?

no
### How was this patch tested?

tpcds tests for plan

Closes #31079 from yaooqinn/SPARK-34037.

Authored-by: Kent Yao <[email protected]>
Signed-off-by: Liang-Chi Hsieh <[email protected]>
  • Loading branch information
yaooqinn authored and viirya committed Jan 15, 2021
1 parent c75c29d commit a235c3b
Show file tree
Hide file tree
Showing 279 changed files with 1,485 additions and 1,496 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -634,17 +634,6 @@ object TypeCoercion {

m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })

// Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows.
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType))
case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType))

case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest.
case Average(e @ IntegralType()) if e.dataType != LongType =>
Average(Cast(e, LongType))
case Average(e @ FractionalType()) if e.dataType != DoubleType =>
Average(Cast(e, DoubleType))

// Hive lets you do aggregation of timestamps... for some reason
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
Expand Down
46 changes: 23 additions & 23 deletions sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,23 @@ struct<plan:string>

== Analyzed Logical Plan ==
sum(DISTINCT val): bigint
Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
Aggregate [sum(distinct val#x) AS sum(DISTINCT val)#xL]
+- SubqueryAlias spark_catalog.default.explain_temp1
+- Relation[key#x,val#x] parquet

== Optimized Logical Plan ==
Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
Aggregate [sum(distinct val#x) AS sum(DISTINCT val)#xL]
+- Project [val#x]
+- Relation[key#x,val#x] parquet

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[], functions=[sum(distinct cast(val#x as bigint)#xL)], output=[sum(DISTINCT val)#xL])
+- HashAggregate(keys=[], functions=[sum(distinct val#x)], output=[sum(DISTINCT val)#xL])
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
+- HashAggregate(keys=[], functions=[partial_sum(distinct cast(val#x as bigint)#xL)], output=[sum#xL])
+- HashAggregate(keys=[cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), ENSURE_REQUIREMENTS, [id=#x]
+- HashAggregate(keys=[cast(val#x as bigint) AS cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
+- HashAggregate(keys=[], functions=[partial_sum(distinct val#x)], output=[sum#xL])
+- HashAggregate(keys=[val#x], functions=[], output=[val#x])
+- Exchange hashpartitioning(val#x, 4), ENSURE_REQUIREMENTS, [id=#x]
+- HashAggregate(keys=[val#x], functions=[], output=[val#x])
+- FileScan parquet default.explain_temp1[val#x] Batched: true, DataFilters: [], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<val:int>


Expand Down Expand Up @@ -615,7 +615,7 @@ Input [2]: [key#x, val#x]
(14) HashAggregate
Input [1]: [key#x]
Keys: []
Functions [1]: [partial_avg(cast(key#x as bigint))]
Functions [1]: [partial_avg(key#x)]
Aggregate Attributes [2]: [sum#x, count#xL]
Results [2]: [sum#x, count#xL]

Expand All @@ -626,9 +626,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(16) HashAggregate
Input [2]: [sum#x, count#xL]
Keys: []
Functions [1]: [avg(cast(key#x as bigint))]
Aggregate Attributes [1]: [avg(cast(key#x as bigint))#x]
Results [1]: [avg(cast(key#x as bigint))#x AS avg(key)#x]
Functions [1]: [avg(key#x)]
Aggregate Attributes [1]: [avg(key#x)#x]
Results [1]: [avg(key#x)#x AS avg(key)#x]

(17) AdaptiveSparkPlan
Output [1]: [avg(key)#x]
Expand Down Expand Up @@ -681,7 +681,7 @@ ReadSchema: struct<key:int>
(5) HashAggregate
Input [1]: [key#x]
Keys: []
Functions [1]: [partial_avg(cast(key#x as bigint))]
Functions [1]: [partial_avg(key#x)]
Aggregate Attributes [2]: [sum#x, count#xL]
Results [2]: [sum#x, count#xL]

Expand All @@ -692,9 +692,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(7) HashAggregate
Input [2]: [sum#x, count#xL]
Keys: []
Functions [1]: [avg(cast(key#x as bigint))]
Aggregate Attributes [1]: [avg(cast(key#x as bigint))#x]
Results [1]: [avg(cast(key#x as bigint))#x AS avg(key)#x]
Functions [1]: [avg(key#x)]
Aggregate Attributes [1]: [avg(key#x)#x]
Results [1]: [avg(key#x)#x AS avg(key)#x]

(8) AdaptiveSparkPlan
Output [1]: [avg(key)#x]
Expand All @@ -717,7 +717,7 @@ ReadSchema: struct<key:int>
(10) HashAggregate
Input [1]: [key#x]
Keys: []
Functions [1]: [partial_avg(cast(key#x as bigint))]
Functions [1]: [partial_avg(key#x)]
Aggregate Attributes [2]: [sum#x, count#xL]
Results [2]: [sum#x, count#xL]

Expand All @@ -728,9 +728,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(12) HashAggregate
Input [2]: [sum#x, count#xL]
Keys: []
Functions [1]: [avg(cast(key#x as bigint))]
Aggregate Attributes [1]: [avg(cast(key#x as bigint))#x]
Results [1]: [avg(cast(key#x as bigint))#x AS avg(key)#x]
Functions [1]: [avg(key#x)]
Aggregate Attributes [1]: [avg(key#x)#x]
Results [1]: [avg(key#x)#x AS avg(key)#x]

(13) AdaptiveSparkPlan
Output [1]: [avg(key)#x]
Expand Down Expand Up @@ -947,7 +947,7 @@ ReadSchema: struct<key:int,val:int>
(2) HashAggregate
Input [2]: [key#x, val#x]
Keys: []
Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))]
Functions [3]: [partial_count(val#x), partial_sum(key#x), partial_count(key#x) FILTER (WHERE (val#x > 1))]
Aggregate Attributes [3]: [count#xL, sum#xL, count#xL]
Results [3]: [count#xL, sum#xL, count#xL]

Expand All @@ -958,9 +958,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(4) HashAggregate
Input [3]: [count#xL, sum#xL, count#xL]
Keys: []
Functions [3]: [count(val#x), sum(cast(key#x as bigint)), count(key#x)]
Aggregate Attributes [3]: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL]
Results [2]: [(count(val#x)#xL + sum(cast(key#x as bigint))#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL]
Functions [3]: [count(val#x), sum(key#x), count(key#x)]
Aggregate Attributes [3]: [count(val#x)#xL, sum(key#x)#xL, count(key#x)#xL]
Results [2]: [(count(val#x)#xL + sum(key#x)#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL]

(5) AdaptiveSparkPlan
Output [2]: [TOTAL#xL, count(key) FILTER (WHERE (val > 1))#xL]
Expand Down
38 changes: 19 additions & 19 deletions sql/core/src/test/resources/sql-tests/results/explain.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,22 @@ struct<plan:string>

== Analyzed Logical Plan ==
sum(DISTINCT val): bigint
Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
Aggregate [sum(distinct val#x) AS sum(DISTINCT val)#xL]
+- SubqueryAlias spark_catalog.default.explain_temp1
+- Relation[key#x,val#x] parquet

== Optimized Logical Plan ==
Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
Aggregate [sum(distinct val#x) AS sum(DISTINCT val)#xL]
+- Project [val#x]
+- Relation[key#x,val#x] parquet

== Physical Plan ==
*HashAggregate(keys=[], functions=[sum(distinct cast(val#x as bigint)#xL)], output=[sum(DISTINCT val)#xL])
*HashAggregate(keys=[], functions=[sum(distinct val#x)], output=[sum(DISTINCT val)#xL])
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
+- *HashAggregate(keys=[], functions=[partial_sum(distinct cast(val#x as bigint)#xL)], output=[sum#xL])
+- *HashAggregate(keys=[cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), ENSURE_REQUIREMENTS, [id=#x]
+- *HashAggregate(keys=[cast(val#x as bigint) AS cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
+- *HashAggregate(keys=[], functions=[partial_sum(distinct val#x)], output=[sum#xL])
+- *HashAggregate(keys=[val#x], functions=[], output=[val#x])
+- Exchange hashpartitioning(val#x, 4), ENSURE_REQUIREMENTS, [id=#x]
+- *HashAggregate(keys=[val#x], functions=[], output=[val#x])
+- *ColumnarToRow
+- FileScan parquet default.explain_temp1[val#x] Batched: true, DataFilters: [], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<val:int>

Expand Down Expand Up @@ -620,7 +620,7 @@ Input [2]: [key#x, val#x]
(15) HashAggregate [codegen id : 1]
Input [1]: [key#x]
Keys: []
Functions [1]: [partial_avg(cast(key#x as bigint))]
Functions [1]: [partial_avg(key#x)]
Aggregate Attributes [2]: [sum#x, count#xL]
Results [2]: [sum#x, count#xL]

Expand All @@ -631,9 +631,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(17) HashAggregate [codegen id : 2]
Input [2]: [sum#x, count#xL]
Keys: []
Functions [1]: [avg(cast(key#x as bigint))]
Aggregate Attributes [1]: [avg(cast(key#x as bigint))#x]
Results [1]: [avg(cast(key#x as bigint))#x AS avg(key)#x]
Functions [1]: [avg(key#x)]
Aggregate Attributes [1]: [avg(key#x)#x]
Results [1]: [avg(key#x)#x AS avg(key)#x]


-- !query
Expand Down Expand Up @@ -684,7 +684,7 @@ Input [1]: [key#x]
(6) HashAggregate [codegen id : 1]
Input [1]: [key#x]
Keys: []
Functions [1]: [partial_avg(cast(key#x as bigint))]
Functions [1]: [partial_avg(key#x)]
Aggregate Attributes [2]: [sum#x, count#xL]
Results [2]: [sum#x, count#xL]

Expand All @@ -695,9 +695,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(8) HashAggregate [codegen id : 2]
Input [2]: [sum#x, count#xL]
Keys: []
Functions [1]: [avg(cast(key#x as bigint))]
Aggregate Attributes [1]: [avg(cast(key#x as bigint))#x]
Results [1]: [avg(cast(key#x as bigint))#x AS avg(key)#x]
Functions [1]: [avg(key#x)]
Aggregate Attributes [1]: [avg(key#x)#x]
Results [1]: [avg(key#x)#x AS avg(key)#x]

Subquery:2 Hosting operator id = 3 Hosting Expression = ReusedSubquery Subquery scalar-subquery#x, [id=#x]

Expand Down Expand Up @@ -895,7 +895,7 @@ Input [2]: [key#x, val#x]
(3) HashAggregate [codegen id : 1]
Input [2]: [key#x, val#x]
Keys: []
Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))]
Functions [3]: [partial_count(val#x), partial_sum(key#x), partial_count(key#x) FILTER (WHERE (val#x > 1))]
Aggregate Attributes [3]: [count#xL, sum#xL, count#xL]
Results [3]: [count#xL, sum#xL, count#xL]

Expand All @@ -906,9 +906,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(5) HashAggregate [codegen id : 2]
Input [3]: [count#xL, sum#xL, count#xL]
Keys: []
Functions [3]: [count(val#x), sum(cast(key#x as bigint)), count(key#x)]
Aggregate Attributes [3]: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL]
Results [2]: [(count(val#x)#xL + sum(cast(key#x as bigint))#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL]
Functions [3]: [count(val#x), sum(key#x), count(key#x)]
Aggregate Attributes [3]: [count(val#x)#xL, sum(key#x)#xL, count(key#x)#xL]
Results [2]: [(count(val#x)#xL + sum(key#x)#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL]


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ select a, b, sum(b) from data group by 3
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
aggregate functions are not allowed in GROUP BY, but found sum(CAST(data.`b` AS BIGINT))
aggregate functions are not allowed in GROUP BY, but found sum(data.`b`)


-- !query
Expand All @@ -131,7 +131,7 @@ select a, b, sum(b) + 2 from data group by 3
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
aggregate functions are not allowed in GROUP BY, but found (sum(CAST(data.`b` AS BIGINT)) + CAST(2 AS BIGINT))
aggregate functions are not allowed in GROUP BY, but found (sum(data.`b`) + CAST(2 AS BIGINT))


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,8 @@ struct<>
org.apache.spark.sql.AnalysisException

Aggregate/Window/Generate expressions are not valid in where clause of the query.
Expression in where clause: [(sum(DISTINCT CAST((outer(a.`four`) + b.`four`) AS BIGINT)) = CAST(b.`four` AS BIGINT))]
Invalid expressions: [sum(DISTINCT CAST((outer(a.`four`) + b.`four`) AS BIGINT))]
Expression in where clause: [(sum(DISTINCT (outer(a.`four`) + b.`four`)) = CAST(b.`four` AS BIGINT))]
Invalid expressions: [sum(DISTINCT (outer(a.`four`) + b.`four`))]


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ AND t2b = (SELECT max(avg)
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
grouping expressions sequence is empty, and 't2.`t2b`' is not an aggregate function. Wrap '(avg(CAST(t2.`t2b` AS BIGINT)) AS `avg`)' in windowing function(s) or wrap 't2.`t2b`' in first() (or first_value) if you don't care which value you get.
grouping expressions sequence is empty, and 't2.`t2b`' is not an aggregate function. Wrap '(avg(t2.`t2b`) AS `avg`)' in windowing function(s) or wrap 't2.`t2b`' in first() (or first_value) if you don't care which value you get.


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,8 @@ struct<>
org.apache.spark.sql.AnalysisException

Aggregate/Window/Generate expressions are not valid in where clause of the query.
Expression in where clause: [(sum(DISTINCT CAST((outer(a.`four`) + b.`four`) AS BIGINT)) = CAST(CAST(udf(ansi_cast(four as string)) AS INT) AS BIGINT))]
Invalid expressions: [sum(DISTINCT CAST((outer(a.`four`) + b.`four`) AS BIGINT))]
Expression in where clause: [(sum(DISTINCT (outer(a.`four`) + b.`four`)) = CAST(CAST(udf(ansi_cast(four as string)) AS INT) AS BIGINT))]
Invalid expressions: [sum(DISTINCT (outer(a.`four`) + b.`four`))]


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ Results [5]: [i_brand#21, i_brand_id#20, i_manufact_id#22, i_manufact#23, sum#27

(37) Exchange
Input [5]: [i_brand#21, i_brand_id#20, i_manufact_id#22, i_manufact#23, sum#27]
Arguments: hashpartitioning(i_brand#21, i_brand_id#20, i_manufact_id#22, i_manufact#23, 5), true, [id=#28]
Arguments: hashpartitioning(i_brand#21, i_brand_id#20, i_manufact_id#22, i_manufact#23, 5), ENSURE_REQUIREMENTS, [id=#28]

(38) HashAggregate [codegen id : 7]
Input [5]: [i_brand#21, i_brand_id#20, i_manufact_id#22, i_manufact#23, sum#27]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ Results [5]: [i_brand#12, i_brand_id#11, i_manufact_id#13, i_manufact#14, sum#27

(37) Exchange
Input [5]: [i_brand#12, i_brand_id#11, i_manufact_id#13, i_manufact#14, sum#27]
Arguments: hashpartitioning(i_brand#12, i_brand_id#11, i_manufact_id#13, i_manufact#14, 5), true, [id=#28]
Arguments: hashpartitioning(i_brand#12, i_brand_id#11, i_manufact_id#13, i_manufact#14, 5), ENSURE_REQUIREMENTS, [id=#28]

(38) HashAggregate [codegen id : 7]
Input [5]: [i_brand#12, i_brand_id#11, i_manufact_id#13, i_manufact#14, sum#27]
Expand Down
Loading

0 comments on commit a235c3b

Please sign in to comment.