Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8797] [SPARK-9146] [SPARK-9145] [SPARK-9147] Support NaN ordering and equality comparisons in Spark SQL #7194

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d2b4a4a
Add random data generator test utilities to Spark SQL.
JoshRosen Jul 2, 2015
ab76cbd
Move code to Catalyst package.
JoshRosen Jul 2, 2015
5acdd5c
Infinity and NaN are interesting.
JoshRosen Jul 2, 2015
b55875a
Generate doubles and floats over entire possible range.
JoshRosen Jul 2, 2015
7d5c13e
Add regression test for SPARK-8782 (ORDER BY NULL)
JoshRosen Jul 2, 2015
e7dc4fb
Add very generic test for ordering
JoshRosen Jul 2, 2015
f9efbb5
Fix ORDER BY NULL
JoshRosen Jul 2, 2015
13fc06a
Add regression test for NaN sorting issue
JoshRosen Jul 2, 2015
9bf195a
Re-enable NaNs in CodeGenerationSuite to produce more regression tests
JoshRosen Jul 2, 2015
630ebc5
Specify an ordering for NaN values.
JoshRosen Jul 2, 2015
d907b5b
Merge remote-tracking branch 'origin/master' into nan
JoshRosen Jul 18, 2015
5b88b2b
Fix compilation of CodeGenerationSuite
JoshRosen Jul 18, 2015
b20837b
Add failing test for new NaN comparision ordering
JoshRosen Jul 18, 2015
8d7be61
Update randomized test to use ScalaTest's assume()
JoshRosen Jul 18, 2015
bfca524
Change ordering so that NaN is maximum value.
JoshRosen Jul 18, 2015
42a1ad5
Stop filtering NaNs in UnsafeExternalSortSuite
JoshRosen Jul 18, 2015
6f03f85
Fix bug in Double / Float ordering
JoshRosen Jul 18, 2015
a30d371
Compare rows' string representations to work around NaN incomparability.
JoshRosen Jul 18, 2015
a2ba2e7
Fix prefix comparision for NaNs
JoshRosen Jul 18, 2015
3998ef2
Remove unused code
JoshRosen Jul 18, 2015
fc6b4d2
Update CodeGenerator
JoshRosen Jul 18, 2015
58bad2c
Revert "Compare rows' string representations to work around NaN incom…
JoshRosen Jul 19, 2015
7fe67af
Support NaN == NaN (SPARK-9145)
JoshRosen Jul 19, 2015
b31eb19
Uncomment failing tests
JoshRosen Jul 19, 2015
c1fd4fe
Fold NaN test into existing test framework
JoshRosen Jul 19, 2015
fbb2a29
Fix NaN comparisons in BinaryComparison expressions
JoshRosen Jul 19, 2015
fe629ae
Merge remote-tracking branch 'origin/master' into nan
JoshRosen Jul 19, 2015
a7267cf
Normalize NaNs in UnsafeRow
JoshRosen Jul 19, 2015
a702e2e
normalization -> canonicalization
JoshRosen Jul 19, 2015
88bd73c
Fix Row.equals()
JoshRosen Jul 19, 2015
983d4fc
Merge remote-tracking branch 'origin/master' into nan
JoshRosen Jul 19, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import org.apache.spark.annotation.Private;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.util.Utils;

@Private
public class PrefixComparators {
Expand Down Expand Up @@ -82,7 +83,7 @@ public static final class FloatPrefixComparator extends PrefixComparator {
public int compare(long aPrefix, long bPrefix) {
float a = Float.intBitsToFloat((int) aPrefix);
float b = Float.intBitsToFloat((int) bPrefix);
return (a < b) ? -1 : (a > b) ? 1 : 0;
return Utils.nanSafeCompareFloats(a, b);
}

public long computePrefix(float value) {
Expand All @@ -97,7 +98,7 @@ public static final class DoublePrefixComparator extends PrefixComparator {
public int compare(long aPrefix, long bPrefix) {
double a = Double.longBitsToDouble(aPrefix);
double b = Double.longBitsToDouble(bPrefix);
return (a < b) ? -1 : (a > b) ? 1 : 0;
return Utils.nanSafeCompareDoubles(a, b);
}

public long computePrefix(double value) {
Expand Down
28 changes: 28 additions & 0 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1586,6 +1586,34 @@ private[spark] object Utils extends Logging {
hashAbs
}

/**
* NaN-safe version of [[java.lang.Double.compare()]] which allows NaN values to be compared
* according to semantics where NaN == NaN and NaN > any non-NaN double.
*/
def nanSafeCompareDoubles(x: Double, y: Double): Int = {
val xIsNan: Boolean = java.lang.Double.isNaN(x)
val yIsNan: Boolean = java.lang.Double.isNaN(y)
if ((xIsNan && yIsNan) || (x == y)) 0
else if (xIsNan) 1
else if (yIsNan) -1
else if (x > y) 1
else -1
}

/**
* NaN-safe version of [[java.lang.Float.compare()]] which allows NaN values to be compared
* according to semantics where NaN == NaN and NaN > any non-NaN float.
*/
def nanSafeCompareFloats(x: Float, y: Float): Int = {
val xIsNan: Boolean = java.lang.Float.isNaN(x)
val yIsNan: Boolean = java.lang.Float.isNaN(y)
if ((xIsNan && yIsNan) || (x == y)) 0
else if (xIsNan) 1
else if (yIsNan) -1
else if (x > y) 1
else -1
}

/** Returns the system properties map that is thread-safe to iterator over. It gets the
* properties which have been set explicitly, as well as those for which only a default value
* has been defined. */
Expand Down
31 changes: 31 additions & 0 deletions core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.util

import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream}
import java.lang.{Double => JDouble, Float => JFloat}
import java.net.{BindException, ServerSocket, URI}
import java.nio.{ByteBuffer, ByteOrder}
import java.text.DecimalFormatSymbols
Expand Down Expand Up @@ -689,4 +690,34 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
// scalastyle:on println
assert(buffer.toString === "t circular test circular\n")
}

test("nanSafeCompareDoubles") {
def shouldMatchDefaultOrder(a: Double, b: Double): Unit = {
assert(Utils.nanSafeCompareDoubles(a, b) === JDouble.compare(a, b))
assert(Utils.nanSafeCompareDoubles(b, a) === JDouble.compare(b, a))
}
shouldMatchDefaultOrder(0d, 0d)
shouldMatchDefaultOrder(0d, 1d)
shouldMatchDefaultOrder(Double.MinValue, Double.MaxValue)
assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NaN) === 0)
assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.PositiveInfinity) === 1)
assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NegativeInfinity) === 1)
assert(Utils.nanSafeCompareDoubles(Double.PositiveInfinity, Double.NaN) === -1)
assert(Utils.nanSafeCompareDoubles(Double.NegativeInfinity, Double.NaN) === -1)
}

test("nanSafeCompareFloats") {
def shouldMatchDefaultOrder(a: Float, b: Float): Unit = {
assert(Utils.nanSafeCompareFloats(a, b) === JFloat.compare(a, b))
assert(Utils.nanSafeCompareFloats(b, a) === JFloat.compare(b, a))
}
shouldMatchDefaultOrder(0f, 0f)
shouldMatchDefaultOrder(1f, 1f)
shouldMatchDefaultOrder(Float.MinValue, Float.MaxValue)
assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NaN) === 0)
assert(Utils.nanSafeCompareFloats(Float.NaN, Float.PositiveInfinity) === 1)
assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NegativeInfinity) === 1)
assert(Utils.nanSafeCompareFloats(Float.PositiveInfinity, Float.NaN) === -1)
assert(Utils.nanSafeCompareFloats(Float.NegativeInfinity, Float.NaN) === -1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,29 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
}

test("float prefix comparator handles NaN properly") {
val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001)
val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff)
assert(nan1.isNaN)
assert(nan2.isNaN)
val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1)
val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2)
assert(nan1Prefix === nan2Prefix)
val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue)
assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1)
}

test("double prefix comparator handles NaNs properly") {
val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L)
val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)
assert(nan1.isNaN)
assert(nan2.isNaN)
val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1)
val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2)
assert(nan1Prefix === nan2Prefix)
val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue)
assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ public void setLong(int ordinal, long value) {
public void setDouble(int ordinal, double value) {
assertIndexIsValid(ordinal);
setNotNullAt(ordinal);
if (Double.isNaN(value)) {
value = Double.NaN;
}
PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value);
}

Expand Down Expand Up @@ -243,6 +246,9 @@ public void setByte(int ordinal, byte value) {
public void setFloat(int ordinal, float value) {
assertIndexIsValid(ordinal);
setNotNullAt(ordinal);
if (Float.isNaN(value)) {
value = Float.NaN;
}
PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
}

Expand Down
24 changes: 16 additions & 8 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -403,20 +403,28 @@ trait Row extends Serializable {
if (!isNullAt(i)) {
val o1 = get(i)
val o2 = other.get(i)
if (o1.isInstanceOf[Array[Byte]]) {
// handle equality of Array[Byte]
val b1 = o1.asInstanceOf[Array[Byte]]
if (!o2.isInstanceOf[Array[Byte]] ||
!java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
o1 match {
case b1: Array[Byte] =>
if (!o2.isInstanceOf[Array[Byte]] ||
!java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
return false
}
case f1: Float if java.lang.Float.isNaN(f1) =>
if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
return false
}
case d1: Double if java.lang.Double.isNaN(d1) =>
if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
return false
}
case _ => if (o1 != o2) {
return false
}
} else if (o1 != o2) {
return false
}
}
i += 1
}
return true
true
}

/* ---------------------- utility methods for Scala ---------------------- */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ class CodeGenContext {
*/
def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match {
case BinaryType => s"java.util.Arrays.equals($c1, $c2)"
case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2"
case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2"
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
case other => s"$c1.equals($c2)"
}
Expand All @@ -194,6 +196,8 @@ class CodeGenContext {
def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
// java boolean doesn't support > or < operator
case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))"
case DoubleType => s"org.apache.spark.util.Utils.nanSafeCompareDoubles($c1, $c2)"
case FloatType => s"org.apache.spark.util.Utils.nanSafeCompareFloats($c1, $c2)"
// use c1 - c2 may overflow
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils


object InterpretedPredicate {
Expand Down Expand Up @@ -222,7 +223,9 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
abstract class BinaryComparison extends BinaryOperator with Predicate {

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
if (ctx.isPrimitiveType(left.dataType)) {
if (ctx.isPrimitiveType(left.dataType)
&& left.dataType != FloatType
&& left.dataType != DoubleType) {
// faster version
defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2")
} else {
Expand Down Expand Up @@ -254,8 +257,15 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
override def symbol: String = "="

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
if (left.dataType != BinaryType) input1 == input2
else java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
if (left.dataType == FloatType) {
Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
} else if (left.dataType == DoubleType) {
Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
} else if (left.dataType != BinaryType) {
input1 == input2
} else {
java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
Expand All @@ -280,7 +290,11 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
} else if (input1 == null || input2 == null) {
false
} else {
if (left.dataType != BinaryType) {
if (left.dataType == FloatType) {
Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
} else if (left.dataType == DoubleType) {
Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
} else if (left.dataType != BinaryType) {
input1 == input2
} else {
java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.util.Utils

/**
* :: DeveloperApi ::
Expand All @@ -37,7 +38,9 @@ class DoubleType private() extends FractionalType {
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
private[sql] val numeric = implicitly[Numeric[Double]]
private[sql] val fractional = implicitly[Fractional[Double]]
private[sql] val ordering = implicitly[Ordering[InternalType]]
private[sql] val ordering = new Ordering[Double] {
override def compare(x: Double, y: Double): Int = Utils.nanSafeCompareDoubles(x, y)
}
private[sql] val asIntegral = DoubleAsIfIntegral

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.util.Utils

/**
* :: DeveloperApi ::
Expand All @@ -37,7 +38,9 @@ class FloatType private() extends FractionalType {
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
private[sql] val numeric = implicitly[Numeric[Float]]
private[sql] val fractional = implicitly[Fractional[Float]]
private[sql] val ordering = implicitly[Ordering[InternalType]]
private[sql] val ordering = new Ordering[Float] {
override def compare(x: Float, y: Float): Int = Utils.nanSafeCompareFloats(x, y)
}
private[sql] val asIntegral = FloatAsIfIntegral

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@

package org.apache.spark.sql.catalyst.expressions

import scala.math._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types.{DataTypeTestUtils, NullType, StructField, StructType}

/**
* Additional tests for code generation.
Expand All @@ -43,6 +48,40 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
futures.foreach(Await.result(_, 10.seconds))
}

// Test GenerateOrdering for all common types. For each type, we construct random input rows that
// contain two columns of that type, then for pairs of randomly-generated rows we check that
// GenerateOrdering agrees with RowOrdering.
(DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType =>
test(s"GenerateOrdering with $dataType") {
val rowOrdering = RowOrdering.forSchema(Seq(dataType, dataType))
val genOrdering = GenerateOrdering.generate(
BoundReference(0, dataType, nullable = true).asc ::
BoundReference(1, dataType, nullable = true).asc :: Nil)
val rowType = StructType(
StructField("a", dataType, nullable = true) ::
StructField("b", dataType, nullable = true) :: Nil)
val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false)
assume(maybeDataGenerator.isDefined)
val randGenerator = maybeDataGenerator.get
val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType)
for (_ <- 1 to 50) {
val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
withClue(s"a = $a, b = $b") {
assert(genOrdering.compare(a, a) === 0)
assert(genOrdering.compare(b, b) === 0)
assert(rowOrdering.compare(a, a) === 0)
assert(rowOrdering.compare(b, b) === 0)
assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a)))
assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a)))
assert(
signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)),
"Generated and non-generated orderings should agree")
}
}
}
}

test("SPARK-8443: split wide projections into blocks due to JVM code size limit") {
val length = 5000
val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,14 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
}

private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_))
private val largeValues = Seq(2, Decimal(2), Array(2.toByte), "b").map(Literal(_))

private val equalValues1 = smallValues
private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_))
private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d).map(Literal(_))
private val largeValues =
Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN).map(Literal(_))

private val equalValues1 =
Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))
private val equalValues2 =
Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))

test("BinaryComparison: <") {
for (i <- 0 until smallValues.length) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,4 +316,26 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
}

test("NaN canonicalization") {
val fieldTypes: Array[DataType] = Array(FloatType, DoubleType)

val row1 = new SpecificMutableRow(fieldTypes)
row1.setFloat(0, java.lang.Float.intBitsToFloat(0x7f800001))
row1.setDouble(1, java.lang.Double.longBitsToDouble(0x7ff0000000000001L))

val row2 = new SpecificMutableRow(fieldTypes)
row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff))
row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL))

val converter = new UnsafeRowConverter(fieldTypes)
val row1Buffer = new Array[Byte](converter.getSizeRequirement(row1))
val row2Buffer = new Array[Byte](converter.getSizeRequirement(row2))
converter.writeRow(
row1, row1Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row1Buffer.length, null)
converter.writeRow(
row2, row2Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row2Buffer.length, null)

assert(row1Buffer.toSeq === row2Buffer.toSeq)
}

}
Loading