From 9c7d5f95418e8692daf962c124bec8f293c8379e Mon Sep 17 00:00:00 2001 From: Zhichao Zhang Date: Thu, 1 Sep 2022 09:17:43 +0800 Subject: [PATCH] [Gluten-343] Fix wrong results when executing avg(int), avg(long), avg(boolean) for ClickHouse backend (#345) --- .../CHHashAggregateExecTransformer.scala | 20 ++++++++++++++++++- .../sql/execution/utils/CHExecUtil.scala | 1 + .../GlutenClickHouseTPCDSParquetSuite.scala | 11 ++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala index 0aff55251f12..987201109aba 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala @@ -18,6 +18,7 @@ package io.glutenproject.execution import java.util +import java.util.Locale import com.google.protobuf.Any import io.glutenproject.expression._ @@ -89,14 +90,31 @@ case class CHHashAggregateExecTransformer( (child.output, aggregateResultAttributes) } else { for (attr <- aggregateResultAttributes) { - typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) val colName = if (aggregateAttributes.exists(_ == attr)) { + // for aggregate func ConverterUtils.genColumnNameWithExprId(attr) + "#Partial#" + ConverterUtils.getShortAttributeName(attr) } else { + // for group by cols ConverterUtils.genColumnNameWithExprId(attr) } nameList.add(colName) + // In final stage, when the output attr is the output of the avg func, + // CH needs to get the original data type as input type. + if (colName.toLowerCase(Locale.ROOT).startsWith("avg#")) { + val originalExpr = aggregateExpressions.find(_.resultAttribute == attr) + val originalType = if (originalExpr.isDefined && + originalExpr.get.asInstanceOf[AggregateExpression] + .aggregateFunction.isInstanceOf[Average]) { + originalExpr.get.asInstanceOf[AggregateExpression].aggregateFunction + .asInstanceOf[Average].child.dataType + } else { + attr.dataType + } + typeList.add(ConverterUtils.getTypeNode(originalType, attr.nullable)) + } else { + typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) + } } (aggregateResultAttributes, output) } diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/utils/CHExecUtil.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/utils/CHExecUtil.scala index 26e74c6d5d9e..aaed92f81e2b 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/utils/CHExecUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/utils/CHExecUtil.scala @@ -49,6 +49,7 @@ object CHExecUtil { case "Short" => ShortType case "String" => StringType case "Binary" => BinaryType + case "Boolean" => BooleanType } // scalastyle:off argcount def genShuffleDependency(rdd: RDD[ColumnarBatch], diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSParquetSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSParquetSuite.scala index af0d0e75751b..9d5d3893661a 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSParquetSuite.scala @@ -69,6 +69,17 @@ class GlutenClickHouseTPCDSParquetSuite extends GlutenClickHouseTPCDSAbstractSui assert(result(0).getDouble(0) == 379.21313271604936) } + test("test select avg(int), avg(long)") { + val testSql = + """ + |select avg(cs_item_sk), avg(cs_order_number) + | from catalog_sales + |""".stripMargin + val result = spark.sql(testSql).collect() + assert(result(0).getDouble(0) == 8998.463336886734) + assert(result(0).getDouble(1) == 80037.12727449503) + } + test("TPCDS Q9") { withSQLConf( ("spark.gluten.sql.columnar.columnartorow", "true")) {