From d84302043371c6978e3c462d3316e4d29bf00854 Mon Sep 17 00:00:00 2001 From: Joey Date: Fri, 10 Mar 2023 17:36:08 +0800 Subject: [PATCH] decimal unscaled value (#3) * support decimal UnscaledValue * rename DecimalUtil to GlutenDecimalUtil --- .../DecimalExpressionsTransformer.scala | 43 +++++++++++++++++++ .../expression/ExpressionConverter.scala | 3 ++ .../expression/ExpressionMappings.scala | 7 ++- .../utils/GlutenDecimalUtil.scala | 32 ++++++++++++++ .../GlutenHashAggregateExecTransformer.scala | 21 +++++---- 5 files changed, 94 insertions(+), 12 deletions(-) create mode 100644 gluten-core/src/main/scala/io/glutenproject/expression/DecimalExpressionsTransformer.scala create mode 100644 gluten-core/src/main/scala/io/glutenproject/utils/GlutenDecimalUtil.scala diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/DecimalExpressionsTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/expression/DecimalExpressionsTransformer.scala new file mode 100644 index 0000000000000..0690387830f14 --- /dev/null +++ b/gluten-core/src/main/scala/io/glutenproject/expression/DecimalExpressionsTransformer.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.glutenproject.expression + +import com.google.common.collect.Lists +import io.glutenproject.expression.ConverterUtils.FunctionConfig +import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode} + +import org.apache.spark.sql.catalyst.expressions.Expression + +class UnscaledValueTransformer( + substraitExprName: String, + child: ExpressionTransformer, + original: Expression) extends ExpressionTransformer { + override def doTransform(args: java.lang.Object): ExpressionNode = { + val childNode = child.doTransform(args) + val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val functionName = + ConverterUtils.makeFuncName( + substraitExprName, + original.children.map(_.dataType), + FunctionConfig.REQ) + val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName) + val expressNodes = Lists.newArrayList(childNode) + val typeNode = ConverterUtils.getTypeNode(original.dataType, original.nullable) + ExpressionBuilder.makeScalarFunction(functionId, expressNodes, typeNode) + } +} \ No newline at end of file diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala index f7e94113315b9..751ee686de1ee 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala @@ -323,6 +323,9 @@ object ExpressionConverter extends Logging { case mapValues: MapValues => new UnaryArgumentCollectionOperationTransformer(substraitExprName, replaceWithExpressionTransformer(mapValues.child, attributeSeq), mapValues) + case unscaled: UnscaledValue => + new UnscaledValueTransformer(substraitExprName, + replaceWithExpressionTransformer(unscaled.child, attributeSeq), unscaled) case expr => logWarning(s"${expr.getClass} or ${expr} is not currently supported.") throw new UnsupportedOperationException( diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala index df1588359f618..7071b9ba0e10e 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala @@ -214,6 +214,9 @@ object ExpressionMappings { final val CUME_DIST = "cume_dist" final val PERCENT_RANK = "percent_rank" + // Decimal functions + final val UNSCALED_VALUE = "unscaled_value" + /** * Mapping Spark scalar expression to Substrait function name */ @@ -375,7 +378,9 @@ object ExpressionMappings { Sig[InSet](IN_SET), Sig[ScalarSubquery](SCALAR_SUBQUERY), Sig[CheckOverflow](CHECK_OVERFLOW), - Sig[PromotePrecision](PROMOTE_PRECISION) + Sig[PromotePrecision](PROMOTE_PRECISION), + // Decimal + Sig[UnscaledValue](UNSCALED_VALUE) ) ++ SparkShimLoader.getSparkShims.expressionMappings /** diff --git a/gluten-core/src/main/scala/io/glutenproject/utils/GlutenDecimalUtil.scala b/gluten-core/src/main/scala/io/glutenproject/utils/GlutenDecimalUtil.scala new file mode 100644 index 0000000000000..5a8ac598323b4 --- /dev/null +++ b/gluten-core/src/main/scala/io/glutenproject/utils/GlutenDecimalUtil.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.glutenproject.utils + +import scala.math.min + +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.DecimalType.{MAX_PRECISION, MAX_SCALE} + +object GlutenDecimalUtil { + object Fixed { + def unapply(t: DecimalType): Option[(Int, Int)] = Some((t.precision, t.scale)) + } + + def bounded(precision: Int, scale: Int): DecimalType = { + DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) + } +} \ No newline at end of file diff --git a/gluten-data/src/main/scala/io/glutenproject/execution/GlutenHashAggregateExecTransformer.scala b/gluten-data/src/main/scala/io/glutenproject/execution/GlutenHashAggregateExecTransformer.scala index 39a158e02492c..c329c3f580672 100644 --- a/gluten-data/src/main/scala/io/glutenproject/execution/GlutenHashAggregateExecTransformer.scala +++ b/gluten-data/src/main/scala/io/glutenproject/execution/GlutenHashAggregateExecTransformer.scala @@ -18,6 +18,7 @@ package io.glutenproject.execution import scala.collection.JavaConverters._ + import com.google.protobuf.Any import io.glutenproject.expression._ import io.glutenproject.expression.ConverterUtils.FunctionConfig @@ -28,11 +29,13 @@ import io.glutenproject.substrait.rel.{RelBuilder, RelNode} import java.util import io.glutenproject.substrait.{AggregationParams, SubstraitContext} +import io.glutenproject.utils.GlutenDecimalUtil + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DecimalType, DoubleType, LongType} +import org.apache.spark.sql.types.{DoubleType, LongType} case class GlutenHashAggregateExecTransformer( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -136,16 +139,12 @@ case class GlutenHashAggregateExecTransformer( val structTypeNodes = new util.ArrayList[TypeNode]() aggregateFunction match { case avg: Average => - avg.dataType match { - case _: DecimalType => - // Use struct type to represent Velox Row(DECIMAL, BIGINT). - structTypeNodes.add(ConverterUtils.getTypeNode(avg.dataType, nullable = true)) - structTypeNodes.add(ConverterUtils.getTypeNode(LongType, nullable = true)) - case _ => - // Use struct type to represent Velox Row(DOUBLE, BIGINT). - structTypeNodes.add(ConverterUtils.getTypeNode(DoubleType, nullable = true)) - structTypeNodes.add(ConverterUtils.getTypeNode(LongType, nullable = true)) - } + structTypeNodes.add(ConverterUtils.getTypeNode( + avg.dataType match { + case _ @ GlutenDecimalUtil.Fixed(p, s) => GlutenDecimalUtil.bounded(p + 10, s) + case _ => DoubleType + }, nullable = true)) + structTypeNodes.add(ConverterUtils.getTypeNode(LongType, nullable = true)) case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => // Use struct type to represent Velox Row(BIGINT, DOUBLE, DOUBLE). structTypeNodes.add(ConverterUtils.getTypeNode(LongType, nullable = true))