From d9ab1e46c457a378b623b2e249ef1b53e0289eae Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 18 Jun 2015 20:49:58 +0000 Subject: [PATCH] Add simple resolver --- .../spark/sql/execution/SortSuite.scala | 20 +++++---- .../spark/sql/execution/SparkPlanTest.scala | 43 ++++++++++++++++++- 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index 5f2db13d8202d..05be27437ae96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -18,9 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, Ascending, SortOrder} +import org.apache.spark.sql.catalyst.dsl.expressions._ + +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types.{IntegerType, StringType} class SortSuite extends SparkPlanTest { + import TestSQLContext.implicits.localSeqToDataFrameHolder test("basic sorting using ExternalSort") { @@ -30,16 +34,14 @@ class SortSuite extends SparkPlanTest { ("World", 8) ) - val sortOrder = Seq( - SortOrder(BoundReference(0, StringType, nullable = false), Ascending), - SortOrder(BoundReference(1, IntegerType, nullable = false), Ascending) - ) - checkAnswer( - input, - (child: SparkPlan) => new ExternalSort(sortOrder, global = false, child), - input.sorted - ) + input.toDF("a", "b"), + ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan), + input.sorted) + checkAnswer( + input.toDF("a", "b"), + ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan), + input.sortBy(t => (t._2, t._1))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index b4f37cf8f69ae..39e561ec3e434 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -21,9 +21,13 @@ import scala.util.control.NonFatal import scala.reflect.runtime.universe.TypeTag import org.apache.spark.SparkFunSuite + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.util._ + import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{Row, DataFrame} -import org.apache.spark.sql.catalyst.util._ /** * Base class for writing tests for individual physical operators. For an example of how this @@ -48,6 +52,24 @@ class SparkPlanTest extends SparkFunSuite { } } + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + */ + protected def checkAnswer[A <: Product : TypeTag]( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedAnswer: Seq[A]): Unit = { + val expectedRows = expectedAnswer.map(Row.fromTuple) + SparkPlanTest.checkAnswer(input, planFunction, expectedRows) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + /** * Runs the plan and makes sure the answer matches the expected result. * @param input the input data to be used. @@ -87,6 +109,23 @@ object SparkPlanTest { val outputPlan = planFunction(input.queryExecution.sparkPlan) + // A very simple resolver to make writing tests easier. In contrast to the real resolver + // this is always case sensitive and does not try to handle scoping or complex type resolution. + val resolvedPlan = outputPlan transform { + case plan: SparkPlan => + val inputMap = plan.children.flatMap(_.output).zipWithIndex.map { + case (a, i) => + (a.name, BoundReference(i, a.dataType, a.nullable)) + }.toMap + + plan.transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.get(u).getOrElse { + sys.error(s"Invalid Test: Cannot resolve $u given input ${inputMap}") + } + } + } + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to @@ -105,7 +144,7 @@ object SparkPlanTest { } val sparkAnswer: Seq[Row] = try { - outputPlan.executeCollect().toSeq + resolvedPlan.executeCollect().toSeq } catch { case NonFatal(e) => val errorMessage =