From 683e270c16258490c7c1deedb05fe44d45c39616 Mon Sep 17 00:00:00 2001 From: Mick Jermsurawong Date: Fri, 5 Jul 2019 22:05:26 +0800 Subject: [PATCH] [SPARK-28200][SQL] Decimal overflow handling in ExpressionEncoder ## What changes were proposed in this pull request? - Currently, `ExpressionEncoder` does not handle bigdecimal overflow. Round-tripping overflowing java/scala BigDecimal/BigInteger returns null. - The serializer encode java/scala BigDecimal to to sql Decimal, which still has the underlying data to the former. - When writing out to UnsafeRow, `changePrecision` will be false and row has null value. https://github.com/apache/spark/blob/24e1e41648de58d3437e008b187b84828830e238/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java#L202-L206 - In [SPARK-23179](https://github.com/apache/spark/pull/20350), an option to throw exception on decimal overflow was introduced. - This PR adds the option in `ExpressionEncoder` to throw when detecting overflowing BigDecimal/BigInteger before its corresponding Decimal gets written to Row. This gives a consistent behavior between decimal arithmetic on sql expression (DecimalPrecision), and getting decimal from dataframe (RowEncoder) Thanks to mgaido91 for the very first PR `SPARK-23179` and follow-up discussion on this change. Thanks to JoshRosen for working with me on this. ## How was this patch tested? added unit tests Closes #25016 from mickjermsurawong-stripe/SPARK-28200. Authored-by: Mick Jermsurawong Signed-off-by: Wenchen Fan --- .../sql/catalyst/SerializerBuildHelper.scala | 13 ++-- .../encoders/ExpressionEncoderSuite.scala | 73 +++++++++++++++++++ .../catalyst/encoders/RowEncoderSuite.scala | 26 +++++++ 3 files changed, 107 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index e035c4be97240..75c278e781140 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, IsNull, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.{CheckOverflow, CreateNamedStruct, Expression, IsNull, UnsafeArrayData} import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String object SerializerBuildHelper { + private def nullOnOverflow: Boolean = SQLConf.get.decimalOperationsNullOnOverflow + def createSerializerForBoolean(inputObject: Expression): Expression = { Invoke(inputObject, "booleanValue", BooleanType) } @@ -99,12 +102,12 @@ object SerializerBuildHelper { } def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = { - StaticInvoke( + CheckOverflow(StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil, - returnNullable = false) + returnNullable = false), DecimalType.SYSTEM_DEFAULT, nullOnOverflow) } def createSerializerForScalaBigDecimal(inputObject: Expression): Expression = { @@ -112,12 +115,12 @@ object SerializerBuildHelper { } def createSerializerForJavaBigInteger(inputObject: Expression): Expression = { - StaticInvoke( + CheckOverflow(StaticInvoke( Decimal.getClass, DecimalType.BigIntDecimal, "apply", inputObject :: Nil, - returnNullable = false) + returnNullable = false), DecimalType.BigIntDecimal, nullOnOverflow) } def createSerializerForScalaBigInt(inputObject: Expression): Expression = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 86e43d71e4608..f4feeca1d05ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ClosureCleaner @@ -379,6 +380,78 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes assert(e.getMessage.contains("tuple with more than 22 elements are not supported")) } + // Scala / Java big decimals ---------------------------------------------------------- + + encodeDecodeTest(BigDecimal(("9" * 20) + "." + "9" * 18), + "scala decimal within precision/scale limit") + encodeDecodeTest(new java.math.BigDecimal(("9" * 20) + "." + "9" * 18), + "java decimal within precision/scale limit") + + encodeDecodeTest(-BigDecimal(("9" * 20) + "." + "9" * 18), + "negative scala decimal within precision/scale limit") + encodeDecodeTest(new java.math.BigDecimal(("9" * 20) + "." + "9" * 18).negate, + "negative java decimal within precision/scale limit") + + testOverflowingBigNumeric(BigDecimal("1" * 21), "scala big decimal") + testOverflowingBigNumeric(new java.math.BigDecimal("1" * 21), "java big decimal") + + testOverflowingBigNumeric(-BigDecimal("1" * 21), "negative scala big decimal") + testOverflowingBigNumeric(new java.math.BigDecimal("1" * 21).negate, "negative java big decimal") + + testOverflowingBigNumeric(BigDecimal(("1" * 21) + ".123"), + "scala big decimal with fractional part") + testOverflowingBigNumeric(new java.math.BigDecimal(("1" * 21) + ".123"), + "java big decimal with fractional part") + + testOverflowingBigNumeric(BigDecimal(("1" * 21) + "." + "9999" * 100), + "scala big decimal with long fractional part") + testOverflowingBigNumeric(new java.math.BigDecimal(("1" * 21) + "." + "9999" * 100), + "java big decimal with long fractional part") + + // Scala / Java big integers ---------------------------------------------------------- + + encodeDecodeTest(BigInt("9" * 38), "scala big integer within precision limit") + encodeDecodeTest(new BigInteger("9" * 38), "java big integer within precision limit") + + encodeDecodeTest(-BigInt("9" * 38), + "negative scala big integer within precision limit") + encodeDecodeTest(new BigInteger("9" * 38).negate(), + "negative java big integer within precision limit") + + testOverflowingBigNumeric(BigInt("1" * 39), "scala big int") + testOverflowingBigNumeric(new BigInteger("1" * 39), "java big integer") + + testOverflowingBigNumeric(-BigInt("1" * 39), "negative scala big int") + testOverflowingBigNumeric(new BigInteger("1" * 39).negate, "negative java big integer") + + testOverflowingBigNumeric(BigInt("9" * 100), "scala very large big int") + testOverflowingBigNumeric(new BigInteger("9" * 100), "java very big int") + + private def testOverflowingBigNumeric[T: TypeTag](bigNumeric: T, testName: String): Unit = { + Seq(true, false).foreach { allowNullOnOverflow => + testAndVerifyNotLeakingReflectionObjects( + s"overflowing $testName, allowNullOnOverflow=$allowNullOnOverflow") { + withSQLConf( + SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> allowNullOnOverflow.toString + ) { + // Need to construct Encoder here rather than implicitly resolving it + // so that SQLConf changes are respected. + val encoder = ExpressionEncoder[T]() + if (allowNullOnOverflow) { + val convertedBack = encoder.resolveAndBind().fromRow(encoder.toRow(bigNumeric)) + assert(convertedBack === null) + } else { + val e = intercept[RuntimeException] { + encoder.toRow(bigNumeric) + } + assert(e.getMessage.contains("Error while encoding")) + assert(e.getCause.getClass === classOf[ArithmeticException]) + } + } + } + } + } + private def encodeDecodeTest[T : ExpressionEncoder]( input: T, testName: String): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 4b3d0612a9b85..5d21e4a2a83ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -162,6 +162,32 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { assert(row.toSeq(schema).head == decimal) } + test("SPARK-23179: RowEncoder should respect nullOnOverflow for decimals") { + val schema = new StructType().add("decimal", DecimalType.SYSTEM_DEFAULT) + testDecimalOverflow(schema, Row(BigDecimal("9" * 100))) + testDecimalOverflow(schema, Row(new java.math.BigDecimal("9" * 100))) + } + + private def testDecimalOverflow(schema: StructType, row: Row): Unit = { + withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") { + val encoder = RowEncoder(schema).resolveAndBind() + intercept[Exception] { + encoder.toRow(row) + } match { + case e: ArithmeticException => + assert(e.getMessage.contains("cannot be represented as Decimal")) + case e: RuntimeException => + assert(e.getCause.isInstanceOf[ArithmeticException]) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal")) + } + } + + withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") { + val encoder = RowEncoder(schema).resolveAndBind() + assert(encoder.fromRow(encoder.toRow(row)).get(0) == null) + } + } + test("RowEncoder should preserve schema nullability") { val schema = new StructType().add("int", IntegerType, nullable = false) val encoder = RowEncoder(schema).resolveAndBind()