Skip to content

Commit

Permalink
Compare rows' string representations to work around NaN incomparability.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 18, 2015
1 parent 6f03f85 commit a30d371
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,17 @@ class SparkPlanTest extends SparkFunSuite {
* treated as the source-of-truth for the test.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
* @param compareStrings if true, the answers will be converted to strings before being compared
*/
protected def checkThatPlansAgree(
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean = true): Unit = {
SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match {
sortAnswers: Boolean = true,
compareStrings: Boolean = false): Unit = {
val result = SparkPlanTest.checkAnswer(
input, planFunction, expectedPlanFunction, sortAnswers, compareStrings)
result match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
Expand All @@ -142,12 +146,14 @@ object SparkPlanTest {
* instantiate a reference implementation of the physical operator
* that's being tested. The result of executing this plan will be
* treated as the source-of-truth for the test.
* @param compareStrings if true, the answers will be converted to strings before being compared
*/
def checkAnswer(
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean): Option[String] = {
sortAnswers: Boolean,
compareStrings: Boolean): Option[String] = {

val outputPlan = planFunction(input.queryExecution.sparkPlan)
val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
Expand Down Expand Up @@ -182,7 +188,7 @@ object SparkPlanTest {
return Some(errorMessage)
}

compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
compareAnswers(actualAnswer, expectedAnswer, sortAnswers, compareStrings).map { errorMessage =>
s"""
| Results do not match.
| Actual result Spark plan:
Expand Down Expand Up @@ -226,7 +232,7 @@ object SparkPlanTest {
return Some(errorMessage)
}

compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
compareAnswers(sparkAnswer, expectedAnswer, sortAnswers, false).map { errorMessage =>
s"""
| Results do not match for Spark plan:
| $outputPlan
Expand All @@ -238,7 +244,8 @@ object SparkPlanTest {
private def compareAnswers(
sparkAnswer: Seq[Row],
expectedAnswer: Seq[Row],
sort: Boolean): Option[String] = {
sort: Boolean,
compareStrings: Boolean): Option[String] = {
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
Expand All @@ -253,11 +260,16 @@ object SparkPlanTest {
case o => o
})
}
if (sort) {
val maybeSorted = if (sort) {
converted.sortBy(_.toString())
} else {
converted
}
if (compareStrings) {
maybeSorted.map(r => Row.fromSeq(r.toSeq.map(String.valueOf))) // valueOf handles nulls
} else {
maybeSorted
}
}
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
val errorMessage =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
inputDf,
UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 23),
Sort(sortOrder, global = true, _: SparkPlan),
sortAnswers = false
sortAnswers = false,
compareStrings = true
)
}
}
Expand Down

0 comments on commit a30d371

Please sign in to comment.