Skip to content

Commit

Permalink
[SPARK-23179][SQL] Support option to throw exception if overflow occu…
Browse files Browse the repository at this point in the history
…rs during Decimal arithmetic

## What changes were proposed in this pull request?

SQL ANSI 2011 states that in case of overflow during arithmetic operations, an exception should be thrown. This is what most of the SQL DBs do (eg. SQLServer, DB2). Hive currently returns NULL (as Spark does) but HIVE-18291 is open to be SQL compliant.

The PR introduce an option to decide which behavior Spark should follow, ie. returning NULL on overflow or throwing an exception.

## How was this patch tested?

added UTs

Closes #20350 from mgaido91/SPARK-23179.

Authored-by: Marco Gaido <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
mgaido91 authored and cloud-fan committed Jun 27, 2019
1 parent 7cbe01e commit 3139d64
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ object DecimalPrecision extends TypeCoercionRule {
PromotePrecision(Cast(e, dataType))
}

private def nullOnOverflow: Boolean = SQLConf.get.decimalOperationsNullOnOverflow

override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// fix decimal precision for expressions
case q => q.transformExpressionsUp(
Expand All @@ -105,7 +107,7 @@ object DecimalPrecision extends TypeCoercionRule {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
CheckOverflow(Add(promotePrecision(e1, resultType), promotePrecision(e2, resultType)),
resultType)
resultType, nullOnOverflow)

case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultScale = max(s1, s2)
Expand All @@ -116,7 +118,7 @@ object DecimalPrecision extends TypeCoercionRule {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
CheckOverflow(Subtract(promotePrecision(e1, resultType), promotePrecision(e2, resultType)),
resultType)
resultType, nullOnOverflow)

case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
Expand All @@ -126,7 +128,7 @@ object DecimalPrecision extends TypeCoercionRule {
}
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType)
resultType, nullOnOverflow)

case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
Expand All @@ -148,7 +150,7 @@ object DecimalPrecision extends TypeCoercionRule {
}
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType)
resultType, nullOnOverflow)

case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
Expand All @@ -159,7 +161,7 @@ object DecimalPrecision extends TypeCoercionRule {
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType)
resultType, nullOnOverflow)

case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
Expand All @@ -170,7 +172,7 @@ object DecimalPrecision extends TypeCoercionRule {
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType)
resultType, nullOnOverflow)

case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging {
collect(left, negate) ++ collect(right, !negate)
case UnaryMinus(child) =>
collect(child, !negate)
case CheckOverflow(child, _) =>
case CheckOverflow(child, _, _) =>
collect(child, negate)
case PromotePrecision(child) =>
collect(child, negate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ object RowEncoder {
d,
"fromDecimal",
inputObject :: Nil,
returnNullable = false), d)
returnNullable = false), d, SQLConf.get.decimalOperationsNullOnOverflow)

case StringType => createSerializerForString(inputObject)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,30 +81,34 @@ case class PromotePrecision(child: Expression) extends UnaryExpression {

/**
* Rounds the decimal to given scale and check whether the decimal can fit in provided precision
* or not, returns null if not.
* or not. If not, if `nullOnOverflow` is `true`, it returns `null`; otherwise an
* `ArithmeticException` is thrown.
*/
case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression {
case class CheckOverflow(
child: Expression,
dataType: DecimalType,
nullOnOverflow: Boolean) extends UnaryExpression {

override def nullable: Boolean = true

override def nullSafeEval(input: Any): Any =
input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale)
input.asInstanceOf[Decimal].toPrecision(
dataType.precision,
dataType.scale,
Decimal.ROUND_HALF_UP,
nullOnOverflow)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, eval => {
val tmp = ctx.freshName("tmp")
s"""
| Decimal $tmp = $eval.clone();
| if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) {
| ${ev.value} = $tmp;
| } else {
| ${ev.isNull} = true;
| }
|${ev.value} = $eval.toPrecision(
| ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow);
|${ev.isNull} = ${ev.value} == null;
""".stripMargin
})
}

override def toString: String = s"CheckOverflow($child, $dataType)"
override def toString: String = s"CheckOverflow($child, $dataType, $nullOnOverflow)"

override def sql: String = child.sql
}
Original file line number Diff line number Diff line change
Expand Up @@ -1138,8 +1138,10 @@ abstract class RoundBase(child: Expression, scale: Expression,
val evaluationCode = dataType match {
case DecimalType.Fixed(_, s) =>
s"""
${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, Decimal.$modeStr());
${ev.isNull} = ${ev.value} == null;"""
|${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s,
| Decimal.$modeStr(), true);
|${ev.isNull} = ${ev.value} == null;
""".stripMargin
case ByteType =>
if (_scale < 0) {
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1441,6 +1441,16 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val DECIMAL_OPERATIONS_NULL_ON_OVERFLOW =
buildConf("spark.sql.decimalOperations.nullOnOverflow")
.internal()
.doc("When true (default), if an overflow on a decimal occurs, then NULL is returned. " +
"Spark's older versions and Hive behave in this way. If turned to false, SQL ANSI 2011 " +
"specification will be followed instead: an arithmetic exception is thrown, as most " +
"of the SQL databases do.")
.booleanConf
.createWithDefault(true)

val LITERAL_PICK_MINIMUM_PRECISION =
buildConf("spark.sql.legacy.literal.pickMinimumPrecision")
.internal()
Expand Down Expand Up @@ -2205,6 +2215,8 @@ class SQLConf extends Serializable with Logging {

def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)

def decimalOperationsNullOnOverflow: Boolean = getConf(DECIMAL_OPERATIONS_NULL_ON_OVERFLOW)

def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION)

def continuousStreamingEpochBacklogQueueSize: Int =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,25 @@ final class Decimal extends Ordered[Decimal] with Serializable {
/**
* Create new `Decimal` with given precision and scale.
*
* @return a non-null `Decimal` value if successful or `null` if overflow would occur.
* @return a non-null `Decimal` value if successful. Otherwise, if `nullOnOverflow` is true, null
* is returned; if `nullOnOverflow` is false, an `ArithmeticException` is thrown.
*/
private[sql] def toPrecision(
precision: Int,
scale: Int,
roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = {
roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP,
nullOnOverflow: Boolean = true): Decimal = {
val copy = clone()
if (copy.changePrecision(precision, scale, roundMode)) copy else null
if (copy.changePrecision(precision, scale, roundMode)) {
copy
} else {
if (nullOnOverflow) {
null
} else {
throw new ArithmeticException(
s"$toDebugString cannot be represented as Decimal($precision, $scale).")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,26 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {

test("CheckOverflow") {
val d1 = Decimal("10.1")
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10"))
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1)
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1)
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null)
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0), true), Decimal("10"))
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1), true), d1)
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2), true), d1)
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3), true), null)
intercept[ArithmeticException](CheckOverflow(Literal(d1), DecimalType(4, 3), false).eval())
intercept[ArithmeticException](checkEvaluationWithMutableProjection(
CheckOverflow(Literal(d1), DecimalType(4, 3), false), null))

val d2 = Decimal(101, 3, 1)
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10"))
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2)
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2)
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null)
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0), true), Decimal("10"))
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1), true), d2)
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2), true), d2)
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3), true), null)
intercept[ArithmeticException](CheckOverflow(Literal(d2), DecimalType(4, 3), false).eval())
intercept[ArithmeticException](checkEvaluationWithMutableProjection(
CheckOverflow(Literal(d2), DecimalType(4, 3), false), null))

checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null)
checkEvaluation(CheckOverflow(
Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2), true), null)
checkEvaluation(CheckOverflow(
Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2), false), null)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,28 @@ select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.1
select 123456789123456789.1234567890 * 1.123456789123456789;
select 12345678912345.123456789123 / 0.000000012345678;

-- throw an exception instead of returning NULL, according to SQL ANSI 2011
set spark.sql.decimalOperations.nullOnOverflow=false;

-- test operations between decimals and constants
select id, a*10, b/10 from decimals_test order by id;

-- test operations on constants
select 10.3 * 3.0;
select 10.3000 * 3.0;
select 10.30000 * 30.0;
select 10.300000000000000000 * 3.000000000000000000;
select 10.300000000000000000 * 3.0000000000000000000;

-- arithmetic operations causing an overflow throw exception
select (5e36 + 0.1) + 5e36;
select (-4e36 - 0.1) - 7e36;
select 12345678901234567890.0 * 12345678901234567890.0;
select 1e35 / 0.1;

-- arithmetic operations causing a precision loss throw exception
select 123456789123456789.1234567890 * 1.123456789123456789;
select 123456789123456789.1234567890 * 1.123456789123456789;
select 12345678912345.123456789123 / 0.000000012345678;

drop table decimals_test;
Loading

0 comments on commit 3139d64

Please sign in to comment.