Skip to content

Commit

Permalink
[Gluten-343] Fix wrong results when executing avg(int), avg(long), av…
Browse files Browse the repository at this point in the history
…g(boolean) for ClickHouse backend (oap-project#345)
  • Loading branch information
zzcclp authored Sep 1, 2022
1 parent bfd7d04 commit 9c7d5f9
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package io.glutenproject.execution

import java.util
import java.util.Locale

import com.google.protobuf.Any
import io.glutenproject.expression._
Expand Down Expand Up @@ -89,14 +90,31 @@ case class CHHashAggregateExecTransformer(
(child.output, aggregateResultAttributes)
} else {
for (attr <- aggregateResultAttributes) {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
val colName = if (aggregateAttributes.exists(_ == attr)) {
// for aggregate func
ConverterUtils.genColumnNameWithExprId(attr) +
"#Partial#" + ConverterUtils.getShortAttributeName(attr)
} else {
// for group by cols
ConverterUtils.genColumnNameWithExprId(attr)
}
nameList.add(colName)
// In final stage, when the output attr is the output of the avg func,
// CH needs to get the original data type as input type.
if (colName.toLowerCase(Locale.ROOT).startsWith("avg#")) {
val originalExpr = aggregateExpressions.find(_.resultAttribute == attr)
val originalType = if (originalExpr.isDefined &&
originalExpr.get.asInstanceOf[AggregateExpression]
.aggregateFunction.isInstanceOf[Average]) {
originalExpr.get.asInstanceOf[AggregateExpression].aggregateFunction
.asInstanceOf[Average].child.dataType
} else {
attr.dataType
}
typeList.add(ConverterUtils.getTypeNode(originalType, attr.nullable))
} else {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
}
}
(aggregateResultAttributes, output)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ object CHExecUtil {
case "Short" => ShortType
case "String" => StringType
case "Binary" => BinaryType
case "Boolean" => BooleanType
}
// scalastyle:off argcount
def genShuffleDependency(rdd: RDD[ColumnarBatch],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@ class GlutenClickHouseTPCDSParquetSuite extends GlutenClickHouseTPCDSAbstractSui
assert(result(0).getDouble(0) == 379.21313271604936)
}

test("test select avg(int), avg(long)") {
val testSql =
"""
|select avg(cs_item_sk), avg(cs_order_number)
| from catalog_sales
|""".stripMargin
val result = spark.sql(testSql).collect()
assert(result(0).getDouble(0) == 8998.463336886734)
assert(result(0).getDouble(1) == 80037.12727449503)
}

test("TPCDS Q9") {
withSQLConf(
("spark.gluten.sql.columnar.columnartorow", "true")) {
Expand Down

0 comments on commit 9c7d5f9

Please sign in to comment.