Skip to content

Commit

Permalink
Fix Row.equals()
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 19, 2015
1 parent a702e2e commit 88bd73c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 20 deletions.
14 changes: 6 additions & 8 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -406,17 +406,15 @@ trait Row extends Serializable {
o1 match {
case b1: Array[Byte] =>
if (!o2.isInstanceOf[Array[Byte]] ||
!java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
!java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
return false
}
case f1: Float =>
if (!o2.isInstanceOf[Float] ||
(java.lang.Float.isNaN(f1) && !java.lang.Float.isNaN(o2.asInstanceOf[Float]))) {
return false
case f1: Float if java.lang.Float.isNaN(f1) =>
if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
return false
}
case d1: Double =>
if (!o2.isInstanceOf[Double] ||
(java.lang.Double.isNaN(d1) && !java.lang.Double.isNaN(o2.asInstanceOf[Double]))) {
case d1: Double if java.lang.Double.isNaN(d1) =>
if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
return false
}
case _ => if (o1 != o2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils


object InterpretedPredicate {
Expand Down Expand Up @@ -257,13 +258,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
if (left.dataType == FloatType) {
val f1 = input1.asInstanceOf[Float]
val f2 = input2.asInstanceOf[Float]
(java.lang.Float.isNaN(f1) && java.lang.Float.isNaN(f2)) || f1 == f2
Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
} else if (left.dataType == DoubleType) {
val d1 = input1.asInstanceOf[Double]
val d2 = input2.asInstanceOf[Double]
(java.lang.Double.isNaN(d1) && java.lang.Double.isNaN(d2)) || d1 == d2
Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
} else if (left.dataType != BinaryType) {
input1 == input2
} else {
Expand Down Expand Up @@ -294,13 +291,9 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
false
} else {
if (left.dataType == FloatType) {
val f1 = input1.asInstanceOf[Float]
val f2 = input2.asInstanceOf[Float]
(java.lang.Float.isNaN(f1) && java.lang.Float.isNaN(f2)) || f1 == f2
Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
} else if (left.dataType == DoubleType) {
val d1 = input1.asInstanceOf[Double]
val d2 = input2.asInstanceOf[Double]
(java.lang.Double.isNaN(d1) && java.lang.Double.isNaN(d2)) || d1 == d2
Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
} else if (left.dataType != BinaryType) {
input1 == input2
} else {
Expand Down

0 comments on commit 88bd73c

Please sign in to comment.