From 6181577e9935f46b646ba3925b873d031aa3d6ba Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sun, 2 Nov 2014 00:03:51 -0700 Subject: [PATCH 1/7] [SPARK-3466] Limit size of results that a driver collects for each action Right now, operations like collect() and take() can crash the driver with an OOM if they bring back too many data. This PR will introduce spark.driver.maxResultSize, after setting it, the driver will abort a job if its result is bigger than it. By default, it's 1g (for backward compatibility for most the cases). In local mode, the driver and executor share the same JVM, the default setting can not protect JVM from OOM. cc mateiz Author: Davies Liu Closes #3003 from davies/collect and squashes the following commits: 248ed5e [Davies Liu] fix compile 272522e [Davies Liu] address comments 2c35773 [Davies Liu] add sizes in message of abort() 5d62303 [Davies Liu] address comments bc3c077 [Davies Liu] Merge branch 'master' of github.com:apache/spark into collect 11f97c5 [Davies Liu] address comments 47b144f [Davies Liu] check the size of result before send and fetch 3d81af2 [Davies Liu] address comments ca8267d [Davies Liu] limit the size of data by collect --- .../org/apache/spark/executor/Executor.scala | 25 ++++++++------ .../apache/spark/scheduler/TaskResult.scala | 4 +-- .../spark/scheduler/TaskResultGetter.scala | 20 ++++++++--- .../spark/scheduler/TaskSetManager.scala | 33 ++++++++++++++++--- .../scala/org/apache/spark/util/Utils.scala | 5 +++ .../scheduler/TaskResultGetterSuite.scala | 2 +- .../spark/scheduler/TaskSetManagerSuite.scala | 25 ++++++++++++++ docs/configuration.md | 12 +++++++ 8 files changed, 104 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c78e0ffca25bb..e24a15f015e1c 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -104,6 +104,9 @@ private[spark] class Executor( // to send the result back. private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) + // Limit of bytes for total size of results (default is 1GB) + private val maxResultSize = Utils.getMaxResultSize(conf) + // Start worker thread pool val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") @@ -210,25 +213,27 @@ private[spark] class Executor( val resultSize = serializedDirectResult.limit // directSend = sending directly back to the driver - val (serializedResult, directSend) = { - if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { + val serializedResult = { + if (resultSize > maxResultSize) { + logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " + + s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + + s"dropping it.") + ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize)) + } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) - (ser.serialize(new IndirectTaskResult[Any](blockId)), false) + logInfo( + s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") + ser.serialize(new IndirectTaskResult[Any](blockId, resultSize)) } else { - (serializedDirectResult, true) + logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") + serializedDirectResult } } execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) - if (directSend) { - logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") - } else { - logInfo( - s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") - } } catch { case ffe: FetchFailedException => { val reason = ffe.toTaskEndReason diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 11c19eeb6e42c..1f114a0207f7b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -31,8 +31,8 @@ import org.apache.spark.util.Utils private[spark] sealed trait TaskResult[T] /** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */ -private[spark] -case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable +private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int) + extends TaskResult[T] with Serializable /** A TaskResult that contains the task's return value and accumulator updates. */ private[spark] diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 4b5be68ec5f92..819b51e12ad8c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -47,9 +47,18 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { try { - val result = serializer.get().deserialize[TaskResult[_]](serializedData) match { - case directResult: DirectTaskResult[_] => directResult - case IndirectTaskResult(blockId) => + val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match { + case directResult: DirectTaskResult[_] => + if (!taskSetManager.canFetchMoreResults(serializedData.limit())) { + return + } + (directResult, serializedData.limit()) + case IndirectTaskResult(blockId, size) => + if (!taskSetManager.canFetchMoreResults(size)) { + // dropped by executor if size is larger than maxResultSize + sparkEnv.blockManager.master.removeBlock(blockId) + return + } logDebug("Fetching indirect task result for TID %s".format(tid)) scheduler.handleTaskGettingResult(taskSetManager, tid) val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId) @@ -64,9 +73,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]]( serializedTaskResult.get) sparkEnv.blockManager.master.removeBlock(blockId) - deserializedResult + (deserializedResult, size) } - result.metrics.resultSize = serializedData.limit() + + result.metrics.resultSize = size scheduler.handleSuccessfulTask(taskSetManager, tid, result) } catch { case cnf: ClassNotFoundException => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 376821f89c6b8..a9767340074a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -23,13 +23,12 @@ import java.util.Arrays import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet -import scala.math.max -import scala.math.min +import scala.math.{min, max} import org.apache.spark._ -import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.{Clock, SystemClock} +import org.apache.spark.TaskState.TaskState +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of @@ -68,6 +67,9 @@ private[spark] class TaskSetManager( val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75) val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5) + // Limit of bytes for total size of results (default is 1GB) + val maxResultSize = Utils.getMaxResultSize(conf) + // Serializer for closures and tasks. val env = SparkEnv.get val ser = env.closureSerializer.newInstance() @@ -89,6 +91,8 @@ private[spark] class TaskSetManager( var stageId = taskSet.stageId var name = "TaskSet_" + taskSet.stageId.toString var parent: Pool = null + var totalResultSize = 0L + var calculatedTasks = 0 val runningTasksSet = new HashSet[Long] override def runningTasks = runningTasksSet.size @@ -515,12 +519,33 @@ private[spark] class TaskSetManager( index } + /** + * Marks the task as getting result and notifies the DAG Scheduler + */ def handleTaskGettingResult(tid: Long) = { val info = taskInfos(tid) info.markGettingResult() sched.dagScheduler.taskGettingResult(info) } + /** + * Check whether has enough quota to fetch the result with `size` bytes + */ + def canFetchMoreResults(size: Long): Boolean = synchronized { + totalResultSize += size + calculatedTasks += 1 + if (maxResultSize > 0 && totalResultSize > maxResultSize) { + val msg = s"Total size of serialized results of ${calculatedTasks} tasks " + + s"(${Utils.bytesToString(totalResultSize)}) is bigger than maxResultSize " + + s"(${Utils.bytesToString(maxResultSize)})" + logError(msg) + abort(msg) + false + } else { + true + } + } + /** * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 68d378f3a212d..4e30d0d3813a2 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1720,6 +1720,11 @@ private[spark] object Utils extends Logging { method.invoke(obj, values.toSeq: _*) } + // Limit of bytes for total size of results (default is 1GB) + def getMaxResultSize(conf: SparkConf): Long = { + memoryStringToMb(conf.get("spark.driver.maxResultSize", "1g")).toLong << 20 + } + /** * Return the current system LD_LIBRARY_PATH name */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index c4e7a4bb7d385..5768a3a733f00 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -40,7 +40,7 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule // Only remove the result once, since we'd like to test the case where the task eventually // succeeds. serializer.get().deserialize[TaskResult[_]](serializedData) match { - case IndirectTaskResult(blockId) => + case IndirectTaskResult(blockId, size) => sparkEnv.blockManager.master.removeBlock(blockId) case directResult: DirectTaskResult[_] => taskSetManager.abort("Internal error: expect only indirect results") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index c0b07649eb6dd..1809b5396d53e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -563,6 +563,31 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.emittedTaskSizeWarning) } + test("abort the job if total size of results is too large") { + val conf = new SparkConf().set("spark.driver.maxResultSize", "2m") + sc = new SparkContext("local", "test", conf) + + def genBytes(size: Int) = { (x: Int) => + val bytes = Array.ofDim[Byte](size) + scala.util.Random.nextBytes(bytes) + bytes + } + + // multiple 1k result + val r = sc.makeRDD(0 until 10, 10).map(genBytes(1024)).collect() + assert(10 === r.size ) + + // single 10M result + val thrown = intercept[SparkException] {sc.makeRDD(genBytes(10 << 20)(0), 1).collect()} + assert(thrown.getMessage().contains("bigger than maxResultSize")) + + // multiple 1M results + val thrown2 = intercept[SparkException] { + sc.makeRDD(0 until 10, 10).map(genBytes(1 << 20)).collect() + } + assert(thrown2.getMessage().contains("bigger than maxResultSize")) + } + test("speculative and noPref task should be scheduled after node-local") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) diff --git a/docs/configuration.md b/docs/configuration.md index 3007706a2586e..099972ca1af70 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -111,6 +111,18 @@ of the most common options to set are: (e.g. 512m, 2g). + + spark.driver.maxResultSize + 1g + + Limit of total size of serialized results of all partitions for each Spark action (e.g. collect). + Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total size + is above this limit. + Having a high limit may cause out-of-memory errors in driver (depends on spark.driver.memory + and memory overhead of objects in JVM). Setting a proper limit can protect the driver from + out-of-memory errors. + + spark.serializer org.apache.spark.serializer.
JavaSerializer From 4e6a7a0b3e55098374a22f3ae9500404f7e4e91a Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 2 Nov 2014 10:44:52 -0800 Subject: [PATCH 2/7] [SPARK-4166][Core][WebUI] Display the executor ID in the Web UI when ExecutorLostFailure happens Now when ExecutorLostFailure happens, it only displays `ExecutorLostFailure (executor lost)`. Adding the executor id will help locate the faulted executor. Author: zsxwing Closes #3033 from zsxwing/SPARK-4166 and squashes the following commits: ff4664c [zsxwing] Backward-compatible support c5c4cf2 [zsxwing] Display the executor ID in the Web UI when ExecutorLostFailure happens --- core/src/main/scala/org/apache/spark/TaskEndReason.scala | 4 ++-- .../scala/org/apache/spark/scheduler/TaskSetManager.scala | 2 +- .../main/scala/org/apache/spark/util/JsonProtocol.scala | 8 ++++++-- .../apache/spark/ui/jobs/JobProgressListenerSuite.scala | 2 +- .../scala/org/apache/spark/util/JsonProtocolSuite.scala | 5 +++-- 5 files changed, 13 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 8f0c5e78416c2..202fba699ab26 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -117,8 +117,8 @@ case object TaskKilled extends TaskFailedReason { * the task crashed the JVM. */ @DeveloperApi -case object ExecutorLostFailure extends TaskFailedReason { - override def toErrorString: String = "ExecutorLostFailure (executor lost)" +case class ExecutorLostFailure(execId: String) extends TaskFailedReason { + override def toErrorString: String = s"ExecutorLostFailure (executor ${execId} lost)" } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a9767340074a8..d8fb640350343 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -732,7 +732,7 @@ private[spark] class TaskSetManager( } // Also re-enqueue any tasks that were running on the node for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure) + handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(execId)) } // recalculate valid locality levels and waits when executor is lost recomputeLocality() diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 5b2e7d3a7edb9..43c7fba06694a 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -272,7 +272,7 @@ private[spark] object JsonProtocol { def taskEndReasonToJson(taskEndReason: TaskEndReason): JValue = { val reason = Utils.getFormattedClassName(taskEndReason) - val json = taskEndReason match { + val json: JObject = taskEndReason match { case fetchFailed: FetchFailed => val blockManagerAddress = Option(fetchFailed.bmAddress). map(blockManagerIdToJson).getOrElse(JNothing) @@ -287,6 +287,8 @@ private[spark] object JsonProtocol { ("Description" -> exceptionFailure.description) ~ ("Stack Trace" -> stackTrace) ~ ("Metrics" -> metrics) + case ExecutorLostFailure(executorId) => + ("Executor ID" -> executorId) case _ => Utils.emptyJson } ("Reason" -> reason) ~ json @@ -636,7 +638,9 @@ private[spark] object JsonProtocol { new ExceptionFailure(className, description, stackTrace, metrics) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled - case `executorLostFailure` => ExecutorLostFailure + case `executorLostFailure` => + val executorId = Utils.jsonOption(json \ "Executor ID").map(_.extract[String]) + ExecutorLostFailure(executorId.getOrElse("Unknown")) case `unknownReason` => UnknownReason } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 3370dd4156c3f..6567c5ab836e7 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -119,7 +119,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc new ExceptionFailure("Exception", "description", null, None), TaskResultLost, TaskKilled, - ExecutorLostFailure, + ExecutorLostFailure("0"), UnknownReason) var failCount = 0 for (reason <- taskFailedReasons) { diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index f1f88c5fd3634..d235d7a0ed839 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -115,7 +115,7 @@ class JsonProtocolSuite extends FunSuite { testTaskEndReason(exceptionFailure) testTaskEndReason(TaskResultLost) testTaskEndReason(TaskKilled) - testTaskEndReason(ExecutorLostFailure) + testTaskEndReason(ExecutorLostFailure("100")) testTaskEndReason(UnknownReason) // BlockId @@ -403,7 +403,8 @@ class JsonProtocolSuite extends FunSuite { assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => - case (ExecutorLostFailure, ExecutorLostFailure) => + case (ExecutorLostFailure(execId1), ExecutorLostFailure(execId2)) => + assert(execId1 === execId2) case (UnknownReason, UnknownReason) => case _ => fail("Task end reasons don't match in types!") } From f0a4b630abf0766cc0c41e682691e0d435caca04 Mon Sep 17 00:00:00 2001 From: wangfei Date: Sun, 2 Nov 2014 14:59:41 -0800 Subject: [PATCH 3/7] [HOTFIX][SQL] hive test missing some golden files cc marmbrus Author: wangfei Closes #3055 from scwf/hotfix and squashes the following commits: d881bd7 [wangfei] miss golden files --- .../golden/truncate_table-1-7fc255c86d7c3a9ff088f9eb29a42565 | 0 .../golden/truncate_table-10-c32b771845f4d5a0330e2cfa09f89a7f | 0 .../golden/truncate_table-7-1ad5d350714e3d4ea17201153772d58d | 0 .../golden/truncate_table-8-76c754eac44c7254b45807255d4dbc3a | 0 .../golden/truncate_table-9-f4286b5657674a6a6b6bc6680f72f89a | 0 .../golden/udf_named_struct-0-50131c0ba7b7a6b65c789a5a8497bada | 1 + .../golden/udf_named_struct-1-8f0ea83364b78634fbb3752c5a5c725 | 1 + .../golden/udf_named_struct-2-380c9638cc6ea8ea42f187bf0cedf350 | 1 + .../golden/udf_named_struct-3-c069e28293a12a813f8e881f776bae90 | 0 .../golden/udf_named_struct-4-b499d4120e009f222f2fab160a9006d7 | 1 + .../golden/udf_struct-0-50131c0ba7b7a6b65c789a5a8497bada | 1 + .../golden/udf_struct-1-f41043b7d9f14fa5e998c90454c7bdb1 | 1 + .../golden/udf_struct-2-8ccdb20153debdab789ea8ad0228e2eb | 1 + .../golden/udf_struct-3-71361a92b74c4d026ac7ae6e1e6746f1 | 0 .../golden/udf_struct-4-b196b5d8849d52bbe5e2ee683f29e051 | 1 + 15 files changed, 8 insertions(+) create mode 100644 sql/hive/src/test/resources/golden/truncate_table-1-7fc255c86d7c3a9ff088f9eb29a42565 create mode 100644 sql/hive/src/test/resources/golden/truncate_table-10-c32b771845f4d5a0330e2cfa09f89a7f create mode 100644 sql/hive/src/test/resources/golden/truncate_table-7-1ad5d350714e3d4ea17201153772d58d create mode 100644 sql/hive/src/test/resources/golden/truncate_table-8-76c754eac44c7254b45807255d4dbc3a create mode 100644 sql/hive/src/test/resources/golden/truncate_table-9-f4286b5657674a6a6b6bc6680f72f89a create mode 100644 sql/hive/src/test/resources/golden/udf_named_struct-0-50131c0ba7b7a6b65c789a5a8497bada create mode 100644 sql/hive/src/test/resources/golden/udf_named_struct-1-8f0ea83364b78634fbb3752c5a5c725 create mode 100644 sql/hive/src/test/resources/golden/udf_named_struct-2-380c9638cc6ea8ea42f187bf0cedf350 create mode 100644 sql/hive/src/test/resources/golden/udf_named_struct-3-c069e28293a12a813f8e881f776bae90 create mode 100644 sql/hive/src/test/resources/golden/udf_named_struct-4-b499d4120e009f222f2fab160a9006d7 create mode 100644 sql/hive/src/test/resources/golden/udf_struct-0-50131c0ba7b7a6b65c789a5a8497bada create mode 100644 sql/hive/src/test/resources/golden/udf_struct-1-f41043b7d9f14fa5e998c90454c7bdb1 create mode 100644 sql/hive/src/test/resources/golden/udf_struct-2-8ccdb20153debdab789ea8ad0228e2eb create mode 100644 sql/hive/src/test/resources/golden/udf_struct-3-71361a92b74c4d026ac7ae6e1e6746f1 create mode 100644 sql/hive/src/test/resources/golden/udf_struct-4-b196b5d8849d52bbe5e2ee683f29e051 diff --git a/sql/hive/src/test/resources/golden/truncate_table-1-7fc255c86d7c3a9ff088f9eb29a42565 b/sql/hive/src/test/resources/golden/truncate_table-1-7fc255c86d7c3a9ff088f9eb29a42565 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-10-c32b771845f4d5a0330e2cfa09f89a7f b/sql/hive/src/test/resources/golden/truncate_table-10-c32b771845f4d5a0330e2cfa09f89a7f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-7-1ad5d350714e3d4ea17201153772d58d b/sql/hive/src/test/resources/golden/truncate_table-7-1ad5d350714e3d4ea17201153772d58d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-8-76c754eac44c7254b45807255d4dbc3a b/sql/hive/src/test/resources/golden/truncate_table-8-76c754eac44c7254b45807255d4dbc3a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-9-f4286b5657674a6a6b6bc6680f72f89a b/sql/hive/src/test/resources/golden/truncate_table-9-f4286b5657674a6a6b6bc6680f72f89a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_named_struct-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_named_struct-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-1-8f0ea83364b78634fbb3752c5a5c725 b/sql/hive/src/test/resources/golden/udf_named_struct-1-8f0ea83364b78634fbb3752c5a5c725 new file mode 100644 index 0000000000000..9bff96e7fa20e --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_named_struct-1-8f0ea83364b78634fbb3752c5a5c725 @@ -0,0 +1 @@ +named_struct(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-2-380c9638cc6ea8ea42f187bf0cedf350 b/sql/hive/src/test/resources/golden/udf_named_struct-2-380c9638cc6ea8ea42f187bf0cedf350 new file mode 100644 index 0000000000000..9bff96e7fa20e --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_named_struct-2-380c9638cc6ea8ea42f187bf0cedf350 @@ -0,0 +1 @@ +named_struct(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-3-c069e28293a12a813f8e881f776bae90 b/sql/hive/src/test/resources/golden/udf_named_struct-3-c069e28293a12a813f8e881f776bae90 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-4-b499d4120e009f222f2fab160a9006d7 b/sql/hive/src/test/resources/golden/udf_named_struct-4-b499d4120e009f222f2fab160a9006d7 new file mode 100644 index 0000000000000..de25f51b5b56d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_named_struct-4-b499d4120e009f222f2fab160a9006d7 @@ -0,0 +1 @@ +{"foo":1,"bar":2} 1 diff --git a/sql/hive/src/test/resources/golden/udf_struct-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_struct-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_struct-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/udf_struct-1-f41043b7d9f14fa5e998c90454c7bdb1 b/sql/hive/src/test/resources/golden/udf_struct-1-f41043b7d9f14fa5e998c90454c7bdb1 new file mode 100644 index 0000000000000..062cb1bc683b1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_struct-1-f41043b7d9f14fa5e998c90454c7bdb1 @@ -0,0 +1 @@ +struct(col1, col2, col3, ...) - Creates a struct with the given field values diff --git a/sql/hive/src/test/resources/golden/udf_struct-2-8ccdb20153debdab789ea8ad0228e2eb b/sql/hive/src/test/resources/golden/udf_struct-2-8ccdb20153debdab789ea8ad0228e2eb new file mode 100644 index 0000000000000..062cb1bc683b1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_struct-2-8ccdb20153debdab789ea8ad0228e2eb @@ -0,0 +1 @@ +struct(col1, col2, col3, ...) - Creates a struct with the given field values diff --git a/sql/hive/src/test/resources/golden/udf_struct-3-71361a92b74c4d026ac7ae6e1e6746f1 b/sql/hive/src/test/resources/golden/udf_struct-3-71361a92b74c4d026ac7ae6e1e6746f1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_struct-4-b196b5d8849d52bbe5e2ee683f29e051 b/sql/hive/src/test/resources/golden/udf_struct-4-b196b5d8849d52bbe5e2ee683f29e051 new file mode 100644 index 0000000000000..ff1a28fa47f18 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_struct-4-b196b5d8849d52bbe5e2ee683f29e051 @@ -0,0 +1 @@ +{"col1":1} {"col1":1,"col2":"a"} 1 a From 9c0eb57c737dd7d97d2cbd4516ddd2cf5d06e4b2 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 2 Nov 2014 15:08:35 -0800 Subject: [PATCH 4/7] [SPARK-3247][SQL] An API for adding data sources to Spark SQL This PR introduces a new set of APIs to Spark SQL to allow other developers to add support for reading data from new sources in `org.apache.spark.sql.sources`. New sources must implement the interface `BaseRelation`, which is responsible for describing the schema of the data. BaseRelations have three `Scan` subclasses, which are responsible for producing an RDD containing row objects. The [various Scan interfaces](https://github.com/marmbrus/spark/blob/foreign/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala#L50) allow for optimizations such as column pruning and filter push down, when the underlying data source can handle these operations. By implementing a class that inherits from RelationProvider these data sources can be accessed using using pure SQL. I've used the functionality to update the JSON support so it can now be used in this way as follows: ```sql CREATE TEMPORARY TABLE jsonTableSQL USING org.apache.spark.sql.json OPTIONS ( path '/home/michael/data.json' ) ``` Further example usage can be found in the test cases: https://github.com/marmbrus/spark/tree/foreign/sql/core/src/test/scala/org/apache/spark/sql/sources There is also a library that uses this new API to read avro data available here: https://github.com/marmbrus/sql-avro Author: Michael Armbrust Closes #2475 from marmbrus/foreign and squashes the following commits: 1ed6010 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into foreign ab2c31f [Michael Armbrust] fix test 1d41bb5 [Michael Armbrust] unify argument names 5b47901 [Michael Armbrust] Remove sealed, more filter types fab154a [Michael Armbrust] Merge remote-tracking branch 'origin/master' into foreign e3e690e [Michael Armbrust] Add hook for extraStrategies a70d602 [Michael Armbrust] Fix style, more tests, FilteredSuite => PrunedFilteredSuite 70da6d9 [Michael Armbrust] Modify API to ease binary compatibility and interop with Java 7d948ae [Michael Armbrust] Fix equality of AttributeReference. 5545491 [Michael Armbrust] Address comments 5031ac3 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into foreign 22963ef [Michael Armbrust] package objects compile wierdly... b069146 [Michael Armbrust] traits => abstract classes 34f836a [Michael Armbrust] Make @DeveloperApi 0d74bcf [Michael Armbrust] Add documention on object life cycle 3e06776 [Michael Armbrust] remove line wraps de3b68c [Michael Armbrust] Remove empty file 360cb30 [Michael Armbrust] style and java api 2957875 [Michael Armbrust] add override 0fd3a07 [Michael Armbrust] Draft of data sources API --- .../expressions/namedExpressions.scala | 2 +- .../apache/spark/sql/catalyst/package.scala | 4 + .../sql/catalyst/planning/QueryPlanner.scala | 20 +- .../spark/sql/catalyst/types/dataTypes.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 25 ++- .../spark/sql/api/java/JavaSQLContext.scala | 5 + .../spark/sql/execution/ExistingRDD.scala | 6 - .../spark/sql/execution/SparkStrategies.scala | 3 +- .../apache/spark/sql/execution/commands.scala | 35 +++- .../apache/spark/sql/json/JSONRelation.scala | 49 +++++ .../scala/org/apache/spark/sql/package.scala | 9 + .../sql/sources/DataSourceStrategy.scala | 112 +++++++++++ .../spark/sql/sources/LogicalRelation.scala | 54 ++++++ .../org/apache/spark/sql/sources/ddl.scala | 108 +++++++++++ .../apache/spark/sql/sources/filters.scala | 26 +++ .../apache/spark/sql/sources/interfaces.scala | 86 +++++++++ .../apache/spark/sql/sources/package.scala | 22 +++ .../apache/spark/sql/CachedTableSuite.scala | 12 -- .../org/apache/spark/sql/QueryTest.scala | 30 ++- .../org/apache/spark/sql/json/JsonSuite.scala | 26 +++ .../spark/sql/sources/DataSourceTest.scala | 34 ++++ .../spark/sql/sources/FilteredScanSuite.scala | 176 ++++++++++++++++++ .../spark/sql/sources/PrunedScanSuite.scala | 137 ++++++++++++++ .../spark/sql/sources/TableScanSuite.scala | 125 +++++++++++++ .../apache/spark/sql/hive/HiveContext.scala | 6 +- .../spark/sql/hive/HiveStrategies.scala | 2 +- 26 files changed, 1074 insertions(+), 42 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 3310566087b3d..fc90a54a58259 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -134,7 +134,7 @@ case class AttributeReference( val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] { override def equals(other: Any) = other match { - case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType + case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala index bdd07bbeb2230..a38079ced34b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +/** + * Catalyst is a library for manipulating relational query plans. All classes in catalyst are + * considered an internal API to Spark SQL and are subject to change between minor releases. + */ package object catalyst { /** * A JVM-global lock that should be used to prevent thread safety issues when using things in diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 5839c9f7c43ef..51b5699affed5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -21,6 +21,15 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode +/** + * Given a [[plans.logical.LogicalPlan LogicalPlan]], returns a list of `PhysicalPlan`s that can + * be used for execution. If this strategy does not apply to the give logical operation then an + * empty list should be returned. + */ +abstract class GenericStrategy[PhysicalPlan <: TreeNode[PhysicalPlan]] extends Logging { + def apply(plan: LogicalPlan): Seq[PhysicalPlan] +} + /** * Abstract class for transforming [[plans.logical.LogicalPlan LogicalPlan]]s into physical plans. * Child classes are responsible for specifying a list of [[Strategy]] objects that each of which @@ -35,16 +44,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNode */ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { /** A list of execution strategies that can be used by the planner */ - def strategies: Seq[Strategy] - - /** - * Given a [[plans.logical.LogicalPlan LogicalPlan]], returns a list of `PhysicalPlan`s that can - * be used for execution. If this strategy does not apply to the give logical operation then an - * empty list should be returned. - */ - abstract protected class Strategy extends Logging { - def apply(plan: LogicalPlan): Seq[PhysicalPlan] - } + def strategies: Seq[GenericStrategy[PhysicalPlan]] /** * Returns a placeholder for a physical plan that executes `plan`. This placeholder will be diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 8dda0b182805c..d25f3a619dd75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -455,7 +455,7 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT case class StructField( name: String, dataType: DataType, - nullable: Boolean, + nullable: Boolean = true, metadata: Metadata = Metadata.empty) { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4953f8399a96b..4cded98c803f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution.{SparkStrategies, _} import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.sources.{DataSourceStrategy, BaseRelation, DDLParser, LogicalRelation} /** * :: AlphaComponent :: @@ -68,13 +69,19 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer + @transient + protected[sql] val ddlParser = new DDLParser + @transient protected[sql] val sqlParser = { val fallback = new catalyst.SqlParser new catalyst.SparkSQLParser(fallback(_)) } - protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser(sql) + protected[sql] def parseSql(sql: String): LogicalPlan = { + ddlParser(sql).getOrElse(sqlParser(sql)) + } + protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } @@ -104,6 +111,10 @@ class SQLContext(@transient val sparkContext: SparkContext) LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self)) } + implicit def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = { + logicalPlanToSparkQuery(LogicalRelation(baseRelation)) + } + /** * :: DeveloperApi :: * Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. @@ -283,6 +294,14 @@ class SQLContext(@transient val sparkContext: SparkContext) def table(tableName: String): SchemaRDD = new SchemaRDD(this, catalog.lookupRelation(None, tableName)) + /** + * :: DeveloperApi :: + * Allows extra strategies to be injected into the query planner at runtime. Note this API + * should be consider experimental and is not intended to be stable across releases. + */ + @DeveloperApi + var extraStrategies: Seq[Strategy] = Nil + protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext @@ -293,7 +312,9 @@ class SQLContext(@transient val sparkContext: SparkContext) def numPartitions = self.numShufflePartitions val strategies: Seq[Strategy] = + extraStrategies ++ ( CommandStrategy(self) :: + DataSourceStrategy :: TakeOrdered :: HashAggregation :: LeftSemiJoin :: @@ -302,7 +323,7 @@ class SQLContext(@transient val sparkContext: SparkContext) ParquetOperations :: BasicOperators :: CartesianProduct :: - BroadcastNestedLoopJoin :: Nil + BroadcastNestedLoopJoin :: Nil) /** * Used to build table scan operators where complex projection and filtering are done using diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 876b1c6edef20..60065509bfbbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation} import org.apache.spark.sql.types.util.DataTypeConversions import org.apache.spark.sql.{SQLContext, StructType => SStructType} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} @@ -39,6 +40,10 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { def this(sparkContext: JavaSparkContext) = this(new SQLContext(sparkContext.sc)) + def baseRelationToSchemaRDD(baseRelation: BaseRelation): JavaSchemaRDD = { + new JavaSchemaRDD(sqlContext, LogicalRelation(baseRelation)) + } + /** * Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is * used for SQL parsing can be configured with 'spark.sql.dialect'. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 04c51a1ee4b97..d64c5af89ec99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -50,12 +50,6 @@ object RDDConversions { } } } - - /* - def toLogicalPlan[A <: Product : TypeTag](productRdd: RDD[A]): LogicalPlan = { - LogicalRDD(ScalaReflection.attributesFor[A], productToRowRdd(productRdd)) - } - */ } case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 79e4ddb8c4f5d..2cd3063bc3097 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{SQLContext, execution} +import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ @@ -304,6 +304,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case class CommandStrategy(context: SQLContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case r: RunnableCommand => ExecutedCommand(r) :: Nil case logical.SetCommand(kv) => Seq(execution.SetCommand(kv, plan.output)(context)) case logical.ExplainCommand(logicalPlan, extended) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 5859eba408ee1..e658e6fc4d5d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -21,10 +21,12 @@ import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Row, Attribute} +import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.{Row, SQLConf, SQLContext} +import org.apache.spark.sql.{SQLConf, SQLContext} +// TODO: DELETE ME... trait Command { this: SparkPlan => @@ -44,6 +46,35 @@ trait Command { override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) } +// TODO: Replace command with runnable command. +trait RunnableCommand extends logical.Command { + self: Product => + + def output: Seq[Attribute] + def run(sqlContext: SQLContext): Seq[Row] +} + +case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { + /** + * A concrete command should override this lazy field to wrap up any side effects caused by the + * command or any other computation that should be evaluated exactly once. The value of this field + * can be used as the contents of the corresponding RDD generated from the physical plan of this + * command. + * + * The `execute()` method of all the physical command classes should reference `sideEffectResult` + * so that the command can be executed eagerly right after the command query is created. + */ + protected[sql] lazy val sideEffectResult: Seq[Row] = cmd.run(sqlContext) + + override def output = cmd.output + + override def children = Nil + + override def executeCollect(): Array[Row] = sideEffectResult.toArray + + override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) +} + /** * :: DeveloperApi :: */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala new file mode 100644 index 0000000000000..fc70c183437f6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.json + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources._ + +private[sql] class DefaultSource extends RelationProvider { + /** Returns a new base relation with the given parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) + val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + + JSONRelation(fileName, samplingRatio)(sqlContext) + } +} + +private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)( + @transient val sqlContext: SQLContext) + extends TableScan { + + private def baseRDD = sqlContext.sparkContext.textFile(fileName) + + override val schema = + JsonRDD.inferSchema( + baseRDD, + samplingRatio, + sqlContext.columnNameOfCorruptRecord) + + override def buildScan() = + JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 05926a24c5307..51dad54f1a3f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.SparkPlan /** * Allows the execution of relational queries, including those expressed in SQL using Spark. @@ -432,6 +433,12 @@ package object sql { @DeveloperApi val StructField = catalyst.types.StructField + /** + * Converts a logical plan into zero or more SparkPlans. + */ + @DeveloperApi + type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] + /** * :: DeveloperApi :: * @@ -448,7 +455,9 @@ package object sql { type Metadata = catalyst.util.Metadata /** + * :: DeveloperApi :: * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. */ + @DeveloperApi type MetadataBuilder = catalyst.util.MetadataBuilder } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala new file mode 100644 index 0000000000000..9b8c6a56b94b4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan + +/** + * A Strategy for planning scans over data sources defined using the sources API. + */ +private[sql] object DataSourceStrategy extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedFilteredScan)) => + pruneFilterProject( + l, + projectList, + filters, + (a, f) => t.buildScan(a, f)) :: Nil + + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedScan)) => + pruneFilterProject( + l, + projectList, + filters, + (a, _) => t.buildScan(a)) :: Nil + + case l @ LogicalRelation(t: TableScan) => + execution.PhysicalRDD(l.output, t.buildScan()) :: Nil + + case _ => Nil + } + + protected def pruneFilterProject( + relation: LogicalRelation, + projectList: Seq[NamedExpression], + filterPredicates: Seq[Expression], + scanBuilder: (Array[String], Array[Filter]) => RDD[Row]) = { + + val projectSet = AttributeSet(projectList.flatMap(_.references)) + val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) + val filterCondition = filterPredicates.reduceLeftOption(And) + + val pushedFilters = selectFilters(filterPredicates.map { _ transform { + case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes. + }}).toArray + + if (projectList.map(_.toAttribute) == projectList && + projectSet.size == projectList.size && + filterSet.subsetOf(projectSet)) { + // When it is possible to just use column pruning to get the right projection and + // when the columns of this projection are enough to evaluate all filter conditions, + // just do a scan followed by a filter, with no extra project. + val requestedColumns = + projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above. + .map(relation.attributeMap) // Match original case of attributes. + .map(_.name) + .toArray + + val scan = + execution.PhysicalRDD( + projectList.map(_.toAttribute), + scanBuilder(requestedColumns, pushedFilters)) + filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) + } else { + val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq + val columnNames = requestedColumns.map(_.name).toArray + + val scan = execution.PhysicalRDD(requestedColumns, scanBuilder(columnNames, pushedFilters)) + execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) + } + } + + protected def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { + case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v) + case expressions.EqualTo(Literal(v, _), a: Attribute) => EqualTo(a.name, v) + + case expressions.GreaterThan(a: Attribute, Literal(v, _)) => GreaterThan(a.name, v) + case expressions.GreaterThan(Literal(v, _), a: Attribute) => LessThan(a.name, v) + + case expressions.LessThan(a: Attribute, Literal(v, _)) => LessThan(a.name, v) + case expressions.LessThan(Literal(v, _), a: Attribute) => GreaterThan(a.name, v) + + case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => + GreaterThanOrEqual(a.name, v) + case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => + LessThanOrEqual(a.name, v) + + case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v) + case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala new file mode 100644 index 0000000000000..82a2cf8402f8f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.sources + +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.expressions.AttributeMap +import org.apache.spark.sql.catalyst.plans.logical.{Statistics, LeafNode, LogicalPlan} + +/** + * Used to link a [[BaseRelation]] in to a logical query plan. + */ +private[sql] case class LogicalRelation(relation: BaseRelation) + extends LeafNode + with MultiInstanceRelation { + + override val output = relation.schema.toAttributes + + // Logical Relations are distinct if they have different output for the sake of transformations. + override def equals(other: Any) = other match { + case l @ LogicalRelation(otherRelation) => relation == otherRelation && output == l.output + case _ => false + } + + override def sameResult(otherPlan: LogicalPlan) = otherPlan match { + case LogicalRelation(otherRelation) => relation == otherRelation + case _ => false + } + + @transient override lazy val statistics = Statistics( + // TODO: Allow datasources to provide statistics as well. + sizeInBytes = BigInt(relation.sqlContext.defaultSizeInBytes) + ) + + /** Used to lookup original attribute capitalization */ + val attributeMap = AttributeMap(output.map(o => (o, o))) + + def newInstance() = LogicalRelation(relation).asInstanceOf[this.type] + + override def simpleString = s"Relation[${output.mkString(",")}] $relation" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala new file mode 100644 index 0000000000000..9168ca2fc6fec --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.util.Utils + +import scala.language.implicitConversions +import scala.util.parsing.combinator.lexical.StdLexical +import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.combinator.PackratParsers + +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.SqlLexical + +/** + * A parser for foreign DDL commands. + */ +private[sql] class DDLParser extends StandardTokenParsers with PackratParsers with Logging { + + def apply(input: String): Option[LogicalPlan] = { + phrase(ddl)(new lexical.Scanner(input)) match { + case Success(r, x) => Some(r) + case x => + logDebug(s"Not recognized as DDL: $x") + None + } + } + + protected case class Keyword(str: String) + + protected implicit def asParser(k: Keyword): Parser[String] = + lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + + protected val CREATE = Keyword("CREATE") + protected val TEMPORARY = Keyword("TEMPORARY") + protected val TABLE = Keyword("TABLE") + protected val USING = Keyword("USING") + protected val OPTIONS = Keyword("OPTIONS") + + // Use reflection to find the reserved words defined in this class. + protected val reservedWords = + this.getClass + .getMethods + .filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword].str) + + override val lexical = new SqlLexical(reservedWords) + + protected lazy val ddl: Parser[LogicalPlan] = createTable + + /** + * CREATE FOREIGN TEMPORARY TABLE avroTable + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") + */ + protected lazy val createTable: Parser[LogicalPlan] = + CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { + case tableName ~ provider ~ opts => + CreateTableUsing(tableName, provider, opts) + } + + protected lazy val options: Parser[Map[String, String]] = + "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } + + protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} + + protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } +} + +private[sql] case class CreateTableUsing( + tableName: String, + provider: String, + options: Map[String, String]) extends RunnableCommand { + + def run(sqlContext: SQLContext) = { + val loader = Utils.getContextOrSparkClassLoader + val clazz: Class[_] = try loader.loadClass(provider) catch { + case cnf: java.lang.ClassNotFoundException => + try loader.loadClass(provider + ".DefaultSource") catch { + case cnf: java.lang.ClassNotFoundException => + sys.error(s"Failed to load class for data source: $provider") + } + } + val dataSource = clazz.newInstance().asInstanceOf[org.apache.spark.sql.sources.RelationProvider] + val relation = dataSource.createRelation(sqlContext, options) + + sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName) + Seq.empty + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala new file mode 100644 index 0000000000000..e72a2aeb8f310 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +abstract class Filter + +case class EqualTo(attribute: String, value: Any) extends Filter +case class GreaterThan(attribute: String, value: Any) extends Filter +case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter +case class LessThan(attribute: String, value: Any) extends Filter +case class LessThanOrEqual(attribute: String, value: Any) extends Filter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala new file mode 100644 index 0000000000000..ac3bf9d8e1a21 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -0,0 +1,86 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.apache.spark.sql.sources + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext, StructType} +import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} + +/** + * Implemented by objects that produce relations for a specific kind of data source. When + * Spark SQL is given a DDL operation with a USING clause specified, this interface is used to + * pass in the parameters specified by a user. + * + * Users may specify the fully qualified class name of a given data source. When that class is + * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for + * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the + * data source 'org.apache.spark.sql.json.DefaultSource' + * + * A new instance of this class with be instantiated each time a DDL call is made. + */ +@DeveloperApi +trait RelationProvider { + /** Returns a new base relation with the given parameters. */ + def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation +} + +/** + * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must + * be able to produce the schema of their data in the form of a [[StructType]] Concrete + * implementation should inherit from one of the descendant `Scan` classes, which define various + * abstract methods for execution. + * + * BaseRelations must also define a equality function that only returns true when the two + * instances will return the same data. This equality function is used when determining when + * it is safe to substitute cached results for a given relation. + */ +@DeveloperApi +abstract class BaseRelation { + def sqlContext: SQLContext + def schema: StructType +} + +/** + * A BaseRelation that can produce all of its tuples as an RDD of Row objects. + */ +@DeveloperApi +abstract class TableScan extends BaseRelation { + def buildScan(): RDD[Row] +} + +/** + * A BaseRelation that can eliminate unneeded columns before producing an RDD + * containing all of its tuples as Row objects. + */ +@DeveloperApi +abstract class PrunedScan extends BaseRelation { + def buildScan(requiredColumns: Array[String]): RDD[Row] +} + +/** + * A BaseRelation that can eliminate unneeded columns and filter using selected + * predicates before producing an RDD containing all matching tuples as Row objects. + * + * The pushed down filters are currently purely an optimization as they will all be evaluated + * again. This means it is safe to use them with methods that produce false positives such + * as filtering partitions based on a bloom filter. + */ +@DeveloperApi +abstract class PrunedFilteredScan extends BaseRelation { + def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala new file mode 100644 index 0000000000000..8393c510f4f6d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +/** + * A set of APIs for adding data sources to Spark SQL. + */ +package object sources diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1a5d87d5240e9..44a2961b27eda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -27,18 +27,6 @@ case class BigData(s: String) class CachedTableSuite extends QueryTest { TestData // Load test tables. - def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { - val planWithCaching = query.queryExecution.withCachedData - val cachedData = planWithCaching collect { - case cached: InMemoryRelation => cached - } - - assert( - cachedData.size == numCachedTables, - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) - } - def rddIdOf(tableName: String): Int = { val executedPlan = table(tableName).queryExecution.executedPlan executedPlan.collect { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 042f61f5a4113..3d9f0cbf80fe7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.columnar.InMemoryRelation class QueryTest extends PlanTest { + /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer @@ -78,11 +80,31 @@ class QueryTest extends PlanTest { |${rdd.queryExecution.executedPlan} |== Results == |${sideBySide( - s"== Correct Answer - ${convertedAnswer.size} ==" +: - prepareAnswer(convertedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} + s"== Correct Answer - ${convertedAnswer.size} ==" +: + prepareAnswer(convertedAnswer).map(_.toString), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} """.stripMargin) } } + + def sqlTest(sqlString: String, expectedAnswer: Any)(implicit sqlContext: SQLContext): Unit = { + test(sqlString) { + checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + } + } + + /** Asserts that a given SchemaRDD will be executed using the given number of cached results. */ + def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + val planWithCaching = query.queryExecution.withCachedData + val cachedData = planWithCaching collect { + case cached: InMemoryRelation => cached + } + + assert( + cachedData.size == numCachedTables, + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 1cb6c23c58f36..362c7e1a52482 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -549,6 +549,32 @@ class JsonSuite extends QueryTest { ) } + test("Loading a JSON dataset from a text file with SQL") { + val file = getTempFilePath("json") + val path = file.toString + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + + sql( + s""" + |CREATE TEMPORARY TABLE jsonTableSQL + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + + checkAnswer( + sql("select * from jsonTableSQL"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") :: Nil + ) + } + test("Applying schemas") { val file = getTempFilePath("json") val path = file.toString diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala new file mode 100644 index 0000000000000..9626252e742e5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -0,0 +1,34 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.Analyzer +import org.apache.spark.sql.test.TestSQLContext +import org.scalatest.BeforeAndAfter + +abstract class DataSourceTest extends QueryTest with BeforeAndAfter { + // Case sensitivity is not configurable yet, but we want to test some edge cases. + // TODO: Remove when it is configurable + implicit val caseInsensisitiveContext = new SQLContext(TestSQLContext.sparkContext) { + @transient + override protected[sql] lazy val analyzer: Analyzer = + new Analyzer(catalog, functionRegistry, caseSensitive = false) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala new file mode 100644 index 0000000000000..8b2f1591d5bf3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -0,0 +1,176 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import scala.language.existentials + +import org.apache.spark.sql._ + +class FilteredScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimpleFilteredScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + } +} + +case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) + extends PrunedFilteredScan { + + override def schema = + StructType( + StructField("a", IntegerType, nullable = false) :: + StructField("b", IntegerType, nullable = false) :: Nil) + + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { + val rowBuilders = requiredColumns.map { + case "a" => (i: Int) => Seq(i) + case "b" => (i: Int) => Seq(i * 2) + } + + FiltersPushed.list = filters + + val filterFunctions = filters.collect { + case EqualTo("a", v) => (a: Int) => a == v + case LessThan("a", v: Int) => (a: Int) => a < v + case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v + case GreaterThan("a", v: Int) => (a: Int) => a > v + case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v + } + + def eval(a: Int) = !filterFunctions.map(_(a)).contains(false) + + sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i => + Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) + } +} + +// A hack for better error messages when filter pushdown fails. +object FiltersPushed { + var list: Seq[Filter] = Nil +} + +class FilteredScanSuite extends DataSourceTest { + + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE oneToTenFiltered + |USING org.apache.spark.sql.sources.FilteredScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + } + + sqlTest( + "SELECT * FROM oneToTenFiltered", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT b, a FROM oneToTenFiltered", + (1 to 10).map(i => Row(i * 2, i)).toSeq) + + sqlTest( + "SELECT a FROM oneToTenFiltered", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT b FROM oneToTenFiltered", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT a * 2 FROM oneToTenFiltered", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT A AS b FROM oneToTenFiltered", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT x.b, y.a FROM oneToTenFiltered x JOIN oneToTenFiltered y ON x.a = y.b", + (1 to 5).map(i => Row(i * 4, i)).toSeq) + + sqlTest( + "SELECT x.a, y.b FROM oneToTenFiltered x JOIN oneToTenFiltered y ON x.a = y.b", + (2 to 10 by 2).map(i => Row(i, i)).toSeq) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE a = 1", + Seq(1).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE A = 1", + Seq(1).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE b = 2", + Seq(1).map(i => Row(i, i * 2)).toSeq) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 = a", 1) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1", 9) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a >= 2", 9) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 < a", 9) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 <= a", 9) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 > a", 0) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 >= a", 2) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 1", 0) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a <= 2", 2) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0) + testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10) + + def testPushDown(sqlString: String, expectedCount: Int): Unit = { + test(s"PushDown Returns $expectedCount: $sqlString") { + val queryExecution = sql(sqlString).queryExecution + val rawPlan = queryExecution.executedPlan.collect { + case p: execution.PhysicalRDD => p + } match { + case Seq(p) => p + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + val rawCount = rawPlan.execute().count() + + if (rawCount != expectedCount) { + fail( + s"Wrong # of results for pushed filter. Got $rawCount, Expected $expectedCount\n" + + s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + + queryExecution) + } + } + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala new file mode 100644 index 0000000000000..fee2e22611cdc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -0,0 +1,137 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql._ + +class PrunedScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimplePrunedScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + } +} + +case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) + extends PrunedScan { + + override def schema = + StructType( + StructField("a", IntegerType, nullable = false) :: + StructField("b", IntegerType, nullable = false) :: Nil) + + override def buildScan(requiredColumns: Array[String]) = { + val rowBuilders = requiredColumns.map { + case "a" => (i: Int) => Seq(i) + case "b" => (i: Int) => Seq(i * 2) + } + + sqlContext.sparkContext.parallelize(from to to).map(i => + Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) + } +} + +class PrunedScanSuite extends DataSourceTest { + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE oneToTenPruned + |USING org.apache.spark.sql.sources.PrunedScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + } + + sqlTest( + "SELECT * FROM oneToTenPruned", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT a, b FROM oneToTenPruned", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT b, a FROM oneToTenPruned", + (1 to 10).map(i => Row(i * 2, i)).toSeq) + + sqlTest( + "SELECT a FROM oneToTenPruned", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT a, a FROM oneToTenPruned", + (1 to 10).map(i => Row(i, i)).toSeq) + + sqlTest( + "SELECT b FROM oneToTenPruned", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT a * 2 FROM oneToTenPruned", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT A AS b FROM oneToTenPruned", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT x.b, y.a FROM oneToTenPruned x JOIN oneToTenPruned y ON x.a = y.b", + (1 to 5).map(i => Row(i * 4, i)).toSeq) + + sqlTest( + "SELECT x.a, y.b FROM oneToTenPruned x JOIN oneToTenPruned y ON x.a = y.b", + (2 to 10 by 2).map(i => Row(i, i)).toSeq) + + testPruning("SELECT * FROM oneToTenPruned", "a", "b") + testPruning("SELECT a, b FROM oneToTenPruned", "a", "b") + testPruning("SELECT b, a FROM oneToTenPruned", "b", "a") + testPruning("SELECT b, b FROM oneToTenPruned", "b") + testPruning("SELECT a FROM oneToTenPruned", "a") + testPruning("SELECT b FROM oneToTenPruned", "b") + + def testPruning(sqlString: String, expectedColumns: String*): Unit = { + test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { + val queryExecution = sql(sqlString).queryExecution + val rawPlan = queryExecution.executedPlan.collect { + case p: execution.PhysicalRDD => p + } match { + case Seq(p) => p + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + val rawColumns = rawPlan.output.map(_.name) + val rawOutput = rawPlan.execute().first() + + if (rawColumns != expectedColumns) { + fail( + s"Wrong column names. Got $rawColumns, Expected $expectedColumns\n" + + s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + + queryExecution) + } + + if (rawOutput.size != expectedColumns.size) { + fail(s"Wrong output row. Got $rawOutput\n$queryExecution") + } + } + } + +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala new file mode 100644 index 0000000000000..b254b0620c779 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -0,0 +1,125 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql._ + +class DefaultSource extends SimpleScanSource + +class SimpleScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimpleScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + } +} + +case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) + extends TableScan { + + override def schema = + StructType(StructField("i", IntegerType, nullable = false) :: Nil) + + override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) +} + +class TableScanSuite extends DataSourceTest { + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE oneToTen + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + } + + sqlTest( + "SELECT * FROM oneToTen", + (1 to 10).map(Row(_)).toSeq) + + sqlTest( + "SELECT i FROM oneToTen", + (1 to 10).map(Row(_)).toSeq) + + sqlTest( + "SELECT i FROM oneToTen WHERE i < 5", + (1 to 4).map(Row(_)).toSeq) + + sqlTest( + "SELECT i * 2 FROM oneToTen", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1", + (2 to 10).map(i => Row(i, i - 1)).toSeq) + + + test("Caching") { + // Cached Query Execution + cacheTable("oneToTen") + assertCached(sql("SELECT * FROM oneToTen")) + checkAnswer( + sql("SELECT * FROM oneToTen"), + (1 to 10).map(Row(_)).toSeq) + + assertCached(sql("SELECT i FROM oneToTen")) + checkAnswer( + sql("SELECT i FROM oneToTen"), + (1 to 10).map(Row(_)).toSeq) + + assertCached(sql("SELECT i FROM oneToTen WHERE i < 5")) + checkAnswer( + sql("SELECT i FROM oneToTen WHERE i < 5"), + (1 to 4).map(Row(_)).toSeq) + + assertCached(sql("SELECT i * 2 FROM oneToTen")) + checkAnswer( + sql("SELECT i * 2 FROM oneToTen"), + (1 to 10).map(i => Row(i * 2)).toSeq) + + assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) + checkAnswer( + sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), + (2 to 10).map(i => Row(i, i - 1)).toSeq) + + // Verify uncaching + uncacheTable("oneToTen") + assertCached(sql("SELECT * FROM oneToTen"), 0) + } + + test("defaultSource") { + sql( + """ + |CREATE TEMPORARY TABLE oneToTenDef + |USING org.apache.spark.sql.sources + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM oneToTenDef"), + (1 to 10).map(Row(_)).toSeq) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2e27817d60221..dca5367f244de 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -50,6 +50,7 @@ import org.apache.spark.sql.execution.ExtractPythonUdfs import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.{Command => PhysicalCommand} import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand +import org.apache.spark.sql.sources.DataSourceStrategy /** * DEPRECATED: Use HiveContext instead. @@ -99,7 +100,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { if (dialect == "sql") { super.sql(sqlText) } else if (dialect == "hiveql") { - new SchemaRDD(this, HiveQl.parseSql(sqlText)) + new SchemaRDD(this, ddlParser(sqlText).getOrElse(HiveQl.parseSql(sqlText))) } else { sys.error(s"Unsupported SQL dialect: $dialect. Try 'sql' or 'hiveql'") } @@ -345,7 +346,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val hivePlanner = new SparkPlanner with HiveStrategies { val hiveContext = self - override val strategies: Seq[Strategy] = Seq( + override val strategies: Seq[Strategy] = extraStrategies ++ Seq( + DataSourceStrategy, CommandStrategy(self), HiveCommandStrategy(self), TakeOrdered, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 3207ad81d9571..989740c8d43b6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan} import org.apache.spark.sql.hive import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{SQLContext, SchemaRDD, Strategy} import scala.collection.JavaConversions._ From e4b80894bdb72c0acf8832fd48421c546fbc37e6 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 2 Nov 2014 15:14:44 -0800 Subject: [PATCH 5/7] [SPARK-4182][SQL] Fixes ColumnStats classes for boolean, binary and complex data types `NoopColumnStats` was once used for binary, boolean and complex data types. This `ColumnStats` doesn't return properly shaped column statistics and causes caching failure if a table contains columns of the aforementioned types. This PR adds `BooleanColumnStats`, `BinaryColumnStats` and `GenericColumnStats`, used for boolean, binary and all complex data types respectively. In addition, `NoopColumnStats` returns properly shaped column statistics containing null count and row count, but this class is now used for testing purpose only. Author: Cheng Lian Closes #3059 from liancheng/spark-4182 and squashes the following commits: b398cfd [Cheng Lian] Fixes failed test case fb3ee85 [Cheng Lian] Fixes SPARK-4182 --- .../spark/sql/columnar/ColumnBuilder.scala | 10 +++-- .../spark/sql/columnar/ColumnStats.scala | 45 ++++++++++++++++++- .../columnar/InMemoryColumnarTableScan.scala | 3 ++ .../org/apache/spark/sql/SQLQuerySuite.scala | 7 +-- .../scala/org/apache/spark/sql/TestData.scala | 8 ++++ .../columnar/InMemoryColumnarQuerySuite.scala | 28 +++++++----- 6 files changed, 82 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 300cef15bf8a4..c68dceef3b142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -79,8 +79,9 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( } private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType]( + columnStats: ColumnStats, columnType: ColumnType[T, JvmType]) - extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType) + extends BasicColumnBuilder[T, JvmType](columnStats, columnType) with NullableColumnBuilder private[sql] abstract class NativeColumnBuilder[T <: NativeType]( @@ -91,7 +92,7 @@ private[sql] abstract class NativeColumnBuilder[T <: NativeType]( with AllCompressionSchemes with CompressibleColumnBuilder[T] -private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new NoopColumnStats, BOOLEAN) +private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) @@ -112,10 +113,11 @@ private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnS private[sql] class TimestampColumnBuilder extends NativeColumnBuilder(new TimestampColumnStats, TIMESTAMP) -private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(BINARY) +private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) // TODO (lian) Add support for array, struct and map -private[sql] class GenericColumnBuilder extends ComplexColumnBuilder(GENERIC) +private[sql] class GenericColumnBuilder + extends ComplexColumnBuilder(new GenericColumnStats, GENERIC) private[sql] object ColumnBuilder { val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index b9f9f8270045c..668efe4a3b2a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -70,11 +70,30 @@ private[sql] sealed trait ColumnStats extends Serializable { def collectedStatistics: Row } +/** + * A no-op ColumnStats only used for testing purposes. + */ private[sql] class NoopColumnStats extends ColumnStats { + override def gatherStats(row: Row, ordinal: Int): Unit = super.gatherStats(row, ordinal) + + def collectedStatistics = Row(null, null, nullCount, count, 0L) +} - override def gatherStats(row: Row, ordinal: Int): Unit = {} +private[sql] class BooleanColumnStats extends ColumnStats { + protected var upper = false + protected var lower = true - override def collectedStatistics = Row() + override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row.getBoolean(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += BOOLEAN.defaultSize + } + } + + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ByteColumnStats extends ColumnStats { @@ -229,3 +248,25 @@ private[sql] class TimestampColumnStats extends ColumnStats { def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } + +private[sql] class BinaryColumnStats extends ColumnStats { + override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + sizeInBytes += BINARY.actualSize(row, ordinal) + } + } + + def collectedStatistics = Row(null, null, nullCount, count, sizeInBytes) +} + +private[sql] class GenericColumnStats extends ColumnStats { + override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + sizeInBytes += GENERIC.actualSize(row, ordinal) + } + } + + def collectedStatistics = Row(null, null, nullCount, count, sizeInBytes) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index ee63134f56d8c..455b415d9d959 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -161,6 +161,9 @@ private[sql] case class InMemoryRelation( } def cachedColumnBuffers = _cachedColumnBuffers + + override protected def otherCopyArgs: Seq[AnyRef] = + Seq(_cachedColumnBuffers, statisticsToBePropagated) } private[sql] case class InMemoryColumnarTableScan( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6befe1b755cc6..6bf439377aa3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,11 +21,12 @@ import java.util.TimeZone import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +/* Implicits */ +import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext._ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { @@ -719,7 +720,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) } - + test("SPARK-3371 Renaming a function expression with group by gives error") { registerFunction("len", (s: String) => s.length) checkAnswer( @@ -934,7 +935,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { - checkAnswer(sql("SELECT key FROM testData WHERE key not between 0 and 10 order by key"), + checkAnswer(sql("SELECT key FROM testData WHERE key not between 0 and 10 order by key"), (11 to 100).map(i => Seq(i))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 836dd17fcc3a2..ef87a230639bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -177,4 +177,12 @@ object TestData { Salary(0, 2000.0) :: Salary(1, 1000.0) :: Nil) salary.registerTempTable("salary") + + case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean) + val complexData = + TestSQLContext.sparkContext.parallelize( + ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true) + :: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false) + :: Nil).toSchemaRDD + complexData.registerTempTable("complexData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 9775dd26b7773..15903d07df29a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.columnar +import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY class InMemoryColumnarQuerySuite extends QueryTest { - import org.apache.spark.sql.TestData._ - import org.apache.spark.sql.test.TestSQLContext._ + // Make sure the tables are loaded. + TestData test("simple columnar query") { - val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan + val plan = executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().toSeq) @@ -42,7 +43,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("projection") { - val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan + val plan = executePlan(testData.select('value, 'key).logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().map { @@ -51,7 +52,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan + val plan = executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().toSeq) @@ -63,7 +64,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq) - TestSQLContext.cacheTable("repeatedData") + cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -75,7 +76,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq) - TestSQLContext.cacheTable("nullableRepeatedData") + cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -87,7 +88,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) - TestSQLContext.cacheTable("timestamps") + cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), @@ -99,10 +100,17 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq) - TestSQLContext.cacheTable("withEmptyParts") + cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq) } + + test("SPARK-4182 Caching complex types") { + complexData.cache().count() + // Shouldn't throw + complexData.count() + complexData.unpersist() + } } From 495a132031ae002c787371f2fd0ba4be2437e7c8 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 2 Nov 2014 15:15:52 -0800 Subject: [PATCH 6/7] [SQL] Fixes race condition in CliSuite `CliSuite` has been flaky for a while, this PR tries to improve this situation by fixing a race condition in `CliSuite`. The `captureOutput` function is used to capture both stdout and stderr output of the forked external process in two background threads and search for expected strings, but wasn't been properly synchronized before. Author: Cheng Lian Closes #3060 from liancheng/fix-cli-suite and squashes the following commits: a70569c [Cheng Lian] Fixes race condition in CliSuite --- .../sql/hive/thriftserver/CliSuite.scala | 35 ++++++++----------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 8a72e9d2aef57..e8ffbc5b954d4 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -18,19 +18,17 @@ package org.apache.spark.sql.hive.thriftserver +import java.io._ + import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ -import scala.concurrent.{Await, Future, Promise} +import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} -import java.io._ -import java.util.concurrent.atomic.AtomicInteger - import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.apache.spark.{SparkException, Logging} +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.util.getTempFilePath class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { @@ -53,23 +51,20 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { """.stripMargin.split("\\s+").toSeq ++ extraArgs } - // AtomicInteger is needed because stderr and stdout of the forked process are handled in - // different threads. - val next = new AtomicInteger(0) + var next = 0 val foundAllExpectedAnswers = Promise.apply[Unit]() val queryStream = new ByteArrayInputStream(queries.mkString("\n").getBytes) val buffer = new ArrayBuffer[String]() + val lock = new Object - def captureOutput(source: String)(line: String) { + def captureOutput(source: String)(line: String): Unit = lock.synchronized { buffer += s"$source> $line" - // If we haven't found all expected answers... - if (next.get() < expectedAnswers.size) { - // If another expected answer is found... - if (line.startsWith(expectedAnswers(next.get()))) { - // If all expected answers have been found... - if (next.incrementAndGet() == expectedAnswers.size) { - foundAllExpectedAnswers.trySuccess(()) - } + // If we haven't found all expected answers and another expected answer comes up... + if (next < expectedAnswers.size && line.startsWith(expectedAnswers(next))) { + next += 1 + // If all expected answers have been found... + if (next == expectedAnswers.size) { + foundAllExpectedAnswers.trySuccess(()) } } } @@ -88,8 +83,8 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { |======================= |Spark SQL CLI command line: ${command.mkString(" ")} | - |Executed query ${next.get()} "${queries(next.get())}", - |But failed to capture expected output "${expectedAnswers(next.get())}" within $timeout. + |Executed query $next "${queries(next)}", + |But failed to capture expected output "${expectedAnswers(next)}" within $timeout. | |${buffer.mkString("\n")} |=========================== From c9f840046f8c45b1137f0289eeb0c980de72ea5e Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 2 Nov 2014 15:18:29 -0800 Subject: [PATCH 7/7] [SPARK-3791][SQL] Provides Spark version and Hive version in HiveThriftServer2 This PR overrides the `GetInfo` Hive Thrift API to provide correct version information. Another property `spark.sql.hive.version` is added to reveal the underlying Hive version. These are generally useful for Spark SQL ODBC driver providers. The Spark version information is extracted from the jar manifest. Also took the chance to remove the `SET -v` hack, which was a workaround for Simba ODBC driver connectivity. TODO - [x] Find a general way to figure out Hive (or even any dependency) version. This [blog post](http://blog.soebes.de/blog/2014/01/02/version-information-into-your-appas-with-maven/) suggests several methods to inspect application version. In the case of Spark, this can be tricky because the chosen method: 1. must applies to both Maven build and SBT build For Maven builds, we can retrieve the version information from the META-INF/maven directory within the assembly jar. But this doesn't work for SBT builds. 2. must not rely on the original jars of dependencies to extract specific dependency version, because Spark uses assembly jar. This implies we can't read Hive version from Hive jar files since standard Spark distribution doesn't include them. 3. should play well with `SPARK_PREPEND_CLASSES` to ease local testing during development. `SPARK_PREPEND_CLASSES` prevents classes to be loaded from the assembly jar, thus we can't locate the jar file and read its manifest. Given these, maybe the only reliable method is to generate a source file containing version information at build time. pwendell Do you have any suggestions from the perspective of the build process? **Update** Hive version is now retrieved from the newly introduced `HiveShim` object. Author: Cheng Lian Author: Cheng Lian Closes #2843 from liancheng/get-info and squashes the following commits: a873d0f [Cheng Lian] Updates test case 53f43cd [Cheng Lian] Retrieves underlying Hive verson via HiveShim 1d282b8 [Cheng Lian] Removes the Simba ODBC "SET -v" hack f857fce [Cheng Lian] Overrides Hive GetInfo Thrift API and adds Hive version property --- .../scala/org/apache/spark/util/Utils.scala | 8 + .../apache/spark/sql/execution/commands.scala | 69 ++++----- .../thriftserver/SparkSQLCLIService.scala | 14 +- .../sql/hive/thriftserver/SparkSQLEnv.scala | 11 +- .../thriftserver/HiveThriftServer2Suite.scala | 144 +++++++++++++----- .../apache/spark/sql/hive/HiveContext.scala | 4 +- .../sql/hive/execution/HiveQuerySuite.scala | 35 +---- 7 files changed, 173 insertions(+), 112 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 4e30d0d3813a2..b402c5f334bb0 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -20,8 +20,10 @@ package org.apache.spark.util import java.io._ import java.net._ import java.nio.ByteBuffer +import java.util.jar.Attributes.Name import java.util.{Properties, Locale, Random, UUID} import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor} +import java.util.jar.{Manifest => JarManifest} import scala.collection.JavaConversions._ import scala.collection.Map @@ -1759,6 +1761,12 @@ private[spark] object Utils extends Logging { s"$libraryPathEnvName=$libraryPath$ampersand" } + lazy val sparkVersion = + SparkContext.jarOfObject(this).map { path => + val manifestUrl = new URL(s"jar:file:$path!/META-INF/MANIFEST.MF") + val manifest = new JarManifest(manifestUrl.openStream()) + manifest.getMainAttributes.getValue(Name.IMPLEMENTATION_VERSION) + }.getOrElse("Unknown") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index e658e6fc4d5d5..f23b9c48cfb40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -84,50 +84,35 @@ case class SetCommand(kv: Option[(String, Option[String])], output: Seq[Attribut extends LeafNode with Command with Logging { override protected lazy val sideEffectResult: Seq[Row] = kv match { - // Set value for the key. - case Some((key, Some(value))) => - if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { - logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + // Configures the deprecated "mapred.reduce.tasks" property. + case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, Some(value))) => + logWarning( + s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") - context.setConf(SQLConf.SHUFFLE_PARTITIONS, value) - Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value")) - } else { - context.setConf(key, value) - Seq(Row(s"$key=$value")) - } - - // Query the value bound to the key. + context.setConf(SQLConf.SHUFFLE_PARTITIONS, value) + Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value")) + + // Configures a single property. + case Some((key, Some(value))) => + context.setConf(key, value) + Seq(Row(s"$key=$value")) + + // Queries all key-value pairs that are set in the SQLConf of the context. Notice that different + // from Hive, here "SET -v" is an alias of "SET". (In Hive, "SET" returns all changed properties + // while "SET -v" returns all properties.) + case Some(("-v", None)) | None => + context.getAllConfs.map { case (k, v) => Row(s"$k=$v") }.toSeq + + // Queries the deprecated "mapred.reduce.tasks" property. + case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, None)) => + logWarning( + s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.") + Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}")) + + // Queries a single property. case Some((key, None)) => - // TODO (lian) This is just a workaround to make the Simba ODBC driver work. - // Should remove this once we get the ODBC driver updated. - if (key == "-v") { - val hiveJars = Seq( - "hive-exec-0.12.0.jar", - "hive-service-0.12.0.jar", - "hive-common-0.12.0.jar", - "hive-hwi-0.12.0.jar", - "hive-0.12.0.jar").mkString(":") - - context.getAllConfs.map { case (k, v) => - Row(s"$k=$v") - }.toSeq ++ Seq( - Row("system:java.class.path=" + hiveJars), - Row("system:sun.java.command=shark.SharkServer2")) - } else { - if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { - logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + - s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.") - Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}")) - } else { - Seq(Row(s"$key=${context.getConf(key, "")}")) - } - } - - // Query all key-value pairs that are set in the SQLConf of the context. - case _ => - context.getAllConfs.map { case (k, v) => - Row(s"$k=$v") - }.toSeq + Seq(Row(s"$key=${context.getConf(key, "")}")) } override def otherCopyArgs = context :: Nil diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index a78311fc48635..ecfb74473e921 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.thriftserver +import java.util.jar.Attributes.Name + import scala.collection.JavaConversions._ import java.io.IOException @@ -29,11 +31,12 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.shims.ShimLoader import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory -import org.apache.hive.service.cli.CLIService +import org.apache.hive.service.cli._ import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.util.Utils private[hive] class SparkSQLCLIService(hiveContext: HiveContext) extends CLIService @@ -60,6 +63,15 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) initCompositeService(hiveConf) } + + override def getInfo(sessionHandle: SessionHandle, getInfoType: GetInfoType): GetInfoValue = { + getInfoType match { + case GetInfoType.CLI_SERVER_NAME => new GetInfoValue("Spark SQL") + case GetInfoType.CLI_DBMS_NAME => new GetInfoValue("Spark SQL") + case GetInfoType.CLI_DBMS_VER => new GetInfoValue(Utils.sparkVersion) + case _ => super.getInfo(sessionHandle, getInfoType) + } + } } private[thriftserver] trait ReflectedCompositeService { this: AbstractService => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 50425863518c3..89732c939b0ec 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.hive.thriftserver +import scala.collection.JavaConversions._ + import org.apache.spark.scheduler.StatsReportListener -import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.{HiveShim, HiveContext} import org.apache.spark.{Logging, SparkConf, SparkContext} -import scala.collection.JavaConversions._ /** A singleton object for the master program. The slaves should not access this. */ private[hive] object SparkSQLEnv extends Logging { @@ -31,8 +32,10 @@ private[hive] object SparkSQLEnv extends Logging { def init() { if (hiveContext == null) { - sparkContext = new SparkContext(new SparkConf() - .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}")) + val sparkConf = new SparkConf() + .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}") + .set("spark.sql.hive.version", HiveShim.version) + sparkContext = new SparkContext(sparkConf) sparkContext.addSparkListener(new StatsReportListener()) hiveContext = new HiveContext(sparkContext) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index c60e8fa5b1259..65d910a0c3ffc 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -30,42 +30,95 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.jdbc.HiveDriver +import org.apache.hive.service.auth.PlainSaslHelper +import org.apache.hive.service.cli.GetInfoType +import org.apache.hive.service.cli.thrift.TCLIService.Client +import org.apache.hive.service.cli.thrift._ +import org.apache.thrift.protocol.TBinaryProtocol +import org.apache.thrift.transport.TSocket import org.scalatest.FunSuite import org.apache.spark.Logging import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.hive.HiveShim /** * Tests for the HiveThriftServer2 using JDBC. + * + * NOTE: SPARK_PREPEND_CLASSES is explicitly disabled in this test suite. Assembly jar must be + * rebuilt after changing HiveThriftServer2 related code. */ class HiveThriftServer2Suite extends FunSuite with Logging { Class.forName(classOf[HiveDriver].getCanonicalName) - def startThriftServerWithin(timeout: FiniteDuration = 1.minute)(f: Statement => Unit) { + def randomListeningPort = { + // Let the system to choose a random available port to avoid collision with other parallel + // builds. + val socket = new ServerSocket(0) + val port = socket.getLocalPort + socket.close() + port + } + + def withJdbcStatement(serverStartTimeout: FiniteDuration = 1.minute)(f: Statement => Unit) { + val port = randomListeningPort + + startThriftServer(port, serverStartTimeout) { + val jdbcUri = s"jdbc:hive2://${"localhost"}:$port/" + val user = System.getProperty("user.name") + val connection = DriverManager.getConnection(jdbcUri, user, "") + val statement = connection.createStatement() + + try { + f(statement) + } finally { + statement.close() + connection.close() + } + } + } + + def withCLIServiceClient( + serverStartTimeout: FiniteDuration = 1.minute)( + f: ThriftCLIServiceClient => Unit) { + val port = randomListeningPort + + startThriftServer(port) { + // Transport creation logics below mimics HiveConnection.createBinaryTransport + val rawTransport = new TSocket("localhost", port) + val user = System.getProperty("user.name") + val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) + val protocol = new TBinaryProtocol(transport) + val client = new ThriftCLIServiceClient(new Client(protocol)) + + transport.open() + + try { + f(client) + } finally { + transport.close() + } + } + } + + def startThriftServer( + port: Int, + serverStartTimeout: FiniteDuration = 1.minute)( + f: => Unit) { val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator) val warehousePath = getTempFilePath("warehouse") val metastorePath = getTempFilePath("metastore") val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" - val listeningHost = "localhost" - val listeningPort = { - // Let the system to choose a random available port to avoid collision with other parallel - // builds. - val socket = new ServerSocket(0) - val port = socket.getLocalPort - socket.close() - port - } - val command = s"""$startScript | --master local | --hiveconf hive.root.logger=INFO,console | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$listeningHost - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$listeningPort + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=${"localhost"} + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port """.stripMargin.split("\\s+").toSeq val serverRunning = Promise[Unit]() @@ -92,31 +145,25 @@ class HiveThriftServer2Suite extends FunSuite with Logging { } } - // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths - Process(command, None, "SPARK_TESTING" -> "0").run(ProcessLogger( + val env = Seq( + // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths + "SPARK_TESTING" -> "0", + // Prevents loading classes out of the assembly jar. Otherwise Utils.sparkVersion can't read + // proper version information from the jar manifest. + "SPARK_PREPEND_CLASSES" -> "") + + Process(command, None, env: _*).run(ProcessLogger( captureThriftServerOutput("stdout"), captureThriftServerOutput("stderr"))) - val jdbcUri = s"jdbc:hive2://$listeningHost:$listeningPort/" - val user = System.getProperty("user.name") - try { - Await.result(serverRunning.future, timeout) - - val connection = DriverManager.getConnection(jdbcUri, user, "") - val statement = connection.createStatement() - - try { - f(statement) - } finally { - statement.close() - connection.close() - } + Await.result(serverRunning.future, serverStartTimeout) + f } catch { case cause: Exception => cause match { case _: TimeoutException => - logError(s"Failed to start Hive Thrift server within $timeout", cause) + logError(s"Failed to start Hive Thrift server within $serverStartTimeout", cause) case _ => } logError( @@ -125,8 +172,8 @@ class HiveThriftServer2Suite extends FunSuite with Logging { |HiveThriftServer2Suite failure output |===================================== |HiveThriftServer2 command line: ${command.mkString(" ")} - |JDBC URI: $jdbcUri - |User: $user + |Binding port: $port + |System user: ${System.getProperty("user.name")} | |${buffer.mkString("\n")} |========================================= @@ -146,7 +193,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging { } test("Test JDBC query execution") { - startThriftServerWithin() { statement => + withJdbcStatement() { statement => val dataFilePath = Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") @@ -168,7 +215,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging { } test("SPARK-3004 regression: result set containing NULL") { - startThriftServerWithin() { statement => + withJdbcStatement() { statement => val dataFilePath = Thread.currentThread().getContextClassLoader.getResource( "data/files/small_kv_with_null.txt") @@ -191,4 +238,33 @@ class HiveThriftServer2Suite extends FunSuite with Logging { assert(!resultSet.next()) } } + + test("GetInfo Thrift API") { + withCLIServiceClient() { client => + val user = System.getProperty("user.name") + val sessionHandle = client.openSession(user, "") + + assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") { + client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue + } + + assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") { + client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue + } + + assertResult(true, "Spark version shouldn't be \"Unknown\"") { + val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue + logInfo(s"Spark version: $version") + version != "Unknown" + } + } + } + + test("Checks Hive version") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("SET spark.sql.hive.version") + resultSet.next() + assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index dca5367f244de..0fe59f42f21ff 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -323,7 +323,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { driver.close() HiveShim.processResults(results) case _ => - sessionState.out.println(tokens(0) + " " + cmd_1) + if (sessionState.out != null) { + sessionState.out.println(tokens(0) + " " + cmd_1) + } Seq(proc.run(cmd_1).getResponseCode.toString) } } catch { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 5918f888c8f4c..b897dff0159ff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -769,7 +769,7 @@ class HiveQuerySuite extends HiveComparisonTest { }.toSet clear() - // "set" itself returns all config variables currently specified in SQLConf. + // "SET" itself returns all config variables currently specified in SQLConf. // TODO: Should we be listing the default here always? probably... assert(sql("SET").collect().size == 0) @@ -778,44 +778,19 @@ class HiveQuerySuite extends HiveComparisonTest { } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal)) { - collectResults(sql("SET")) - } + assertResult(Set(testKey -> testVal))(collectResults(sql("SET"))) + assertResult(Set(testKey -> testVal))(collectResults(sql("SET -v"))) sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { collectResults(sql("SET")) } - - // "set key" - assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey")) - } - - assertResult(Set(nonexistentKey -> "")) { - collectResults(sql(s"SET $nonexistentKey")) - } - - // Assert that sql() should have the same effects as sql() by repeating the above using sql(). - clear() - assert(sql("SET").collect().size == 0) - - assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey=$testVal")) - } - - assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal)) { - collectResults(sql("SET")) - } - - sql(s"SET ${testKey + testKey}=${testVal + testVal}") - assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(sql("SET")) + collectResults(sql("SET -v")) } + // "SET key" assertResult(Set(testKey -> testVal)) { collectResults(sql(s"SET $testKey")) }