diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index 6d7b6e56e4e3c..d33c7f2eeb4e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -359,7 +359,17 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { !fromExp.foldable && fromExp.dataType.isInstanceOf[NumericType] && toType.isInstanceOf[NumericType] && - Cast.canUpCast(fromExp.dataType, toType) + canUnwrapCast(fromExp.dataType, toType) + } + + private def canUnwrapCast(from: DataType, to: DataType): Boolean = (from, to) match { + // SPARK-39476: It's not safe to unwrap cast from Integer to Float or from Long to Float/Double, + // since the length of Integer/Long may exceed the significant digits of Float/Double. + case (IntegerType, FloatType) => false + case (LongType, FloatType) => false + case (LongType, DoubleType) => false + case _ if from.isInstanceOf[NumericType] => Cast.canUpCast(from, to) + case _ => false } private[optimizer] def getRange(dt: DataType): Option[(Any, Any)] = dt match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala index e6f0426428bd4..c27097562e5c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala @@ -190,5 +190,36 @@ class UnwrapCastInComparisonEndToEndSuite extends QueryTest with SharedSparkSess } } + test("SPARK-39476: Should not unwrap cast from Long to Double/Float") { + withTable(t) { + Seq((6470759586864300301L)) + .toDF("c1").write.saveAsTable(t) + val df = spark.table(t) + + checkAnswer( + df.where("cast(c1 as double) == cast(6470759586864300301L as double)") + .select("c1"), + Row(6470759586864300301L)) + + checkAnswer( + df.where("cast(c1 as float) == cast(6470759586864300301L as float)") + .select("c1"), + Row(6470759586864300301L)) + } + } + + test("SPARK-39476: Should not unwrap cast from Integer to Float") { + withTable(t) { + Seq((33554435)) + .toDF("c1").write.saveAsTable(t) + val df = spark.table(t) + + checkAnswer( + df.where("cast(c1 as float) == cast(33554435 as float)") + .select("c1"), + Row(33554435)) + } + } + private def decimal(v: BigDecimal): Decimal = Decimal(v, 5, 2) }