diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala index 8467952d84a8..ffa5b91162f1 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala @@ -199,6 +199,7 @@ object ExpressionMappings { final val SCALAR_SUBQUERY = "scalar_subquery" final val EXPLODE = "explode" final val CHECK_OVERFLOW = "ss_check_overflow" + final val MAKE_DECIMAL = "make_decimal" final val PROMOTE_PRECISION = "promote_precision" final val ROW_CONSTRUCTOR = "row_constructor" @@ -376,6 +377,7 @@ object ExpressionMappings { Sig[InSet](IN_SET), Sig[ScalarSubquery](SCALAR_SUBQUERY), Sig[CheckOverflow](CHECK_OVERFLOW), + Sig[MakeDecimal](MAKE_DECIMAL), Sig[PromotePrecision](PROMOTE_PRECISION), // Decimal Sig[UnscaledValue](UNSCALED_VALUE) diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/UnaryExpressionTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/expression/UnaryExpressionTransformer.scala index 9db8b8e44851..15ad682c1163 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/UnaryExpressionTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/UnaryExpressionTransformer.scala @@ -128,6 +128,32 @@ class CheckOverflowTransformer( } } +class MakeDecimalTransformer(substraitExprName: String, + child: ExpressionTransformer, + original: MakeDecimal) + extends ExpressionTransformer { + + override def doTransform(args: java.lang.Object): ExpressionNode = { + val childNode = child.doTransform(args) + val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val functionId = ExpressionBuilder.newScalarFunction( + functionMap, + ConverterUtils.makeFuncName( + substraitExprName, + Seq(original.dataType, BooleanType), + FunctionConfig.OPT)) + + // use fake decimal literal, because velox function signature need to get return type + // scale and precision by input type variable + val toTypeNodes = ExpressionBuilder.makeDecimalLiteral( + new Decimal().set(0, original.precision, original.scale)) + val expressionNodes = Lists.newArrayList(childNode, toTypeNodes, + new BooleanLiteralNode(original.nullOnOverflow)) + val typeNode = ConverterUtils.getTypeNode(original.dataType, original.nullable) + ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode) + } +} + case class Md5Transformer(substraitExprName: String, child: ExpressionTransformer, original: Md5) extends ExpressionTransformer with Logging { @@ -196,6 +222,8 @@ object UnaryExpressionTransformer { original match { case c: CheckOverflow => new CheckOverflowTransformer(substraitExprName, child, c) + case m: MakeDecimal => + new MakeDecimalTransformer(substraitExprName, child, m) case p: PromotePrecision => new PromotePrecisionTransformer(child, p) case extract if extract.isInstanceOf[GetDateField] || extract.isInstanceOf[GetTimeField] =>