Skip to content

Commit

Permalink
decimal unscaled value (apache#3)
Browse files Browse the repository at this point in the history
* support decimal UnscaledValue

* rename DecimalUtil to GlutenDecimalUtil
  • Loading branch information
liujiayi771 authored and JkSelf committed Mar 22, 2023
1 parent 5bf8a1a commit d843020
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.glutenproject.expression

import com.google.common.collect.Lists
import io.glutenproject.expression.ConverterUtils.FunctionConfig
import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode}

import org.apache.spark.sql.catalyst.expressions.Expression

class UnscaledValueTransformer(
substraitExprName: String,
child: ExpressionTransformer,
original: Expression) extends ExpressionTransformer {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val childNode = child.doTransform(args)
val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]
val functionName =
ConverterUtils.makeFuncName(
substraitExprName,
original.children.map(_.dataType),
FunctionConfig.REQ)
val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName)
val expressNodes = Lists.newArrayList(childNode)
val typeNode = ConverterUtils.getTypeNode(original.dataType, original.nullable)
ExpressionBuilder.makeScalarFunction(functionId, expressNodes, typeNode)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ object ExpressionConverter extends Logging {
case mapValues: MapValues =>
new UnaryArgumentCollectionOperationTransformer(substraitExprName,
replaceWithExpressionTransformer(mapValues.child, attributeSeq), mapValues)
case unscaled: UnscaledValue =>
new UnscaledValueTransformer(substraitExprName,
replaceWithExpressionTransformer(unscaled.child, attributeSeq), unscaled)
case expr =>
logWarning(s"${expr.getClass} or ${expr} is not currently supported.")
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ object ExpressionMappings {
final val CUME_DIST = "cume_dist"
final val PERCENT_RANK = "percent_rank"

// Decimal functions
final val UNSCALED_VALUE = "unscaled_value"

/**
* Mapping Spark scalar expression to Substrait function name
*/
Expand Down Expand Up @@ -375,7 +378,9 @@ object ExpressionMappings {
Sig[InSet](IN_SET),
Sig[ScalarSubquery](SCALAR_SUBQUERY),
Sig[CheckOverflow](CHECK_OVERFLOW),
Sig[PromotePrecision](PROMOTE_PRECISION)
Sig[PromotePrecision](PROMOTE_PRECISION),
// Decimal
Sig[UnscaledValue](UNSCALED_VALUE)
) ++ SparkShimLoader.getSparkShims.expressionMappings

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.glutenproject.utils

import scala.math.min

import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.types.DecimalType.{MAX_PRECISION, MAX_SCALE}

object GlutenDecimalUtil {
object Fixed {
def unapply(t: DecimalType): Option[(Int, Int)] = Some((t.precision, t.scale))
}

def bounded(precision: Int, scale: Int): DecimalType = {
DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package io.glutenproject.execution

import scala.collection.JavaConverters._

import com.google.protobuf.Any
import io.glutenproject.expression._
import io.glutenproject.expression.ConverterUtils.FunctionConfig
Expand All @@ -28,11 +29,13 @@ import io.glutenproject.substrait.rel.{RelBuilder, RelNode}

import java.util
import io.glutenproject.substrait.{AggregationParams, SubstraitContext}
import io.glutenproject.utils.GlutenDecimalUtil

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, DoubleType, LongType}
import org.apache.spark.sql.types.{DoubleType, LongType}

case class GlutenHashAggregateExecTransformer(
requiredChildDistributionExpressions: Option[Seq[Expression]],
Expand Down Expand Up @@ -136,16 +139,12 @@ case class GlutenHashAggregateExecTransformer(
val structTypeNodes = new util.ArrayList[TypeNode]()
aggregateFunction match {
case avg: Average =>
avg.dataType match {
case _: DecimalType =>
// Use struct type to represent Velox Row(DECIMAL, BIGINT).
structTypeNodes.add(ConverterUtils.getTypeNode(avg.dataType, nullable = true))
structTypeNodes.add(ConverterUtils.getTypeNode(LongType, nullable = true))
case _ =>
// Use struct type to represent Velox Row(DOUBLE, BIGINT).
structTypeNodes.add(ConverterUtils.getTypeNode(DoubleType, nullable = true))
structTypeNodes.add(ConverterUtils.getTypeNode(LongType, nullable = true))
}
structTypeNodes.add(ConverterUtils.getTypeNode(
avg.dataType match {
case _ @ GlutenDecimalUtil.Fixed(p, s) => GlutenDecimalUtil.bounded(p + 10, s)
case _ => DoubleType
}, nullable = true))
structTypeNodes.add(ConverterUtils.getTypeNode(LongType, nullable = true))
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
// Use struct type to represent Velox Row(BIGINT, DOUBLE, DOUBLE).
structTypeNodes.add(ConverterUtils.getTypeNode(LongType, nullable = true))
Expand Down

0 comments on commit d843020

Please sign in to comment.