diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index d9ec2c0d4b4cd..94e27379b7465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -358,8 +358,18 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { toType.sameType(literalType) && !fromExp.foldable && toType.isInstanceOf[NumericType] && - ((fromExp.dataType.isInstanceOf[NumericType] && Cast.canUpCast(fromExp.dataType, toType)) || - fromExp.dataType.isInstanceOf[BooleanType]) + canUnwrapCast(fromExp.dataType, toType) + } + + private def canUnwrapCast(from: DataType, to: DataType): Boolean = (from, to) match { + case (BooleanType, _) => true + // SPARK-39476: It's not safe to unwrap cast from Integer to Float or from Long to Float/Double, + // since the length of Integer/Long may exceed the significant digits of Float/Double. + case (IntegerType, FloatType) => false + case (LongType, FloatType) => false + case (LongType, DoubleType) => false + case _ if from.isInstanceOf[NumericType] => Cast.canUpCast(from, to) + case _ => false } private[optimizer] def getRange(dt: DataType): Option[(Any, Any)] = dt match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala index 2c361299b173d..1d7af84ef6096 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala @@ -209,5 +209,36 @@ class UnwrapCastInComparisonEndToEndSuite extends QueryTest with SharedSparkSess } } + test("SPARK-39476: Should not unwrap cast from Long to Double/Float") { + withTable(t) { + Seq((6470759586864300301L)) + .toDF("c1").write.saveAsTable(t) + val df = spark.table(t) + + checkAnswer( + df.where("cast(c1 as double) == cast(6470759586864300301L as double)") + .select("c1"), + Row(6470759586864300301L)) + + checkAnswer( + df.where("cast(c1 as float) == cast(6470759586864300301L as float)") + .select("c1"), + Row(6470759586864300301L)) + } + } + + test("SPARK-39476: Should not unwrap cast from Integer to Float") { + withTable(t) { + Seq((33554435)) + .toDF("c1").write.saveAsTable(t) + val df = spark.table(t) + + checkAnswer( + df.where("cast(c1 as float) == cast(33554435 as float)") + .select("c1"), + Row(33554435)) + } + } + private def decimal(v: BigDecimal): Decimal = Decimal(v, 5, 2) }