Skip to content

Commit

Permalink
[SPARK-33008][SQL] Division by zero on divide-like operations returns…
Browse files Browse the repository at this point in the history
… incorrect result

### What changes were proposed in this pull request?
In ANSI mode, when a division by zero occurs performing a divide-like operation (Divide, IntegralDivide, Remainder or Pmod), we are returning an incorrect value. Instead, we should throw an exception, as stated in the SQL standard.

### Why are the changes needed?
Result corrupt.

### Does this PR introduce any user-facing change?
No.

### How was this patch tested?
added UT + existing UTs (improved)

Closes #29882 from luluorta/SPARK-33008.

Authored-by: luluorta <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
luluorta authored and cloud-fan committed Oct 29, 2020
1 parent fa63117 commit cbd3fde
Show file tree
Hide file tree
Showing 24 changed files with 379 additions and 181 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -307,35 +307,35 @@ class Analyzer(
object ResolveBinaryArithmetic extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p: LogicalPlan => p.transformExpressionsUp {
case a @ Add(l, r) if a.childrenResolved => (l.dataType, r.dataType) match {
case a @ Add(l, r, f) if a.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, CalendarIntervalType) => a
case (DateType, CalendarIntervalType) => DateAddInterval(l, r)
case (DateType, CalendarIntervalType) => DateAddInterval(l, r, ansiEnabled = f)
case (_, CalendarIntervalType) => Cast(TimeAdd(l, r), l.dataType)
case (CalendarIntervalType, DateType) => DateAddInterval(r, l)
case (CalendarIntervalType, DateType) => DateAddInterval(r, l, ansiEnabled = f)
case (CalendarIntervalType, _) => Cast(TimeAdd(r, l), r.dataType)
case (DateType, dt) if dt != StringType => DateAdd(l, r)
case (dt, DateType) if dt != StringType => DateAdd(r, l)
case _ => a
}
case s @ Subtract(l, r) if s.childrenResolved => (l.dataType, r.dataType) match {
case s @ Subtract(l, r, f) if s.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, CalendarIntervalType) => s
case (DateType, CalendarIntervalType) =>
DatetimeSub(l, r, DateAddInterval(l, UnaryMinus(r)))
DatetimeSub(l, r, DateAddInterval(l, UnaryMinus(r, f), ansiEnabled = f))
case (_, CalendarIntervalType) =>
Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r))), l.dataType)
Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r, f))), l.dataType)
case (TimestampType, _) => SubtractTimestamps(l, r)
case (_, TimestampType) => SubtractTimestamps(l, r)
case (_, DateType) => SubtractDates(l, r)
case (DateType, dt) if dt != StringType => DateSub(l, r)
case _ => s
}
case m @ Multiply(l, r) if m.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => MultiplyInterval(l, r)
case (_, CalendarIntervalType) => MultiplyInterval(r, l)
case m @ Multiply(l, r, f) if m.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => MultiplyInterval(l, r, f)
case (_, CalendarIntervalType) => MultiplyInterval(r, l, f)
case _ => m
}
case d @ Divide(l, r) if d.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => DivideInterval(l, r)
case d @ Divide(l, r, f) if d.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => DivideInterval(l, r, f)
case _ => d
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,39 +98,44 @@ object DecimalPrecision extends TypeCoercionRule {
// Skip nodes who is already promoted
case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e

case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
case a @ Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val resultScale = max(s1, s2)
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
resultScale)
} else {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
CheckOverflow(Add(promotePrecision(e1, resultType), promotePrecision(e2, resultType)),
CheckOverflow(
a.withNewChildren(Seq(promotePrecision(e1, resultType), promotePrecision(e2, resultType))),
resultType, nullOnOverflow)

case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
case s @ Subtract(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2), _) =>
val resultScale = max(s1, s2)
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
resultScale)
} else {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
CheckOverflow(Subtract(promotePrecision(e1, resultType), promotePrecision(e2, resultType)),
CheckOverflow(
s.withNewChildren(Seq(promotePrecision(e1, resultType), promotePrecision(e2, resultType))),
resultType, nullOnOverflow)

case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
case m @ Multiply(
e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
} else {
DecimalType.bounded(p1 + p2 + 1, s1 + s2)
}
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
CheckOverflow(
m.withNewChildren(Seq(promotePrecision(e1, widerType), promotePrecision(e2, widerType))),
resultType, nullOnOverflow)

case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
case d @ Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
// Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
// Scale: max(6, s1 + p2 + 1)
Expand All @@ -149,37 +154,40 @@ object DecimalPrecision extends TypeCoercionRule {
DecimalType.bounded(intDig + decDig, decDig)
}
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
CheckOverflow(
d.withNewChildren(Seq(promotePrecision(e1, widerType), promotePrecision(e2, widerType))),
resultType, nullOnOverflow)

case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
case r @ Remainder(
e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
} else {
DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
}
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
CheckOverflow(
r.withNewChildren(Seq(promotePrecision(e1, widerType), promotePrecision(e2, widerType))),
resultType, nullOnOverflow)

case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
case p @ Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
} else {
DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
}
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
CheckOverflow(
p.withNewChildren(Seq(promotePrecision(e1, widerType), promotePrecision(e2, widerType))),
resultType, nullOnOverflow)

case expr @ IntegralDivide(
e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val widerType = widerDecimalType(p1, s1, p2, s2)
val promotedExpr = IntegralDivide(
promotePrecision(e1, widerType),
promotePrecision(e2, widerType))
val promotedExpr = expr.withNewChildren(
Seq(promotePrecision(e1, widerType), promotePrecision(e2, widerType)))
if (expr.dataType.isInstanceOf[DecimalType]) {
// This follows division rule
val intDig = p1 - s1 + s2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging {
if attributesWithEventWatermark.contains(a) && metadata.contains(delayKey) =>
Multiply(Literal(eventWatermark.get.toDouble), Literal(1000.0))
}
}.reduceLeft(Add)
}.reduceLeft(Add(_, _))

// Calculate the constraint value
logInfo(s"Final expression to evaluate constraint:\t$exprWithWatermarkSubstituted")
Expand Down Expand Up @@ -226,14 +226,14 @@ object StreamingJoinHelper extends PredicateHelper with Logging {
*/
def collect(expr: Expression, negate: Boolean): Seq[Expression] = {
expr match {
case Add(left, right) =>
case Add(left, right, _) =>
collect(left, negate) ++ collect(right, negate)
case Subtract(left, right) =>
case Subtract(left, right, _) =>
collect(left, negate) ++ collect(right, !negate)
case TimeAdd(left, right, _) =>
collect(left, negate) ++ collect(right, negate)
case DatetimeSub(_, _, child) => collect(child, negate)
case UnaryMinus(child) =>
case UnaryMinus(child, _) =>
collect(child, !negate)
case CheckOverflow(child, _, _) =>
collect(child, negate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ object TypeCoercion {
s.withNewChildren(Seq(Cast(e, DoubleType)))
case s @ StddevSamp(e @ StringType(), _) =>
s.withNewChildren(Seq(Cast(e, DoubleType)))
case UnaryMinus(e @ StringType()) => UnaryMinus(Cast(e, DoubleType))
case m @ UnaryMinus(e @ StringType(), _) => m.withNewChildren(Seq(Cast(e, DoubleType)))
case UnaryPositive(e @ StringType()) => UnaryPositive(Cast(e, DoubleType))
case v @ VariancePop(e @ StringType(), _) =>
v.withNewChildren(Seq(Cast(e, DoubleType)))
Expand Down Expand Up @@ -698,8 +698,8 @@ object TypeCoercion {
// Decimal and Double remain the same
case d: Divide if d.dataType == DoubleType => d
case d: Divide if d.dataType.isInstanceOf[DecimalType] => d
case Divide(left, right) if isNumericOrNull(left) && isNumericOrNull(right) =>
Divide(Cast(left, DoubleType), Cast(right, DoubleType))
case d @ Divide(left, right, _) if isNumericOrNull(left) && isNumericOrNull(right) =>
d.withNewChildren(Seq(Cast(left, DoubleType), Cast(right, DoubleType)))
}

private def isNumericOrNull(ex: Expression): Boolean = {
Expand All @@ -715,8 +715,8 @@ object TypeCoercion {
object IntegralDivision extends TypeCoercionRule {
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
case e if !e.childrenResolved => e
case d @ IntegralDivide(left, right) =>
IntegralDivide(mayCastToLong(left), mayCastToLong(right))
case d @ IntegralDivide(left, right, _) =>
d.withNewChildren(Seq(mayCastToLong(left), mayCastToLong(right)))
}

private def mayCastToLong(expr: Expression): Expression = expr.dataType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,11 @@ object Canonicalize {

/** Rearrange expressions that are commutative or associative. */
private def expressionReorder(e: Expression): Expression = e match {
case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add)
case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply)
// TODO: do not reorder consecutive `Add`s or `Multiply`s with different `failOnError` flags
case a @ Add(_, _, f) =>
orderCommutative(a, { case Add(l, r, _) => Seq(l, r) }).reduce(Add(_, _, f))
case m @ Multiply(_, _, f) =>
orderCommutative(m, { case Multiply(l, r, _) => Seq(l, r) }).reduce(Multiply(_, _, f))

case o: Or =>
orderCommutative(o, { case Or(l, r) if l.deterministic && r.deterministic => Seq(l, r) })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,13 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
)

// If all input are nulls, count will be 0 and we will get null after the division.
// We can't directly use `/` as it throws an exception under ansi mode.
override lazy val evaluateExpression = child.dataType match {
case _: DecimalType =>
DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType)
DecimalPrecision.decimalAndDecimal(
Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
case _ =>
sum.cast(resultType) / count.cast(resultType)
Divide(sum.cast(resultType), count.cast(resultType), failOnError = false)
}

override lazy val updateExpressions: Seq[Expression] = Seq(
Expand Down
Loading

0 comments on commit cbd3fde

Please sign in to comment.