diff --git a/bin/spark-submit b/bin/spark-submit index aefd38a0a2b90..3e5cbdbb24394 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -44,7 +44,10 @@ while (($#)); do shift done -DEFAULT_PROPERTIES_FILE="$SPARK_HOME/conf/spark-defaults.conf" +if [ -z "$SPARK_CONF_DIR" ]; then + export SPARK_CONF_DIR="$SPARK_HOME/conf" +fi +DEFAULT_PROPERTIES_FILE="$SPARK_CONF_DIR/spark-defaults.conf" if [ "$MASTER" == "yarn-cluster" ]; then SPARK_SUBMIT_DEPLOY_MODE=cluster fi diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd index daf0284db9230..12244a9cb04fb 100644 --- a/bin/spark-submit2.cmd +++ b/bin/spark-submit2.cmd @@ -24,7 +24,11 @@ set ORIG_ARGS=%* rem Reset the values of all variables used set SPARK_SUBMIT_DEPLOY_MODE=client -set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf + +if not defined %SPARK_CONF_DIR% ( + set SPARK_CONF_DIR=%SPARK_HOME%\conf +) +set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_CONF_DIR%\spark-defaults.conf set SPARK_SUBMIT_DRIVER_MEMORY= set SPARK_SUBMIT_LIBRARY_PATH= set SPARK_SUBMIT_CLASSPATH= diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index e9e90e3f2f65a..a0ee2a7cbb2a2 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -65,6 +65,9 @@ private[spark] class ExecutorAllocationManager( listenerBus: LiveListenerBus, conf: SparkConf) extends Logging { + + allocationManager => + import ExecutorAllocationManager._ // Lower and upper bounds on the number of executors. These are required. @@ -121,7 +124,7 @@ private[spark] class ExecutorAllocationManager( private var clock: Clock = new RealClock // Listener for Spark events that impact the allocation policy - private val listener = new ExecutorAllocationListener(this) + private val listener = new ExecutorAllocationListener /** * Verify that the settings specified through the config are valid. @@ -209,11 +212,12 @@ private[spark] class ExecutorAllocationManager( addTime += sustainedSchedulerBacklogTimeout * 1000 } - removeTimes.foreach { case (executorId, expireTime) => - if (now >= expireTime) { + removeTimes.retain { case (executorId, expireTime) => + val expired = now >= expireTime + if (expired) { removeExecutor(executorId) - removeTimes.remove(executorId) } + !expired } } @@ -291,7 +295,7 @@ private[spark] class ExecutorAllocationManager( // Do not kill the executor if we have already reached the lower bound val numExistingExecutors = executorIds.size - executorsPendingToRemove.size if (numExistingExecutors - 1 < minNumExecutors) { - logInfo(s"Not removing idle executor $executorId because there are only " + + logDebug(s"Not removing idle executor $executorId because there are only " + s"$numExistingExecutors executor(s) left (limit $minNumExecutors)") return false } @@ -315,7 +319,11 @@ private[spark] class ExecutorAllocationManager( private def onExecutorAdded(executorId: String): Unit = synchronized { if (!executorIds.contains(executorId)) { executorIds.add(executorId) - executorIds.foreach(onExecutorIdle) + // If an executor (call this executor X) is not removed because the lower bound + // has been reached, it will no longer be marked as idle. When new executors join, + // however, we are no longer at the lower bound, and so we must mark executor X + // as idle again so as not to forget that it is a candidate for removal. (see SPARK-4951) + executorIds.filter(listener.isExecutorIdle).foreach(onExecutorIdle) logInfo(s"New executor $executorId has registered (new total is ${executorIds.size})") if (numExecutorsPending > 0) { numExecutorsPending -= 1 @@ -373,10 +381,14 @@ private[spark] class ExecutorAllocationManager( * the executor is not already marked as idle. */ private def onExecutorIdle(executorId: String): Unit = synchronized { - if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { - logDebug(s"Starting idle timer for $executorId because there are no more tasks " + - s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)") - removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000 + if (executorIds.contains(executorId)) { + if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { + logDebug(s"Starting idle timer for $executorId because there are no more tasks " + + s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)") + removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000 + } + } else { + logWarning(s"Attempted to mark unknown executor $executorId idle") } } @@ -396,25 +408,24 @@ private[spark] class ExecutorAllocationManager( * and consistency of events returned by the listener. For simplicity, it does not account * for speculated tasks. */ - private class ExecutorAllocationListener(allocationManager: ExecutorAllocationManager) - extends SparkListener { + private class ExecutorAllocationListener extends SparkListener { private val stageIdToNumTasks = new mutable.HashMap[Int, Int] private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]] private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]] override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { - synchronized { - val stageId = stageSubmitted.stageInfo.stageId - val numTasks = stageSubmitted.stageInfo.numTasks + val stageId = stageSubmitted.stageInfo.stageId + val numTasks = stageSubmitted.stageInfo.numTasks + allocationManager.synchronized { stageIdToNumTasks(stageId) = numTasks allocationManager.onSchedulerBacklogged() } } override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { - synchronized { - val stageId = stageCompleted.stageInfo.stageId + val stageId = stageCompleted.stageInfo.stageId + allocationManager.synchronized { stageIdToNumTasks -= stageId stageIdToTaskIndices -= stageId @@ -426,39 +437,49 @@ private[spark] class ExecutorAllocationManager( } } - override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { val stageId = taskStart.stageId val taskId = taskStart.taskInfo.taskId val taskIndex = taskStart.taskInfo.index val executorId = taskStart.taskInfo.executorId - // If this is the last pending task, mark the scheduler queue as empty - stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex - val numTasksScheduled = stageIdToTaskIndices(stageId).size - val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1) - if (numTasksScheduled == numTasksTotal) { - // No more pending tasks for this stage - stageIdToNumTasks -= stageId - if (stageIdToNumTasks.isEmpty) { - allocationManager.onSchedulerQueueEmpty() + allocationManager.synchronized { + // This guards against the race condition in which the `SparkListenerTaskStart` + // event is posted before the `SparkListenerBlockManagerAdded` event, which is + // possible because these events are posted in different threads. (see SPARK-4951) + if (!allocationManager.executorIds.contains(executorId)) { + allocationManager.onExecutorAdded(executorId) + } + + // If this is the last pending task, mark the scheduler queue as empty + stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex + val numTasksScheduled = stageIdToTaskIndices(stageId).size + val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1) + if (numTasksScheduled == numTasksTotal) { + // No more pending tasks for this stage + stageIdToNumTasks -= stageId + if (stageIdToNumTasks.isEmpty) { + allocationManager.onSchedulerQueueEmpty() + } } - } - // Mark the executor on which this task is scheduled as busy - executorIdToTaskIds.getOrElseUpdate(executorId, new mutable.HashSet[Long]) += taskId - allocationManager.onExecutorBusy(executorId) + // Mark the executor on which this task is scheduled as busy + executorIdToTaskIds.getOrElseUpdate(executorId, new mutable.HashSet[Long]) += taskId + allocationManager.onExecutorBusy(executorId) + } } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { val executorId = taskEnd.taskInfo.executorId val taskId = taskEnd.taskInfo.taskId - - // If the executor is no longer running scheduled any tasks, mark it as idle - if (executorIdToTaskIds.contains(executorId)) { - executorIdToTaskIds(executorId) -= taskId - if (executorIdToTaskIds(executorId).isEmpty) { - executorIdToTaskIds -= executorId - allocationManager.onExecutorIdle(executorId) + allocationManager.synchronized { + // If the executor is no longer running scheduled any tasks, mark it as idle + if (executorIdToTaskIds.contains(executorId)) { + executorIdToTaskIds(executorId) -= taskId + if (executorIdToTaskIds(executorId).isEmpty) { + executorIdToTaskIds -= executorId + allocationManager.onExecutorIdle(executorId) + } } } } @@ -466,7 +487,12 @@ private[spark] class ExecutorAllocationManager( override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { val executorId = blockManagerAdded.blockManagerId.executorId if (executorId != SparkContext.DRIVER_IDENTIFIER) { - allocationManager.onExecutorAdded(executorId) + // This guards against the race condition in which the `SparkListenerTaskStart` + // event is posted before the `SparkListenerBlockManagerAdded` event, which is + // possible because these events are posted in different threads. (see SPARK-4951) + if (!allocationManager.executorIds.contains(executorId)) { + allocationManager.onExecutorAdded(executorId) + } } } @@ -478,12 +504,23 @@ private[spark] class ExecutorAllocationManager( /** * An estimate of the total number of pending tasks remaining for currently running stages. Does * not account for tasks which may have failed and been resubmitted. + * + * Note: This is not thread-safe without the caller owning the `allocationManager` lock. */ def totalPendingTasks(): Int = { stageIdToNumTasks.map { case (stageId, numTasks) => numTasks - stageIdToTaskIndices.get(stageId).map(_.size).getOrElse(0) }.sum } + + /** + * Return true if an executor is not currently running a task, and false otherwise. + * + * Note: This is not thread-safe without the caller owning the `allocationManager` lock. + */ + def isExecutorIdle(executorId: String): Boolean = { + !executorIdToTaskIds.contains(executorId) + } } } diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index edc3889c9ae51..677c5e0f89d72 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -24,6 +24,7 @@ import com.google.common.io.Files import org.apache.spark.util.Utils private[spark] class HttpFileServer( + conf: SparkConf, securityManager: SecurityManager, requestedPort: Int = 0) extends Logging { @@ -41,7 +42,7 @@ private[spark] class HttpFileServer( fileDir.mkdir() jarDir.mkdir() logInfo("HTTP File server directory is " + baseDir) - httpServer = new HttpServer(baseDir, securityManager, requestedPort, "HTTP file server") + httpServer = new HttpServer(conf, baseDir, securityManager, requestedPort, "HTTP file server") httpServer.start() serverUri = httpServer.uri logDebug("HTTP file server started at: " + serverUri) diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 912558d0cab7d..fa22787ce7ea3 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -42,6 +42,7 @@ private[spark] class ServerStateException(message: String) extends Exception(mes * around a Jetty server. */ private[spark] class HttpServer( + conf: SparkConf, resourceBase: File, securityManager: SecurityManager, requestedPort: Int = 0, @@ -57,7 +58,7 @@ private[spark] class HttpServer( } else { logInfo("Starting HTTP Server") val (actualServer, actualPort) = - Utils.startServiceOnPort[Server](requestedPort, doStart, serverName) + Utils.startServiceOnPort[Server](requestedPort, doStart, conf, serverName) server = actualServer port = actualPort } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index c14764f773982..a0ce107f43b16 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -370,7 +370,9 @@ private[spark] object SparkConf { } /** - * Return whether the given config is a Spark port config. + * Return true if the given config matches either `spark.*.port` or `spark.port.*`. */ - def isSparkPortConf(name: String): Boolean = name.startsWith("spark.") && name.endsWith(".port") + def isSparkPortConf(name: String): Boolean = { + (name.startsWith("spark.") && name.endsWith(".port")) || name.startsWith("spark.port.") + } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3bf3acd245d8f..ff5d796ee2766 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -458,7 +458,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli Option(localProperties.get).map(_.getProperty(key)).getOrElse(null) /** Set a human readable description of the current job. */ - @deprecated("use setJobGroup", "0.8.1") def setJobDescription(value: String) { setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value) } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 43436a1697000..4d418037bd33f 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -312,7 +312,7 @@ object SparkEnv extends Logging { val httpFileServer = if (isDriver) { val fileServerPort = conf.getInt("spark.fileserver.port", 0) - val server = new HttpFileServer(securityManager, fileServerPort) + val server = new HttpFileServer(conf, securityManager, fileServerPort) server.initialize() conf.set("spark.fileserver.uri", server.serverUri) server diff --git a/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala b/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala new file mode 100644 index 0000000000000..9df61062e1f85 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.annotation.DeveloperApi + +/** + * Exception thrown when a task cannot be serialized. + */ +private[spark] class TaskNotSerializableException(error: Throwable) extends Exception(error) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 31f0a462f84d8..31d6958c403b3 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -153,7 +153,8 @@ private[broadcast] object HttpBroadcast extends Logging { private def createServer(conf: SparkConf) { broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf)) val broadcastPort = conf.getInt("spark.broadcast.port", 0) - server = new HttpServer(broadcastDir, securityManager, broadcastPort, "HTTP broadcast server") + server = + new HttpServer(conf, broadcastDir, securityManager, broadcastPort, "HTTP broadcast server") server.start() serverUri = server.uri logInfo("Broadcast server started at " + serverUri) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 1faabe91f49a8..f14ef4d299383 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -405,7 +405,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | --queue QUEUE_NAME The YARN queue to submit to (Default: "default"). | --num-executors NUM Number of executors to launch (Default: 2). | --archives ARCHIVES Comma separated list of archives to be extracted into the - | working directory of each executor.""".stripMargin + | working directory of each executor. + """.stripMargin ) SparkSubmit.exitFn() } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 3340fca08014e..03c4137ca0a81 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -174,7 +174,7 @@ private[nio] class ConnectionManager( serverChannel.socket.bind(new InetSocketAddress(port)) (serverChannel, serverChannel.socket.getLocalPort) } - Utils.startServiceOnPort[ServerSocketChannel](port, startService, name) + Utils.startServiceOnPort[ServerSocketChannel](port, startService, conf, name) serverChannel.register(selector, SelectionKey.OP_ACCEPT) val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 259621d263d7c..61d09d73e17cb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -866,26 +866,6 @@ class DAGScheduler( } if (tasks.size > 0) { - // Preemptively serialize a task to make sure it can be serialized. We are catching this - // exception here because it would be fairly hard to catch the non-serializable exception - // down the road, where we have several different implementations for local scheduler and - // cluster schedulers. - // - // We've already serialized RDDs and closures in taskBinary, but here we check for all other - // objects such as Partition. - try { - closureSerializer.serialize(tasks.head) - } catch { - case e: NotSerializableException => - abortStage(stage, "Task not serializable: " + e.toString) - runningStages -= stage - return - case NonFatal(e) => // Other exceptions, such as IllegalArgumentException from Kryo. - abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") - runningStages -= stage - return - } - logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") stage.pendingTasks ++= tasks logDebug("New pending tasks: " + stage.pendingTasks) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index a41f3eef195d2..a1dfb01062591 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -31,6 +31,7 @@ import scala.util.Random import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.SchedulingMode.SchedulingMode +import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.util.Utils import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId @@ -209,6 +210,40 @@ private[spark] class TaskSchedulerImpl( .format(manager.taskSet.id, manager.parent.name)) } + private def resourceOfferSingleTaskSet( + taskSet: TaskSetManager, + maxLocality: TaskLocality, + shuffledOffers: Seq[WorkerOffer], + availableCpus: Array[Int], + tasks: Seq[ArrayBuffer[TaskDescription]]) : Boolean = { + var launchedTask = false + for (i <- 0 until shuffledOffers.size) { + val execId = shuffledOffers(i).executorId + val host = shuffledOffers(i).host + if (availableCpus(i) >= CPUS_PER_TASK) { + try { + for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { + tasks(i) += task + val tid = task.taskId + taskIdToTaskSetId(tid) = taskSet.taskSet.id + taskIdToExecutorId(tid) = execId + executorsByHost(host) += execId + availableCpus(i) -= CPUS_PER_TASK + assert(availableCpus(i) >= 0) + launchedTask = true + } + } catch { + case e: TaskNotSerializableException => + logError(s"Resource offer failed, task set ${taskSet.name} was not serializable") + // Do not offer resources for this task, but don't throw an error to allow other + // task sets to be submitted. + return launchedTask + } + } + } + return launchedTask + } + /** * Called by cluster manager to offer resources on slaves. We respond by asking our active task * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so @@ -251,23 +286,8 @@ private[spark] class TaskSchedulerImpl( var launchedTask = false for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) { do { - launchedTask = false - for (i <- 0 until shuffledOffers.size) { - val execId = shuffledOffers(i).executorId - val host = shuffledOffers(i).host - if (availableCpus(i) >= CPUS_PER_TASK) { - for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { - tasks(i) += task - val tid = task.taskId - taskIdToTaskSetId(tid) = taskSet.taskSet.id - taskIdToExecutorId(tid) = execId - executorsByHost(host) += execId - availableCpus(i) -= CPUS_PER_TASK - assert(availableCpus(i) >= 0) - launchedTask = true - } - } - } + launchedTask = resourceOfferSingleTaskSet( + taskSet, maxLocality, shuffledOffers, availableCpus, tasks) } while (launchedTask) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 28e6147509f78..4667850917151 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -18,12 +18,14 @@ package org.apache.spark.scheduler import java.io.NotSerializableException +import java.nio.ByteBuffer import java.util.Arrays import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet import scala.math.{min, max} +import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.executor.TaskMetrics @@ -417,6 +419,7 @@ private[spark] class TaskSetManager( * @param host the host Id of the offered resource * @param maxLocality the maximum locality we want to schedule the tasks at */ + @throws[TaskNotSerializableException] def resourceOffer( execId: String, host: String, @@ -456,10 +459,17 @@ private[spark] class TaskSetManager( } // Serialize and return the task val startTime = clock.getTime() - // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here - // we assume the task can be serialized without exceptions. - val serializedTask = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) + val serializedTask: ByteBuffer = try { + Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) + } catch { + // If the task cannot be serialized, then there's no point to re-attempt the task, + // as it will always fail. So just abort the whole task-set. + case NonFatal(e) => + val msg = s"Failed to serialize task $taskId, not attempting to retry it." + logError(msg, e) + abort(s"$msg Exception during serialization: $e") + throw new TaskNotSerializableException(e) + } if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 && !emittedTaskSizeWarning) { emittedTaskSizeWarning = true diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 50721b9d6cd6c..f14aaeea0a25c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler.cluster +import scala.concurrent.{Future, ExecutionContext} + import akka.actor.{Actor, ActorRef, Props} import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} @@ -24,7 +26,9 @@ import org.apache.spark.SparkContext import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.ui.JettyUtils -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.{AkkaUtils, Utils} + +import scala.util.control.NonFatal /** * Abstract Yarn scheduler backend that contains common logic @@ -97,6 +101,9 @@ private[spark] abstract class YarnSchedulerBackend( private class YarnSchedulerActor extends Actor { private var amActor: Option[ActorRef] = None + implicit val askAmActorExecutor = ExecutionContext.fromExecutor( + Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-executor")) + override def preStart(): Unit = { // Listen for disassociation events context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) @@ -110,7 +117,12 @@ private[spark] abstract class YarnSchedulerBackend( case r: RequestExecutors => amActor match { case Some(actor) => - sender ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout) + val driverActor = sender + Future { + driverActor ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout) + } onFailure { + case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e) + } case None => logWarning("Attempted to request executors before the AM has registered!") sender ! false @@ -119,7 +131,12 @@ private[spark] abstract class YarnSchedulerBackend( case k: KillExecutors => amActor match { case Some(actor) => - sender ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout) + val driverActor = sender + Future { + driverActor ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout) + } onFailure { + case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e) + } case None => logWarning("Attempted to kill executors before the AM has registered!") sender ! false diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index d2947dcea4f7c..d56e23ce4478a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -29,7 +29,7 @@ import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.HttpBroadcast import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock} -import org.apache.spark.scheduler.MapStatus +import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ import org.apache.spark.util.BoundedPriorityQueue import org.apache.spark.util.collection.CompactBuffer @@ -207,7 +207,8 @@ private[serializer] object KryoSerializer { classOf[PutBlock], classOf[GotBlock], classOf[GetBlock], - classOf[MapStatus], + classOf[CompressedMapStatus], + classOf[HighlyCompressedMapStatus], classOf[CompactBuffer[_]], classOf[BlockManagerId], classOf[Array[Byte]], diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 8dadf6794039e..61ef5ff168791 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -31,7 +31,8 @@ import org.apache.spark.util.Utils private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager) extends BlockStore(blockManager) with Logging { - val minMemoryMapBytes = blockManager.conf.getLong("spark.storage.memoryMapThreshold", 2 * 4096L) + val minMemoryMapBytes = blockManager.conf.getLong( + "spark.storage.memoryMapThreshold", 2 * 1024L * 1024L) override def getSize(blockId: BlockId): Long = { diskManager.getFile(blockId.name).length diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 2a27d49d2de05..88fed833f922d 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -201,7 +201,7 @@ private[spark] object JettyUtils extends Logging { } } - val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, serverName) + val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName) ServerInfo(server, boundPort, collection) } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index db2531dc171f8..4c9b1e3c46f0f 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -53,7 +53,7 @@ private[spark] object AkkaUtils extends Logging { val startService: Int => (ActorSystem, Int) = { actualPort => doCreateActorSystem(name, host, actualPort, conf, securityManager) } - Utils.startServiceOnPort(port, startService, name) + Utils.startServiceOnPort(port, startService, conf, name) } private def doCreateActorSystem( diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c4f1898a2db15..2c04e4ddfbcb7 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -701,7 +701,7 @@ private[spark] object Utils extends Logging { } } - private var customHostname: Option[String] = None + private var customHostname: Option[String] = sys.env.get("SPARK_LOCAL_HOSTNAME") /** * Allow setting a custom host name because when we run on Mesos we need to use the same @@ -1690,17 +1690,15 @@ private[spark] object Utils extends Logging { } /** - * Default maximum number of retries when binding to a port before giving up. + * Maximum number of retries when binding to a port before giving up. */ - val portMaxRetries: Int = { - if (sys.props.contains("spark.testing")) { + def portMaxRetries(conf: SparkConf): Int = { + val maxRetries = conf.getOption("spark.port.maxRetries").map(_.toInt) + if (conf.contains("spark.testing")) { // Set a higher number of retries for tests... - sys.props.get("spark.port.maxRetries").map(_.toInt).getOrElse(100) + maxRetries.getOrElse(100) } else { - Option(SparkEnv.get) - .flatMap(_.conf.getOption("spark.port.maxRetries")) - .map(_.toInt) - .getOrElse(16) + maxRetries.getOrElse(16) } } @@ -1709,17 +1707,18 @@ private[spark] object Utils extends Logging { * Each subsequent attempt uses 1 + the port used in the previous attempt (unless the port is 0). * * @param startPort The initial port to start the service on. - * @param maxRetries Maximum number of retries to attempt. - * A value of 3 means attempting ports n, n+1, n+2, and n+3, for example. * @param startService Function to start service on a given port. * This is expected to throw java.net.BindException on port collision. + * @param conf A SparkConf used to get the maximum number of retries when binding to a port. + * @param serviceName Name of the service. */ def startServiceOnPort[T]( startPort: Int, startService: Int => (T, Int), - serviceName: String = "", - maxRetries: Int = portMaxRetries): (T, Int) = { + conf: SparkConf, + serviceName: String = ""): (T, Int) = { val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'" + val maxRetries = portMaxRetries(conf) for (offset <- 0 to maxRetries) { // Do not increment port if startPort is 0, which is treated as a special port val tryPort = if (startPort == 0) { diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index c817f6dcede75..0e4df17c1bf87 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark +import scala.collection.mutable + import org.scalatest.{FunSuite, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ @@ -143,11 +145,17 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { // Verify that running a task reduces the cap sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 3))) + sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( + 0L, BlockManagerId("executor-1", "host1", 1), 100L)) sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) + assert(numExecutorsPending(manager) === 4) assert(addExecutors(manager) === 1) - assert(numExecutorsPending(manager) === 6) + assert(numExecutorsPending(manager) === 5) assert(numExecutorsToAdd(manager) === 2) - assert(addExecutors(manager) === 1) + assert(addExecutors(manager) === 2) + assert(numExecutorsPending(manager) === 7) + assert(numExecutorsToAdd(manager) === 4) + assert(addExecutors(manager) === 0) assert(numExecutorsPending(manager) === 7) assert(numExecutorsToAdd(manager) === 1) @@ -325,6 +333,8 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { val manager = sc.executorAllocationManager.get manager.setClock(clock) + executorIds(manager).asInstanceOf[mutable.Set[String]] ++= List("1", "2", "3") + // Starting remove timer is idempotent for each executor assert(removeTimes(manager).isEmpty) onExecutorIdle(manager, "1") @@ -597,6 +607,41 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).size === 1) } + test("SPARK-4951: call onTaskStart before onBlockManagerAdded") { + sc = createSparkContext(2, 10) + val manager = sc.executorAllocationManager.get + assert(executorIds(manager).isEmpty) + assert(removeTimes(manager).isEmpty) + + sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) + sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( + 0L, BlockManagerId("executor-1", "host1", 1), 100L)) + assert(executorIds(manager).size === 1) + assert(executorIds(manager).contains("executor-1")) + assert(removeTimes(manager).size === 0) + } + + test("SPARK-4951: onExecutorAdded should not add a busy executor to removeTimes") { + sc = createSparkContext(2, 10) + val manager = sc.executorAllocationManager.get + assert(executorIds(manager).isEmpty) + assert(removeTimes(manager).isEmpty) + sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( + 0L, BlockManagerId("executor-1", "host1", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) + + assert(executorIds(manager).size === 1) + assert(executorIds(manager).contains("executor-1")) + assert(removeTimes(manager).size === 0) + + sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( + 0L, BlockManagerId("executor-2", "host1", 1), 100L)) + assert(executorIds(manager).size === 2) + assert(executorIds(manager).contains("executor-2")) + assert(removeTimes(manager).size === 1) + assert(removeTimes(manager).contains("executor-2")) + assert(!removeTimes(manager).contains("executor-1")) + } } /** diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 0b6511a80df1d..3d2700b7e6be4 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -30,7 +30,7 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => var conf = new SparkConf(false) override def beforeAll() { - _sc = new SparkContext("local", "test", conf) + _sc = new SparkContext("local[4]", "test", conf) super.beforeAll() } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 6836e9ab0fd6b..0deb9b18b8688 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.rdd +import java.io.{ObjectInputStream, ObjectOutputStream, IOException} + +import com.esotericsoftware.kryo.KryoException + import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -887,6 +891,23 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(ancestors6.count(_.isInstanceOf[CyclicalDependencyRDD[_]]) === 3) } + test("task serialization exception should not hang scheduler") { + class BadSerializable extends Serializable { + @throws(classOf[IOException]) + private def writeObject(out: ObjectOutputStream): Unit = throw new KryoException("Bad serialization") + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = {} + } + // Note that in the original bug, SPARK-4349, that this verifies, the job would only hang if there were + // more threads in the Spark Context than there were number of objects in this sequence. + intercept[Throwable] { + sc.parallelize(Seq(new BadSerializable, new BadSerializable)).collect + } + // Check that the context has not crashed + sc.parallelize(1 to 100).map(x => x*2).collect + } + /** A contrived RDD that allows the manual addition of dependencies after creation. */ private class CyclicalDependencyRDD[T: ClassTag] extends RDD[T](sc, Nil) { private val mutableDependencies: ArrayBuffer[Dependency[_]] = ArrayBuffer.empty diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala new file mode 100644 index 0000000000000..6b75c98839e03 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io.{ObjectInputStream, ObjectOutputStream, IOException} + +import org.apache.spark.TaskContext + +/** + * A Task implementation that fails to serialize. + */ +private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) extends Task[Array[Byte]](stageId, 0) { + override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] + override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() + + @throws(classOf[IOException]) + private def writeObject(out: ObjectOutputStream): Unit = { + if (stageId == 0) { + throw new IllegalStateException("Cannot serialize") + } + } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = {} +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 8874cf00e9993..add13f5b21765 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -100,4 +100,34 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin assert(1 === taskDescriptions.length) assert("executor0" === taskDescriptions(0).executorId) } + + test("Scheduler does not crash when tasks are not serializable") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskCpus = 2 + + sc.conf.set("spark.task.cpus", taskCpus.toString) + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + val dagScheduler = new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + val numFreeCores = 1 + taskScheduler.setDAGScheduler(dagScheduler) + var taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) + val multiCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", taskCpus), + new WorkerOffer("executor1", "host1", numFreeCores)) + taskScheduler.submitTasks(taskSet) + var taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten + assert(0 === taskDescriptions.length) + + // Now check that we can still submit tasks + // Even if one of the tasks has not-serializable tasks, the other task set should still be processed without error + taskScheduler.submitTasks(taskSet) + taskScheduler.submitTasks(FakeTask.createTaskSet(1)) + taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten + assert(taskDescriptions.map(_.executorId) === Seq("executor0")) + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 472191551a01f..84b9b788237bf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.io.{ObjectInputStream, ObjectOutputStream, IOException} import java.util.Random import scala.collection.mutable.ArrayBuffer @@ -563,6 +564,19 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.emittedTaskSizeWarning) } + test("Not serializable exception thrown if the task cannot be serialized") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + + val taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) + + intercept[TaskNotSerializableException] { + manager.resourceOffer("exec1", "host1", ANY) + } + assert(manager.isZombie) + } + test("abort the job if total size of results is too large") { val conf = new SparkConf().set("spark.driver.maxResultSize", "2m") sc = new SparkContext("local", "test", conf) diff --git a/dev/run-tests b/dev/run-tests index 20603fc089239..2257a566bb1bb 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -21,8 +21,10 @@ FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" -# Remove work directory +# Clean up work directory and caches rm -rf ./work +rm -rf ~/.ivy2/local/org.apache.spark +rm -rf ~/.ivy2/cache/org.apache.spark source "$FWDIR/dev/run-tests-codes.sh" diff --git a/docs/_config.yml b/docs/_config.yml index a96a76dd9ab5e..e2db274e1f619 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -17,6 +17,6 @@ SPARK_VERSION: 1.3.0-SNAPSHOT SPARK_VERSION_SHORT: 1.3.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.4" -MESOS_VERSION: 0.18.1 +MESOS_VERSION: 0.21.0 SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK SPARK_GITHUB_URL: https://github.com/apache/spark diff --git a/docs/configuration.md b/docs/configuration.md index 2add48569bece..f292bfbb7dcd6 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -678,7 +678,7 @@ Apart from these, the following properties are also available, and may be useful spark.storage.memoryMapThreshold - 8192 + 2097152 Size of a block, in bytes, above which Spark memory maps when reading a block from disk. This prevents Spark from memory mapping very small blocks. In general, memory diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 183698ffe9304..4f273098c5db3 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -21,6 +21,14 @@ Most of the configs are the same for Spark on YARN as for other deployment modes + + + + + @@ -90,7 +98,14 @@ Most of the configs are the same for Spark on YARN as for other deployment modes + + + + + @@ -145,7 +160,7 @@ Most of the configs are the same for Spark on YARN as for other deployment modes diff --git a/examples/pom.xml b/examples/pom.xml index 002d4458c4b3e..4b92147725f6b 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -392,29 +392,6 @@ - - hbase-hadoop2 - - - hbase.profile - hadoop2 - - - - 0.98.7-hadoop2 - - - - hbase-hadoop1 - - - !hbase.profile - - - - 0.98.7-hadoop1 - - diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index 9fbb0a800d735..35b8dd6c29b66 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -27,8 +27,8 @@ object SparkPi { val conf = new SparkConf().setAppName("Spark Pi") val spark = new SparkContext(conf) val slices = if (args.length > 0) args(0).toInt else 2 - val n = 100000 * slices - val count = spark.parallelize(1 to n, slices).map { i => + val n = math.min(100000L * slices, Int.MaxValue).toInt // avoid overflow + val count = spark.parallelize(1 until n, slices).map { i => val x = random * 2 - 1 val y = random * 2 - 1 if (x*x + y*y < 1) 1 else 0 diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 13943ed5442b9..f333e3891b5f0 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -80,7 +80,7 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L val socket = new ServerSocket(trialPort) socket.close() (null, trialPort) - })._2 + }, conf)._2 } /** Setup and start the streaming context */ diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index 98fe6cb301f52..fe53a29cba0c9 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -19,16 +19,21 @@ package org.apache.spark.streaming.mqtt import java.net.{URI, ServerSocket} +import scala.concurrent.duration._ +import scala.language.postfixOps + import org.apache.activemq.broker.{TransportConnector, BrokerService} -import org.apache.spark.util.Utils +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence + import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Eventually -import scala.concurrent.duration._ + import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.eclipse.paho.client.mqttv3._ -import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence +import org.apache.spark.SparkConf +import org.apache.spark.util.Utils class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { @@ -38,8 +43,9 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { private val freePort = findFreePort() private val brokerUri = "//localhost:" + freePort private val topic = "def" - private var ssc: StreamingContext = _ private val persistenceDir = Utils.createTempDir() + + private var ssc: StreamingContext = _ private var broker: BrokerService = _ private var connector: TransportConnector = _ @@ -101,7 +107,7 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { val socket = new ServerSocket(trialPort) socket.close() (null, trialPort) - })._2 + }, new SparkConf())._2 } def publishData(data: String): Unit = { @@ -115,8 +121,9 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { val message: MqttMessage = new MqttMessage(data.getBytes("utf-8")) message.setQos(1) message.setRetained(true) - for (i <- 0 to 100) + for (i <- 0 to 100) { msgTopic.publish(message) + } } } finally { client.disconnect() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala index 3a6c0e681e3fa..d8e134619411b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala @@ -20,10 +20,12 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.IndexedSeq import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose} -import org.apache.spark.rdd.RDD + import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS} -import org.apache.spark.mllib.stat.impl.MultivariateGaussian +import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** * This class performs expectation maximization for multivariate Gaussian @@ -45,10 +47,11 @@ import org.apache.spark.mllib.util.MLUtils class GaussianMixtureEM private ( private var k: Int, private var convergenceTol: Double, - private var maxIterations: Int) extends Serializable { + private var maxIterations: Int, + private var seed: Long) extends Serializable { /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */ - def this() = this(2, 0.01, 100) + def this() = this(2, 0.01, 100, Utils.random.nextLong()) // number of samples per cluster to use when initializing Gaussians private val nSamples = 5 @@ -100,11 +103,21 @@ class GaussianMixtureEM private ( this } - /** Return the largest change in log-likelihood at which convergence is - * considered to have occurred. + /** + * Return the largest change in log-likelihood at which convergence is + * considered to have occurred. */ def getConvergenceTol: Double = convergenceTol - + + /** Set the random seed */ + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + + /** Return the random seed */ + def getSeed: Long = seed + /** Perform expectation maximization */ def run(data: RDD[Vector]): GaussianMixtureModel = { val sc = data.sparkContext @@ -113,7 +126,7 @@ class GaussianMixtureEM private ( val breezeData = data.map(u => u.toBreeze.toDenseVector).cache() // Get length of the input vectors - val d = breezeData.first.length + val d = breezeData.first().length // Determine initial weights and corresponding Gaussians. // If the user supplied an initial GMM, we use those values, otherwise @@ -122,11 +135,11 @@ class GaussianMixtureEM private ( // derived from the samples val (weights, gaussians) = initialModel match { case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) => - new MultivariateGaussian(mu.toBreeze.toDenseVector, sigma.toBreeze.toDenseMatrix) + new MultivariateGaussian(mu, sigma) }) case None => { - val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt) + val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed) (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => val slice = samples.view(i * nSamples, (i + 1) * nSamples) new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) @@ -164,8 +177,8 @@ class GaussianMixtureEM private ( } // Need to convert the breeze matrices to MLlib matrices - val means = Array.tabulate(k) { i => Vectors.fromBreeze(gaussians(i).mu) } - val sigmas = Array.tabulate(k) { i => Matrices.fromBreeze(gaussians(i).sigma) } + val means = Array.tabulate(k) { i => gaussians(i).mu } + val sigmas = Array.tabulate(k) { i => gaussians(i).sigma } new GaussianMixtureModel(weights, means, sigmas) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index b461ea4f0f06e..416cad080c408 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -21,7 +21,7 @@ import breeze.linalg.{DenseVector => BreezeVector} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrix, Vector} -import org.apache.spark.mllib.stat.impl.MultivariateGaussian +import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 36d8cadd2bdd7..181f507516485 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -102,6 +102,9 @@ class IndexedRowMatrix( k: Int, computeU: Boolean = false, rCond: Double = 1e-9): SingularValueDecomposition[IndexedRowMatrix, Matrix] = { + + val n = numCols().toInt + require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.") val indices = rows.map(_.index) val svd = toRowMatrix().computeSVD(k, computeU, rCond) val U = if (computeU) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index fbd35e372f9b1..d5abba6a4b645 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -212,7 +212,7 @@ class RowMatrix( tol: Double, mode: String): SingularValueDecomposition[RowMatrix, Matrix] = { val n = numCols().toInt - require(k > 0 && k <= n, s"Request up to n singular values but got k=$k and n=$n.") + require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.") object SVDMode extends Enumeration { val LocalARPACK, LocalLAPACK, DistARPACK = Value diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala similarity index 61% rename from mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala rename to mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index bc7f6c5197ac7..fd186b5ee6f72 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -15,13 +15,16 @@ * limitations under the License. */ -package org.apache.spark.mllib.stat.impl +package org.apache.spark.mllib.stat.distribution import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym} +import org.apache.spark.annotation.DeveloperApi; +import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} import org.apache.spark.mllib.util.MLUtils /** + * :: DeveloperApi :: * This class provides basic functionality for a Multivariate Gaussian (Normal) Distribution. In * the event that the covariance matrix is singular, the density will be computed in a * reduced dimensional subspace under which the distribution is supported. @@ -30,33 +33,64 @@ import org.apache.spark.mllib.util.MLUtils * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution */ -private[mllib] class MultivariateGaussian( - val mu: DBV[Double], - val sigma: DBM[Double]) extends Serializable { +@DeveloperApi +class MultivariateGaussian ( + val mu: Vector, + val sigma: Matrix) extends Serializable { + require(sigma.numCols == sigma.numRows, "Covariance matrix must be square") + require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size") + + private val breezeMu = mu.toBreeze.toDenseVector + + /** + * private[mllib] constructor + * + * @param mu The mean vector of the distribution + * @param sigma The covariance matrix of the distribution + */ + private[mllib] def this(mu: DBV[Double], sigma: DBM[Double]) = { + this(Vectors.fromBreeze(mu), Matrices.fromBreeze(sigma)) + } + /** * Compute distribution dependent constants: - * rootSigmaInv = D^(-1/2) * U, where sigma = U * D * U.t - * u = (2*pi)^(-k/2) * det(sigma)^(-1/2) + * rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t + * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants /** Returns density of this multivariate Gaussian at given point, x */ - def pdf(x: DBV[Double]): Double = { - val delta = x - mu + def pdf(x: Vector): Double = { + pdf(x.toBreeze.toDenseVector) + } + + /** Returns the log-density of this multivariate Gaussian at given point, x */ + def logpdf(x: Vector): Double = { + logpdf(x.toBreeze.toDenseVector) + } + + /** Returns density of this multivariate Gaussian at given point, x */ + private[mllib] def pdf(x: DBV[Double]): Double = { + math.exp(logpdf(x)) + } + + /** Returns the log-density of this multivariate Gaussian at given point, x */ + private[mllib] def logpdf(x: DBV[Double]): Double = { + val delta = x - breezeMu val v = rootSigmaInv * delta - u * math.exp(v.t * v * -0.5) + u + v.t * v * -0.5 } /** * Calculate distribution dependent components used for the density function: - * pdf(x) = (2*pi)^(-k/2) * det(sigma)^(-1/2) * exp( (-1/2) * (x-mu).t * inv(sigma) * (x-mu) ) + * pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu)) * where k is length of the mean vector. * * We here compute distribution-fixed parts - * (2*pi)^(-k/2) * det(sigma)^(-1/2) + * log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) * and - * D^(-1/2) * U, where sigma = U * D * U.t + * D^(-1/2)^ * U, where sigma = U * D * U.t * * Both the determinant and the inverse can be computed from the singular value decomposition * of sigma. Noting that covariance matrices are always symmetric and positive semi-definite, @@ -65,11 +99,11 @@ private[mllib] class MultivariateGaussian( * * sigma = U * D * U.t * inv(Sigma) = U * inv(D) * U.t - * = (D^{-1/2} * U).t * (D^{-1/2} * U) + * = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U) * * and thus * - * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2} * U * (x-mu))^2 + * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^ * * To guard against singular covariance matrices, this method computes both the * pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered @@ -77,21 +111,21 @@ private[mllib] class MultivariateGaussian( * relation to the maximum singular value (same tolerance used by, e.g., Octave). */ private def calculateCovarianceConstants: (DBM[Double], Double) = { - val eigSym.EigSym(d, u) = eigSym(sigma) // sigma = u * diag(d) * u.t + val eigSym.EigSym(d, u) = eigSym(sigma.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t // For numerical stability, values are considered to be non-zero only if they exceed tol. // This prevents any inverted value from exceeding (eps * n * max(d))^-1 val tol = MLUtils.EPSILON * max(d) * d.length try { - // pseudo-determinant is product of all non-zero singular values - val pdetSigma = d.activeValuesIterator.filter(_ > tol).reduce(_ * _) + // log(pseudo-determinant) is sum of the logs of all non-zero singular values + val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum // calculate the root-pseudo-inverse of the diagonal matrix of singular values // by inverting the square root of all non-zero values val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray)) - (pinvS * u, math.pow(2.0 * math.Pi, -mu.length / 2.0) * math.pow(pdetSigma, -0.5)) + (pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma)) } catch { case uex: UnsupportedOperationException => throw new IllegalArgumentException("Covariance matrix has no non-zero singular values") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala index 23feb82874b70..9da5495741a80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala @@ -35,12 +35,14 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex val Ew = 1.0 val Emu = Vectors.dense(5.0, 10.0) val Esigma = Matrices.dense(2, 2, Array(2.0 / 3.0, -2.0 / 3.0, -2.0 / 3.0, 2.0 / 3.0)) - - val gmm = new GaussianMixtureEM().setK(1).run(data) - - assert(gmm.weight(0) ~== Ew absTol 1E-5) - assert(gmm.mu(0) ~== Emu absTol 1E-5) - assert(gmm.sigma(0) ~== Esigma absTol 1E-5) + + val seeds = Array(314589, 29032897, 50181, 494821, 4660) + seeds.foreach { seed => + val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data) + assert(gmm.weight(0) ~== Ew absTol 1E-5) + assert(gmm.mu(0) ~== Emu absTol 1E-5) + assert(gmm.sigma(0) ~== Esigma absTol 1E-5) + } } test("two clusters") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index e25bc02b06c9a..741cd4997b853 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -113,6 +113,13 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(closeToZero(U * brzDiag(s) * V.t - localA)) } + test("validate k in svd") { + val A = new IndexedRowMatrix(indexedRows) + intercept[IllegalArgumentException] { + A.computeSVD(-1) + } + } + def closeToZero(G: BDM[Double]): Boolean = { G.valuesIterator.map(math.abs).sum < 1e-6 } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index dbf55ff81ca99..3309713e91f87 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -171,6 +171,14 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { } } + test("validate k in svd") { + for (mat <- Seq(denseMat, sparseMat)) { + intercept[IllegalArgumentException] { + mat.computeSVD(-1) + } + } + } + def closeToZero(G: BDM[Double]): Boolean = { G.valuesIterator.map(math.abs).sum < 1e-6 } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala similarity index 72% rename from mllib/src/test/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussianSuite.scala rename to mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index d58f2587e55aa..fac2498e4dcb3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -15,54 +15,53 @@ * limitations under the License. */ -package org.apache.spark.mllib.stat.impl +package org.apache.spark.mllib.stat.distribution import org.scalatest.FunSuite -import breeze.linalg.{ DenseVector => BDV, DenseMatrix => BDM } - +import org.apache.spark.mllib.linalg.{ Vectors, Matrices } import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ class MultivariateGaussianSuite extends FunSuite with MLlibTestSparkContext { test("univariate") { - val x1 = new BDV(Array(0.0)) - val x2 = new BDV(Array(1.5)) + val x1 = Vectors.dense(0.0) + val x2 = Vectors.dense(1.5) - val mu = new BDV(Array(0.0)) - val sigma1 = new BDM(1, 1, Array(1.0)) + val mu = Vectors.dense(0.0) + val sigma1 = Matrices.dense(1, 1, Array(1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5) - val sigma2 = new BDM(1, 1, Array(4.0)) + val sigma2 = Matrices.dense(1, 1, Array(4.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5) } test("multivariate") { - val x1 = new BDV(Array(0.0, 0.0)) - val x2 = new BDV(Array(1.0, 1.0)) + val x1 = Vectors.dense(0.0, 0.0) + val x2 = Vectors.dense(1.0, 1.0) - val mu = new BDV(Array(0.0, 0.0)) - val sigma1 = new BDM(2, 2, Array(1.0, 0.0, 0.0, 1.0)) + val mu = Vectors.dense(0.0, 0.0) + val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5) - val sigma2 = new BDM(2, 2, Array(4.0, -1.0, -1.0, 2.0)) + val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5) } test("multivariate degenerate") { - val x1 = new BDV(Array(0.0, 0.0)) - val x2 = new BDV(Array(1.0, 1.0)) + val x1 = Vectors.dense(0.0, 0.0) + val x2 = Vectors.dense(1.0, 1.0) - val mu = new BDV(Array(0.0, 0.0)) - val sigma = new BDM(2, 2, Array(1.0, 1.0, 1.0, 1.0)) + val mu = Vectors.dense(0.0, 0.0) + val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0)) val dist = new MultivariateGaussian(mu, sigma) assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5) assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5) diff --git a/pom.xml b/pom.xml index 703e5c47bf59b..f4466e56c2a53 100644 --- a/pom.xml +++ b/pom.xml @@ -115,14 +115,14 @@ 1.6 spark 2.0.1 - 0.18.1 + 0.21.0 shaded-protobuf 1.7.5 1.2.17 1.0.4 2.4.1 ${hadoop.version} - 0.94.6 + 0.98.7-hadoop1 hbase 1.4.0 3.4.5 @@ -1130,6 +1130,7 @@ ${test_classpath} true + false @@ -1465,6 +1466,7 @@ 2.2.0 2.5.0 + 0.98.7-hadoop2 hadoop2 @@ -1475,6 +1477,7 @@ 2.3.02.5.00.9.0 + 0.98.7-hadoop23.1.1hadoop2 @@ -1486,6 +1489,7 @@ 2.4.02.5.00.9.0 + 0.98.7-hadoop23.1.1hadoop2 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 31d4c317ae569..51e8bd4cf6419 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,7 +36,6 @@ object MimaExcludes { case v if v.startsWith("1.3") => Seq( MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("graphx"), // These are needed if checking against the sbt build, since they are part of // the maven-generated artifacts in the 1.2 build. MimaBuild.excludeSparkPackage("unused"), diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 0e8b398fc6b97..014ac1791c849 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -807,14 +807,14 @@ def convert_struct(obj): return if isinstance(obj, tuple): - if hasattr(obj, "fields"): - d = dict(zip(obj.fields, obj)) - if hasattr(obj, "__FIELDS__"): + if hasattr(obj, "_fields"): + d = dict(zip(obj._fields, obj)) + elif hasattr(obj, "__FIELDS__"): d = dict(zip(obj.__FIELDS__, obj)) elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): d = dict(obj) else: - raise ValueError("unexpected tuple: %s" % obj) + raise ValueError("unexpected tuple: %s" % str(obj)) elif isinstance(obj, dict): d = obj @@ -1327,6 +1327,16 @@ def inferSchema(self, rdd, samplingRatio=None): >>> srdd = sqlCtx.inferSchema(nestedRdd2) >>> srdd.collect() [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] + + >>> from collections import namedtuple + >>> CustomRow = namedtuple('CustomRow', 'field1 field2') + >>> rdd = sc.parallelize( + ... [CustomRow(field1=1, field2="row1"), + ... CustomRow(field1=2, field2="row2"), + ... CustomRow(field1=3, field2="row3")]) + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.collect()[0] + Row(field1=1, field2=u'row1') """ if isinstance(rdd, SchemaRDD): diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 646c68e60c2e9..b646f0b6f0868 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -106,7 +106,7 @@ import org.apache.spark.util.Utils val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles /** Jetty server that will serve our classes to worker nodes */ val classServerPort = conf.getInt("spark.replClassServer.port", 0) - val classServer = new HttpServer(outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") + val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") private var currentSettings: Settings = initialSettings var printResults = true // whether to print result lines var totalSilence = false // whether to print anything diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 5e93a71995072..69e44d4f916e1 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -32,7 +32,7 @@ object Main extends Logging { val s = new Settings() s.processArguments(List("-Yrepl-class-based", "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-sync"), true) - val classServer = new HttpServer(outputDir, new SecurityManager(conf)) + val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) var sparkContext: SparkContext = _ var interp = new SparkILoop // this is a public var because tests reset it. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala similarity index 62% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index f1a1ca6616a21..93d74adbcc957 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -105,72 +105,3 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical { } } } - -/** - * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL - * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser. - * - * @param fallback A function that parses an input string to a logical plan - */ -private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { - - // A parser for the key-value part of the "SET [key = [value ]]" syntax - private object SetCommandParser extends RegexParsers { - private val key: Parser[String] = "(?m)[^=]+".r - - private val value: Parser[String] = "(?m).*$".r - - private val pair: Parser[LogicalPlan] = - (key ~ ("=".r ~> value).?).? ^^ { - case None => SetCommand(None) - case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim))) - } - - def apply(input: String): LogicalPlan = parseAll(pair, input) match { - case Success(plan, _) => plan - case x => sys.error(x.toString) - } - } - - protected val AS = Keyword("AS") - protected val CACHE = Keyword("CACHE") - protected val LAZY = Keyword("LAZY") - protected val SET = Keyword("SET") - protected val TABLE = Keyword("TABLE") - protected val UNCACHE = Keyword("UNCACHE") - - protected implicit def asParser(k: Keyword): Parser[String] = - lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - - private val reservedWords: Seq[String] = - this - .getClass - .getMethods - .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword].str) - - override val lexical = new SqlLexical(reservedWords) - - override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | others - - private lazy val cache: Parser[LogicalPlan] = - CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { - case isLazy ~ tableName ~ plan => - CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined) - } - - private lazy val uncache: Parser[LogicalPlan] = - UNCACHE ~ TABLE ~> ident ^^ { - case tableName => UncacheTableCommand(tableName) - } - - private lazy val set: Parser[LogicalPlan] = - SET ~> restInput ^^ { - case input => SetCommandParser(input) - } - - private lazy val others: Parser[LogicalPlan] = - wholeInput ^^ { - case input => fallback(input) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index f79d4ff444dc0..5d974df98b699 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -125,7 +125,7 @@ class SqlParser extends AbstractSparkSQLParser { } protected lazy val start: Parser[LogicalPlan] = - ( select * + ( (select | ("(" ~> select <~ ")")) * ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} @@ -178,10 +178,10 @@ class SqlParser extends AbstractSparkSQLParser { joinedRelation | relationFactor protected lazy val relationFactor: Parser[LogicalPlan] = - ( ident ~ (opt(AS) ~> opt(ident)) ^^ { - case tableName ~ alias => UnresolvedRelation(None, tableName, alias) + ( rep1sep(ident, ".") ~ (opt(AS) ~> opt(ident)) ^^ { + case tableIdent ~ alias => UnresolvedRelation(tableIdent, alias) } - | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) } + | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) } ) protected lazy val joinedRelation: Parser[LogicalPlan] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 72680f37a0b4d..c009cc1e1e85c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -228,11 +228,11 @@ class Analyzer(catalog: Catalog, */ object ResolveRelations extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ InsertIntoTable(UnresolvedRelation(databaseName, name, alias), _, _, _) => + case i @ InsertIntoTable(UnresolvedRelation(tableIdentifier, alias), _, _, _) => i.copy( - table = EliminateAnalysisOperators(catalog.lookupRelation(databaseName, name, alias))) - case UnresolvedRelation(databaseName, name, alias) => - catalog.lookupRelation(databaseName, name, alias) + table = EliminateAnalysisOperators(catalog.lookupRelation(tableIdentifier, alias))) + case UnresolvedRelation(tableIdentifier, alias) => + catalog.lookupRelation(tableIdentifier, alias) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 0415d74bd8141..df8d03b86c533 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -28,77 +28,74 @@ trait Catalog { def caseSensitive: Boolean - def tableExists(db: Option[String], tableName: String): Boolean + def tableExists(tableIdentifier: Seq[String]): Boolean def lookupRelation( - databaseName: Option[String], - tableName: String, - alias: Option[String] = None): LogicalPlan + tableIdentifier: Seq[String], + alias: Option[String] = None): LogicalPlan - def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit + def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit - def unregisterTable(databaseName: Option[String], tableName: String): Unit + def unregisterTable(tableIdentifier: Seq[String]): Unit def unregisterAllTables(): Unit - protected def processDatabaseAndTableName( - databaseName: Option[String], - tableName: String): (Option[String], String) = { + protected def processTableIdentifier(tableIdentifier: Seq[String]): Seq[String] = { if (!caseSensitive) { - (databaseName.map(_.toLowerCase), tableName.toLowerCase) + tableIdentifier.map(_.toLowerCase) } else { - (databaseName, tableName) + tableIdentifier } } - protected def processDatabaseAndTableName( - databaseName: String, - tableName: String): (String, String) = { - if (!caseSensitive) { - (databaseName.toLowerCase, tableName.toLowerCase) + protected def getDbTableName(tableIdent: Seq[String]): String = { + val size = tableIdent.size + if (size <= 2) { + tableIdent.mkString(".") } else { - (databaseName, tableName) + tableIdent.slice(size - 2, size).mkString(".") } } + + protected def getDBTable(tableIdent: Seq[String]) : (Option[String], String) = { + (tableIdent.lift(tableIdent.size - 2), tableIdent.last) + } } class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { val tables = new mutable.HashMap[String, LogicalPlan]() override def registerTable( - databaseName: Option[String], - tableName: String, + tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { - val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) - tables += ((tblName, plan)) + val tableIdent = processTableIdentifier(tableIdentifier) + tables += ((getDbTableName(tableIdent), plan)) } - override def unregisterTable( - databaseName: Option[String], - tableName: String) = { - val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) - tables -= tblName + override def unregisterTable(tableIdentifier: Seq[String]) = { + val tableIdent = processTableIdentifier(tableIdentifier) + tables -= getDbTableName(tableIdent) } override def unregisterAllTables() = { tables.clear() } - override def tableExists(db: Option[String], tableName: String): Boolean = { - val (dbName, tblName) = processDatabaseAndTableName(db, tableName) - tables.get(tblName) match { + override def tableExists(tableIdentifier: Seq[String]): Boolean = { + val tableIdent = processTableIdentifier(tableIdentifier) + tables.get(getDbTableName(tableIdent)) match { case Some(_) => true case None => false } } override def lookupRelation( - databaseName: Option[String], - tableName: String, + tableIdentifier: Seq[String], alias: Option[String] = None): LogicalPlan = { - val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) - val table = tables.getOrElse(tblName, sys.error(s"Table Not Found: $tableName")) - val tableWithQualifiers = Subquery(tblName, table) + val tableIdent = processTableIdentifier(tableIdentifier) + val tableFullName = getDbTableName(tableIdent) + val table = tables.getOrElse(tableFullName, sys.error(s"Table Not Found: $tableFullName")) + val tableWithQualifiers = Subquery(tableIdent.last, table) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. @@ -117,41 +114,39 @@ trait OverrideCatalog extends Catalog { // TODO: This doesn't work when the database changes... val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]() - abstract override def tableExists(db: Option[String], tableName: String): Boolean = { - val (dbName, tblName) = processDatabaseAndTableName(db, tableName) - overrides.get((dbName, tblName)) match { + abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = { + val tableIdent = processTableIdentifier(tableIdentifier) + overrides.get(getDBTable(tableIdent)) match { case Some(_) => true - case None => super.tableExists(db, tableName) + case None => super.tableExists(tableIdentifier) } } abstract override def lookupRelation( - databaseName: Option[String], - tableName: String, + tableIdentifier: Seq[String], alias: Option[String] = None): LogicalPlan = { - val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) - val overriddenTable = overrides.get((dbName, tblName)) - val tableWithQualifers = overriddenTable.map(r => Subquery(tblName, r)) + val tableIdent = processTableIdentifier(tableIdentifier) + val overriddenTable = overrides.get(getDBTable(tableIdent)) + val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r)) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. val withAlias = tableWithQualifers.map(r => alias.map(a => Subquery(a, r)).getOrElse(r)) - withAlias.getOrElse(super.lookupRelation(dbName, tblName, alias)) + withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias)) } override def registerTable( - databaseName: Option[String], - tableName: String, + tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { - val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) - overrides.put((dbName, tblName), plan) + val tableIdent = processTableIdentifier(tableIdentifier) + overrides.put(getDBTable(tableIdent), plan) } - override def unregisterTable(databaseName: Option[String], tableName: String): Unit = { - val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) - overrides.remove((dbName, tblName)) + override def unregisterTable(tableIdentifier: Seq[String]): Unit = { + val tableIdent = processTableIdentifier(tableIdentifier) + overrides.remove(getDBTable(tableIdent)) } override def unregisterAllTables(): Unit = { @@ -167,22 +162,21 @@ object EmptyCatalog extends Catalog { val caseSensitive: Boolean = true - def tableExists(db: Option[String], tableName: String): Boolean = { + def tableExists(tableIdentifier: Seq[String]): Boolean = { throw new UnsupportedOperationException } def lookupRelation( - databaseName: Option[String], - tableName: String, + tableIdentifier: Seq[String], alias: Option[String] = None) = { throw new UnsupportedOperationException } - def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = { + def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { throw new UnsupportedOperationException } - def unregisterTable(databaseName: Option[String], tableName: String): Unit = { + def unregisterTable(tableIdentifier: Seq[String]): Unit = { throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 77d84e1687e1b..71a738a0b2ca0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -34,8 +34,7 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str * Holds the name of a relation that has yet to be looked up in a [[Catalog]]. */ case class UnresolvedRelation( - databaseName: Option[String], - tableName: String, + tableIdentifier: Seq[String], alias: Option[String] = None) extends LeafNode { override def output = Nil override lazy val resolved = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 9608e15c0f302..b2262e5e6efb6 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -290,7 +290,7 @@ package object dsl { def insertInto(tableName: String, overwrite: Boolean = false) = InsertIntoTable( - analysis.UnresolvedRelation(None, tableName), Map.empty, logicalPlan, overwrite) + analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite) def analyze = analysis.SimpleAnalyzer(logicalPlan) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 5a1863953eae9..45905f8ef98c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.types.StringType +import org.apache.spark.sql.catalyst.expressions.Attribute /** * A logical node that represents a non-query command to be executed by the system. For example, @@ -28,48 +27,3 @@ abstract class Command extends LeafNode { self: Product => def output: Seq[Attribute] = Seq.empty } - -/** - * - * Commands of the form "SET [key [= value] ]". - */ -case class SetCommand(kv: Option[(String, Option[String])]) extends Command { - override def output = Seq( - AttributeReference("", StringType, nullable = false)()) -} - -/** - * Returned by a parser when the users only wants to see what query plan would be executed, without - * actually performing the execution. - */ -case class ExplainCommand(plan: LogicalPlan, extended: Boolean = false) extends Command { - override def output = - Seq(AttributeReference("plan", StringType, nullable = false)()) -} - -/** - * Returned for the "CACHE TABLE tableName [AS SELECT ...]" command. - */ -case class CacheTableCommand(tableName: String, plan: Option[LogicalPlan], isLazy: Boolean) - extends Command - -/** - * Returned for the "UNCACHE TABLE tableName" command. - */ -case class UncacheTableCommand(tableName: String) extends Command - -/** - * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. - * @param table The table to be described. - * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. - * It is effective only when the table is a Hive table. - */ -case class DescribeCommand( - table: LogicalPlan, - isExtended: Boolean) extends Command { - override def output = Seq( - // Column names are based on Hive. - AttributeReference("col_name", StringType, nullable = false)(), - AttributeReference("data_type", StringType, nullable = false)(), - AttributeReference("comment", StringType, nullable = false)()) -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 82f2101d8ce17..f430057ef7191 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -44,8 +44,8 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("e", ShortType)()) before { - caseSensitiveCatalog.registerTable(None, "TaBlE", testRelation) - caseInsensitiveCatalog.registerTable(None, "TaBlE", testRelation) + caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) + caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) } test("union project *") { @@ -64,45 +64,45 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert( caseSensitiveAnalyze( Project(Seq(UnresolvedAttribute("TbL.a")), - UnresolvedRelation(None, "TaBlE", Some("TbL")))) === + UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) val e = intercept[TreeNodeException[_]] { caseSensitiveAnalyze( Project(Seq(UnresolvedAttribute("tBl.a")), - UnresolvedRelation(None, "TaBlE", Some("TbL")))) + UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) } assert(e.getMessage().toLowerCase.contains("unresolved")) assert( caseInsensitiveAnalyze( Project(Seq(UnresolvedAttribute("TbL.a")), - UnresolvedRelation(None, "TaBlE", Some("TbL")))) === + UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) assert( caseInsensitiveAnalyze( Project(Seq(UnresolvedAttribute("tBl.a")), - UnresolvedRelation(None, "TaBlE", Some("TbL")))) === + UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) } test("resolve relations") { val e = intercept[RuntimeException] { - caseSensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None)) + caseSensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) } assert(e.getMessage == "Table Not Found: tAbLe") assert( - caseSensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) === + caseSensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) assert( - caseInsensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None)) === + caseInsensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation) assert( - caseInsensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) === + caseInsensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 3677a6e72e23a..bbbeb4f2e4fe3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -41,7 +41,7 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { val f: Expression = UnresolvedAttribute("f") before { - catalog.registerTable(None, "table", relation) + catalog.registerTable(Seq("table"), relation) } private def checkType(expression: Expression, expectedType: DataType): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 6a1a4d995bf61..6c575dd727b46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -76,7 +76,7 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] val sqlParser = { val fallback = new catalyst.SqlParser - new catalyst.SparkSQLParser(fallback(_)) + new SparkSQLParser(fallback(_)) } protected[sql] def parseSql(sql: String): LogicalPlan = { @@ -276,7 +276,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = { - catalog.registerTable(None, tableName, rdd.queryExecution.logical) + catalog.registerTable(Seq(tableName), rdd.queryExecution.logical) } /** @@ -289,7 +289,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ def dropTempTable(tableName: String): Unit = { tryUncacheQuery(table(tableName)) - catalog.unregisterTable(None, tableName) + catalog.unregisterTable(Seq(tableName)) } /** @@ -308,7 +308,7 @@ class SQLContext(@transient val sparkContext: SparkContext) /** Returns the specified table as a SchemaRDD */ def table(tableName: String): SchemaRDD = - new SchemaRDD(this, catalog.lookupRelation(None, tableName)) + new SchemaRDD(this, catalog.lookupRelation(Seq(tableName))) /** * :: DeveloperApi :: @@ -329,7 +329,6 @@ class SQLContext(@transient val sparkContext: SparkContext) def strategies: Seq[Strategy] = extraStrategies ++ ( - CommandStrategy :: DataSourceStrategy :: TakeOrdered :: HashAggregation :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index fd5f4abcbcd65..3cf9209465b76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -97,8 +97,8 @@ private[sql] trait SchemaRDDLike { */ @Experimental def insertInto(tableName: String, overwrite: Boolean): Unit = - sqlContext.executePlan( - InsertIntoTable(UnresolvedRelation(None, tableName), Map.empty, logicalPlan, overwrite)).toRdd + sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), + Map.empty, logicalPlan, overwrite)).toRdd /** * :: Experimental :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala new file mode 100644 index 0000000000000..65358b7d4ea8e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.{SqlLexical, AbstractSparkSQLParser} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{UncacheTableCommand, CacheTableCommand, SetCommand} + +import scala.util.parsing.combinator.RegexParsers + +/** + * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL + * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser. + * + * @param fallback A function that parses an input string to a logical plan + */ +private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { + + // A parser for the key-value part of the "SET [key = [value ]]" syntax + private object SetCommandParser extends RegexParsers { + private val key: Parser[String] = "(?m)[^=]+".r + + private val value: Parser[String] = "(?m).*$".r + + private val output: Seq[Attribute] = Seq(AttributeReference("", StringType, nullable = false)()) + + private val pair: Parser[LogicalPlan] = + (key ~ ("=".r ~> value).?).? ^^ { + case None => SetCommand(None, output) + case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim)), output) + } + + def apply(input: String): LogicalPlan = parseAll(pair, input) match { + case Success(plan, _) => plan + case x => sys.error(x.toString) + } + } + + protected val AS = Keyword("AS") + protected val CACHE = Keyword("CACHE") + protected val LAZY = Keyword("LAZY") + protected val SET = Keyword("SET") + protected val TABLE = Keyword("TABLE") + protected val UNCACHE = Keyword("UNCACHE") + + protected implicit def asParser(k: Keyword): Parser[String] = + lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + + private val reservedWords: Seq[String] = + this + .getClass + .getMethods + .filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword].str) + + override val lexical = new SqlLexical(reservedWords) + + override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | others + + private lazy val cache: Parser[LogicalPlan] = + CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { + case isLazy ~ tableName ~ plan => + CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined) + } + + private lazy val uncache: Parser[LogicalPlan] = + UNCACHE ~ TABLE ~> ident ^^ { + case tableName => UncacheTableCommand(tableName) + } + + private lazy val set: Parser[LogicalPlan] = + SET ~> restInput ^^ { + case input => SetCommandParser(input) + } + + private lazy val others: Parser[LogicalPlan] = + wholeInput ^^ { + case input => fallback(input) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ce878c137e627..99b6611d3bbcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -259,6 +259,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def numPartitions = self.numPartitions def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case r: RunnableCommand => ExecutedCommand(r) :: Nil + case logical.Distinct(child) => execution.Distinct(partial = false, execution.Distinct(partial = true, planLater(child))) :: Nil @@ -308,22 +310,4 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => Nil } } - - case object CommandStrategy extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: RunnableCommand => ExecutedCommand(r) :: Nil - case logical.SetCommand(kv) => - Seq(ExecutedCommand(execution.SetCommand(kv, plan.output))) - case logical.ExplainCommand(logicalPlan, extended) => - Seq(ExecutedCommand( - execution.ExplainCommand(logicalPlan, plan.output, extended))) - case logical.CacheTableCommand(tableName, optPlan, isLazy) => - Seq(ExecutedCommand( - execution.CacheTableCommand(tableName, optPlan, isLazy))) - case logical.UncacheTableCommand(tableName) => - Seq(ExecutedCommand( - execution.UncacheTableCommand(tableName))) - case _ => Nil - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e53723c176569..16ca4be5587c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -70,7 +70,7 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: override def output = child.output // TODO: How to pick seed? - override def execute() = child.execute().sample(withReplacement, fraction, seed) + override def execute() = child.execute().map(_.copy()).sample(withReplacement, fraction, seed) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index b8fa4b019953e..0d765c4c92f85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -113,7 +113,7 @@ case class SetCommand( @DeveloperApi case class ExplainCommand( logicalPlan: LogicalPlan, - override val output: Seq[Attribute], extended: Boolean) extends RunnableCommand { + override val output: Seq[Attribute], extended: Boolean = false) extends RunnableCommand { // Run through the optimizer to generate the physical plan. override def run(sqlContext: SQLContext) = try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index fc70c183437f6..a9a6696cb15e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -18,31 +18,48 @@ package org.apache.spark.sql.json import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.sources._ -private[sql] class DefaultSource extends RelationProvider { - /** Returns a new base relation with the given parameters. */ +private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider { + + /** Returns a new base relation with the parameters. */ override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - JSONRelation(fileName, samplingRatio)(sqlContext) + JSONRelation(fileName, samplingRatio, None)(sqlContext) + } + + /** Returns a new base relation with the given schema and parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) + val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + + JSONRelation(fileName, samplingRatio, Some(schema))(sqlContext) } } -private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)( +private[sql] case class JSONRelation( + fileName: String, + samplingRatio: Double, + userSpecifiedSchema: Option[StructType])( @transient val sqlContext: SQLContext) extends TableScan { private def baseRDD = sqlContext.sparkContext.textFile(fileName) - override val schema = - JsonRDD.inferSchema( - baseRDD, - samplingRatio, - sqlContext.columnNameOfCorruptRecord) + override val schema = userSpecifiedSchema.getOrElse( + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema( + baseRDD, + samplingRatio, + sqlContext.columnNameOfCorruptRecord))) override def buildScan() = JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index b237a07c72d07..2835dc3408b96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -28,7 +28,7 @@ import parquet.schema.MessageType import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} /** @@ -67,6 +67,8 @@ private[sql] case class ParquetRelation( conf, sqlContext.isParquetBinaryAsString) + lazy val attributeMap = AttributeMap(output.map(o => o -> o)) + override def newInstance() = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] // Equals must also take into account the output attributes so that we can distinguish between diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 96bace1769f71..f5487740d3af9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -64,18 +64,17 @@ case class ParquetTableScan( // The resolution of Parquet attributes is case sensitive, so we resolve the original attributes // by exprId. note: output cannot be transient, see // https://issues.apache.org/jira/browse/SPARK-1367 - val normalOutput = - attributes - .filterNot(a => relation.partitioningAttributes.map(_.exprId).contains(a.exprId)) - .flatMap(a => relation.output.find(o => o.exprId == a.exprId)) + val output = attributes.map(relation.attributeMap) - val partOutput = - attributes.flatMap(a => relation.partitioningAttributes.find(o => o.exprId == a.exprId)) + // A mapping of ordinals partitionRow -> finalOutput. + val requestedPartitionOrdinals = { + val partitionAttributeOrdinals = AttributeMap(relation.partitioningAttributes.zipWithIndex) - def output = partOutput ++ normalOutput - - assert(normalOutput.size + partOutput.size == attributes.size, - s"$normalOutput + $partOutput != $attributes, ${relation.output}") + attributes.zipWithIndex.flatMap { + case (attribute, finalOrdinal) => + partitionAttributeOrdinals.get(attribute).map(_ -> finalOrdinal) + } + }.toArray override def execute(): RDD[Row] = { import parquet.filter2.compat.FilterCompat.FilterPredicateCompat @@ -97,7 +96,7 @@ case class ParquetTableScan( // Store both requested and original schema in `Configuration` conf.set( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(normalOutput)) + ParquetTypesConverter.convertToString(output)) conf.set( RowWriteSupport.SPARK_ROW_SCHEMA, ParquetTypesConverter.convertToString(relation.output)) @@ -125,7 +124,7 @@ case class ParquetTableScan( classOf[Row], conf) - if (partOutput.nonEmpty) { + if (requestedPartitionOrdinals.nonEmpty) { baseRDD.mapPartitionsWithInputSplit { case (split, iter) => val partValue = "([^=]+)=([^=]+)".r val partValues = @@ -138,15 +137,25 @@ case class ParquetTableScan( case _ => None }.toMap + // Convert the partitioning attributes into the correct types val partitionRowValues = - partOutput.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) + relation.partitioningAttributes + .map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) new Iterator[Row] { - private[this] val joinedRow = new JoinedRow5(Row(partitionRowValues:_*), null) - def hasNext = iter.hasNext - - def next() = joinedRow.withRight(iter.next()._2) + def next() = { + val row = iter.next()._2.asInstanceOf[SpecificMutableRow] + + // Parquet will leave partitioning columns empty, so we fill them in here. + var i = 0 + while (i < requestedPartitionOrdinals.size) { + row(requestedPartitionOrdinals(i)._2) = + partitionRowValues(requestedPartitionOrdinals(i)._1) + i += 1 + } + row + } } } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 8a66ac31f2dfb..fe2c4d8436b2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql.sources -import org.apache.spark.Logging -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.util.Utils - import scala.language.implicitConversions -import scala.util.parsing.combinator.lexical.StdLexical import scala.util.parsing.combinator.syntactical.StandardTokenParsers import scala.util.parsing.combinator.PackratParsers +import org.apache.spark.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.SqlLexical @@ -44,6 +43,14 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi } } + def parseType(input: String): DataType = { + phrase(dataType)(new lexical.Scanner(input)) match { + case Success(r, x) => r + case x => + sys.error(s"Unsupported dataType: $x") + } + } + protected case class Keyword(str: String) protected implicit def asParser(k: Keyword): Parser[String] = @@ -55,6 +62,24 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected val USING = Keyword("USING") protected val OPTIONS = Keyword("OPTIONS") + // Data types. + protected val STRING = Keyword("STRING") + protected val BINARY = Keyword("BINARY") + protected val BOOLEAN = Keyword("BOOLEAN") + protected val TINYINT = Keyword("TINYINT") + protected val SMALLINT = Keyword("SMALLINT") + protected val INT = Keyword("INT") + protected val BIGINT = Keyword("BIGINT") + protected val FLOAT = Keyword("FLOAT") + protected val DOUBLE = Keyword("DOUBLE") + protected val DECIMAL = Keyword("DECIMAL") + protected val DATE = Keyword("DATE") + protected val TIMESTAMP = Keyword("TIMESTAMP") + protected val VARCHAR = Keyword("VARCHAR") + protected val ARRAY = Keyword("ARRAY") + protected val MAP = Keyword("MAP") + protected val STRUCT = Keyword("STRUCT") + // Use reflection to find the reserved words defined in this class. protected val reservedWords = this.getClass @@ -67,15 +92,25 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected lazy val ddl: Parser[LogicalPlan] = createTable /** - * CREATE TEMPORARY TABLE avroTable + * `CREATE TEMPORARY TABLE avroTable * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * or + * `CREATE TEMPORARY TABLE avroTable(intField int, stringField string...) + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` */ protected lazy val createTable: Parser[LogicalPlan] = - CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { - case tableName ~ provider ~ opts => - CreateTableUsing(tableName, provider, opts) + ( + CREATE ~ TEMPORARY ~ TABLE ~> ident + ~ (tableCols).? ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { + case tableName ~ columns ~ provider ~ opts => + val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) + CreateTableUsing(tableName, userSpecifiedSchema, provider, opts) } + ) + + protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" protected lazy val options: Parser[Map[String, String]] = "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } @@ -83,10 +118,66 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } + + protected lazy val column: Parser[StructField] = + ident ~ dataType ^^ { case columnName ~ typ => + StructField(columnName, typ) + } + + protected lazy val primitiveType: Parser[DataType] = + STRING ^^^ StringType | + BINARY ^^^ BinaryType | + BOOLEAN ^^^ BooleanType | + TINYINT ^^^ ByteType | + SMALLINT ^^^ ShortType | + INT ^^^ IntegerType | + BIGINT ^^^ LongType | + FLOAT ^^^ FloatType | + DOUBLE ^^^ DoubleType | + fixedDecimalType | // decimal with precision/scale + DECIMAL ^^^ DecimalType.Unlimited | // decimal with no precision/scale + DATE ^^^ DateType | + TIMESTAMP ^^^ TimestampType | + VARCHAR ~ "(" ~ numericLit ~ ")" ^^^ StringType + + protected lazy val fixedDecimalType: Parser[DataType] = + (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + } + + protected lazy val arrayType: Parser[DataType] = + ARRAY ~> "<" ~> dataType <~ ">" ^^ { + case tpe => ArrayType(tpe) + } + + protected lazy val mapType: Parser[DataType] = + MAP ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { + case t1 ~ _ ~ t2 => MapType(t1, t2) + } + + protected lazy val structField: Parser[StructField] = + ident ~ ":" ~ dataType ^^ { + case fieldName ~ _ ~ tpe => StructField(fieldName, tpe, nullable = true) + } + + protected lazy val structType: Parser[DataType] = + (STRUCT ~> "<" ~> repsep(structField, ",") <~ ">" ^^ { + case fields => new StructType(fields) + }) | + (STRUCT ~> "<>" ^^ { + case fields => new StructType(Nil) + }) + + private[sql] lazy val dataType: Parser[DataType] = + arrayType | + mapType | + structType | + primitiveType } private[sql] case class CreateTableUsing( tableName: String, + userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String]) extends RunnableCommand { @@ -99,8 +190,29 @@ private[sql] case class CreateTableUsing( sys.error(s"Failed to load class for data source: $provider") } } - val dataSource = clazz.newInstance().asInstanceOf[org.apache.spark.sql.sources.RelationProvider] - val relation = dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) + + val relation = userSpecifiedSchema match { + case Some(schema: StructType) => { + clazz.newInstance match { + case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => + dataSource + .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider] + .createRelation(sqlContext, new CaseInsensitiveMap(options), schema) + case _ => + sys.error(s"${clazz.getCanonicalName} should extend SchemaRelationProvider.") + } + } + case None => { + clazz.newInstance match { + case dataSource: org.apache.spark.sql.sources.RelationProvider => + dataSource + .asInstanceOf[org.apache.spark.sql.sources.RelationProvider] + .createRelation(sqlContext, new CaseInsensitiveMap(options)) + case _ => + sys.error(s"${clazz.getCanonicalName} should extend RelationProvider.") + } + } + } sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName) Seq.empty @@ -110,7 +222,8 @@ private[sql] case class CreateTableUsing( /** * Builds a map in which keys are case insensitive */ -protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] { +protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] + with Serializable { val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 02eff80456dbe..990f7e0e74bcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.sources import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLConf, Row, SQLContext, StructType} +import org.apache.spark.sql.{Row, SQLContext, StructType} import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} /** @@ -44,6 +44,33 @@ trait RelationProvider { def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation } +/** + * ::DeveloperApi:: + * Implemented by objects that produce relations for a specific kind of data source. When + * Spark SQL is given a DDL operation with + * 1. USING clause: to specify the implemented SchemaRelationProvider + * 2. User defined schema: users can define schema optionally when create table + * + * Users may specify the fully qualified class name of a given data source. When that class is + * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for + * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the + * data source 'org.apache.spark.sql.json.DefaultSource' + * + * A new instance of this class with be instantiated each time a DDL call is made. + */ +@DeveloperApi +trait SchemaRelationProvider { + /** + * Returns a new base relation with the given parameters and user defined schema. + * Note: the parameters' keywords are case insensitive and this insensitivity is enforced + * by the Map that is passed to the function. + */ + def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation +} + /** * ::DeveloperApi:: * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 1a4232dab86e7..c7e136388fce8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -302,8 +302,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { upperCaseData.where('N <= 4).registerTempTable("left") upperCaseData.where('N >= 3).registerTempTable("right") - val left = UnresolvedRelation(None, "left", None) - val right = UnresolvedRelation(None, "right", None) + val left = UnresolvedRelation(Seq("left"), None) + val right = UnresolvedRelation(Seq("right"), None) checkAnswer( left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index add4e218a22ee..d9de5686dce48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -272,6 +272,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { mapData.collect().take(1).toSeq) } + test("from follow multiple brackets") { + checkAnswer(sql( + "select key from ((select * from testData limit 1) union all (select * from testData limit 1)) x limit 1"), + 1 + ) + + checkAnswer(sql( + "select key from (select * from testData) x limit 1"), + 1 + ) + + checkAnswer(sql( + "select key from (select * from testData limit 1 union all select * from testData limit 1) x limit 1"), + 1 + ) + } + test("average") { checkAnswer( sql("SELECT AVG(a) FROM testData2"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 3cd7b0115d567..605190f5ae6a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.sources +import java.sql.{Timestamp, Date} + import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.types.DecimalType class DefaultSource extends SimpleScanSource @@ -38,9 +41,77 @@ case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) } +class AllDataTypesScanSource extends SchemaRelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext) + } +} + +case class AllDataTypesScan( + from: Int, + to: Int, + userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext) + extends TableScan { + + override def schema = userSpecifiedSchema + + override def buildScan() = { + sqlContext.sparkContext.parallelize(from to to).map { i => + Row( + s"str_$i", + s"str_$i".getBytes(), + i % 2 == 0, + i.toByte, + i.toShort, + i, + i.toLong, + i.toFloat, + i.toDouble, + BigDecimal(i), + BigDecimal(i), + new Date((i + 1) * 8640000), + new Timestamp(20000 + i), + s"varchar_$i", + Seq(i, i + 1), + Seq(Map(s"str_$i" -> Row(i.toLong))), + Map(i -> i.toString), + Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), + Row(i, i.toString), + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + } + } +} + class TableScanSuite extends DataSourceTest { import caseInsensisitiveContext._ + var tableWithSchemaExpected = (1 to 10).map { i => + Row( + s"str_$i", + s"str_$i", + i % 2 == 0, + i.toByte, + i.toShort, + i, + i.toLong, + i.toFloat, + i.toDouble, + BigDecimal(i), + BigDecimal(i), + new Date((i + 1) * 8640000), + new Timestamp(20000 + i), + s"varchar_$i", + Seq(i, i + 1), + Seq(Map(s"str_$i" -> Row(i.toLong))), + Map(i -> i.toString), + Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), + Row(i, i.toString), + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + }.toSeq + before { sql( """ @@ -51,6 +122,37 @@ class TableScanSuite extends DataSourceTest { | To '10' |) """.stripMargin) + + sql( + """ + |CREATE TEMPORARY TABLE tableWithSchema ( + |`string$%Field` stRIng, + |binaryField binary, + |`booleanField` boolean, + |ByteField tinyint, + |shortField smaLlint, + |int_Field iNt, + |`longField_:,<>=+/~^` Bigint, + |floatField flOat, + |doubleField doubLE, + |decimalField1 decimal, + |decimalField2 decimal(9,2), + |dateField dAte, + |timestampField tiMestamp, + |varcharField varchaR(12), + |arrayFieldSimple Array, + |arrayFieldComplex Array>>, + |mapFieldSimple MAP, + |mapFieldComplex Map, Struct>, + |structFieldSimple StRuct, + |structFieldComplex StRuct, Value:struct<`value_(2)`:Array>> + |) + |USING org.apache.spark.sql.sources.AllDataTypesScanSource + |OPTIONS ( + | From '1', + | To '10' + |) + """.stripMargin) } sqlTest( @@ -73,6 +175,96 @@ class TableScanSuite extends DataSourceTest { "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1", (2 to 10).map(i => Row(i, i - 1)).toSeq) + test("Schema and all fields") { + val expectedSchema = StructType( + StructField("string$%Field", StringType, true) :: + StructField("binaryField", BinaryType, true) :: + StructField("booleanField", BooleanType, true) :: + StructField("ByteField", ByteType, true) :: + StructField("shortField", ShortType, true) :: + StructField("int_Field", IntegerType, true) :: + StructField("longField_:,<>=+/~^", LongType, true) :: + StructField("floatField", FloatType, true) :: + StructField("doubleField", DoubleType, true) :: + StructField("decimalField1", DecimalType.Unlimited, true) :: + StructField("decimalField2", DecimalType(9, 2), true) :: + StructField("dateField", DateType, true) :: + StructField("timestampField", TimestampType, true) :: + StructField("varcharField", StringType, true) :: + StructField("arrayFieldSimple", ArrayType(IntegerType), true) :: + StructField("arrayFieldComplex", + ArrayType( + MapType(StringType, StructType(StructField("key", LongType, true) :: Nil))), true) :: + StructField("mapFieldSimple", MapType(IntegerType, StringType), true) :: + StructField("mapFieldComplex", + MapType( + MapType(StringType, FloatType), + StructType(StructField("key", LongType, true) :: Nil)), true) :: + StructField("structFieldSimple", + StructType( + StructField("key", IntegerType, true) :: + StructField("Value", StringType, true) :: Nil), true) :: + StructField("structFieldComplex", + StructType( + StructField("key", ArrayType(StringType), true) :: + StructField("Value", + StructType( + StructField("value_(2)", ArrayType(DateType), true) :: Nil), true) :: Nil), true) :: + Nil + ) + + assert(expectedSchema == table("tableWithSchema").schema) + + checkAnswer( + sql( + """SELECT + | `string$%Field`, + | cast(binaryField as string), + | booleanField, + | byteField, + | shortField, + | int_Field, + | `longField_:,<>=+/~^`, + | floatField, + | doubleField, + | decimalField1, + | decimalField2, + | dateField, + | timestampField, + | varcharField, + | arrayFieldSimple, + | arrayFieldComplex, + | mapFieldSimple, + | mapFieldComplex, + | structFieldSimple, + | structFieldComplex FROM tableWithSchema""".stripMargin), + tableWithSchemaExpected + ) + } + + sqlTest( + "SELECT count(*) FROM tableWithSchema", + 10) + + sqlTest( + "SELECT `string$%Field` FROM tableWithSchema", + (1 to 10).map(i => Row(s"str_$i")).toSeq) + + sqlTest( + "SELECT int_Field FROM tableWithSchema WHERE int_Field < 5", + (1 to 4).map(Row(_)).toSeq) + + sqlTest( + "SELECT `longField_:,<>=+/~^` * 2 FROM tableWithSchema", + (1 to 10).map(i => Row(i * 2.toLong)).toSeq) + + sqlTest( + "SELECT structFieldSimple.key, arrayFieldSimple[1] FROM tableWithSchema a where int_Field=1", + Seq(Seq(1, 2))) + + sqlTest( + "SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema", + (1 to 10).map(i => Row(Seq(new Date((i + 2) * 8640000)))).toSeq) test("Caching") { // Cached Query Execution diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 259eef0b80d03..123a1f629ab1c 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -76,13 +76,6 @@ - - org.apache.maven.plugins - maven-deploy-plugin - - true - - diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala index 7a3d76c61c3a1..59f3a75768082 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala @@ -53,6 +53,7 @@ private[hive] abstract class AbstractSparkSQLDriver( override def run(command: String): CommandProcessorResponse = { // TODO unify the error code try { + context.sparkContext.setJobDescription(command) val execution = context.executePlan(context.sql(command).logicalPlan) hiveResponse = execution.stringResult() tableSchema = getResultSetSchema(execution) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index 7814aa38f4146..b52a51d11e4ad 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -143,6 +143,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging { | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=http | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT}=$port | --driver-class-path ${sys.props("java.class.path")} + | --conf spark.ui.enabled=false """.stripMargin.split("\\s+").toSeq } else { s"""$startScript @@ -153,6 +154,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging { | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port | --driver-class-path ${sys.props("java.class.path")} + | --conf spark.ui.enabled=false """.stripMargin.split("\\s+").toSeq } diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala index 5550183621fb6..742acba58d776 100644 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -22,7 +22,6 @@ import java.util.{ArrayList => JArrayList, Map => JMap} import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, Map => SMap} -import scala.math._ import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.metastore.api.FieldSchema @@ -33,8 +32,8 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.plans.logical.SetCommand import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} import org.apache.spark.sql.{SQLConf, SchemaRDD, Row => SparkRow} @@ -190,14 +189,12 @@ private[hive] class SparkExecuteStatementOperation( result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) result.queryExecution.logical match { - case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value)))) => + case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value))), _) => sessionToActivePool(parentSession.getSessionHandle) = value logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") case _ => } - - val groupId = round(random * 1000000).toString - hiveContext.sparkContext.setJobGroup(groupId, statement) + hiveContext.sparkContext.setJobDescription(statement) sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) } diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index 798a690a20427..b82156427a88c 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -22,7 +22,6 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, Map => SMap} -import scala.math._ import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.security.UserGroupInformation @@ -31,7 +30,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.plans.logical.SetCommand +import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} @@ -161,14 +160,12 @@ private[hive] class SparkExecuteStatementOperation( result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) result.queryExecution.logical match { - case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value)))) => + case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value))), _) => sessionToActivePool(parentSession.getSessionHandle) = value logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") case _ => } - - val groupId = round(random * 1000000).toString - hiveContext.sparkContext.setJobGroup(groupId, statement) + hiveContext.sparkContext.setJobDescription(statement) sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 982e0593fcfd1..02eac43b2103f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -38,8 +38,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperators, OverrideCatalog, OverrideFunctionRegistry} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types.DecimalType -import org.apache.spark.sql.catalyst.types.decimal.Decimal -import org.apache.spark.sql.execution.{SparkPlan, ExecutedCommand, ExtractPythonUdfs, QueryExecutionException} +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand, QueryExecutionException} import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DescribeHiveTableCommand} import org.apache.spark.sql.sources.DataSourceStrategy @@ -124,7 +123,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * in the Hive metastore. */ def analyze(tableName: String) { - val relation = EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) + val relation = EliminateAnalysisOperators(catalog.lookupRelation(Seq(tableName))) relation match { case relation: MetastoreRelation => @@ -340,7 +339,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override def strategies: Seq[Strategy] = extraStrategies ++ Seq( DataSourceStrategy, - CommandStrategy, HiveCommandStrategy(self), TakeOrdered, ParquetOperations, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index b31a3ec25096b..c25288e000122 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -20,23 +20,18 @@ package org.apache.spark.sql.hive import java.io.IOException import java.util.{List => JList} -import org.apache.spark.sql.execution.SparkPlan - -import scala.util.parsing.combinator.RegexParsers - import org.apache.hadoop.util.ReflectionUtils - import org.apache.hadoop.hive.metastore.TableType import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table, HiveException} +import org.apache.hadoop.hive.ql.metadata.InvalidTableException import org.apache.hadoop.hive.ql.plan.CreateTableDesc import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException} import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{Catalog, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ @@ -57,18 +52,25 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val caseSensitive: Boolean = false - def tableExists(db: Option[String], tableName: String): Boolean = { - val (databaseName, tblName) = processDatabaseAndTableName( - db.getOrElse(hive.sessionState.getCurrentDatabase), tableName) - client.getTable(databaseName, tblName, false) != null + def tableExists(tableIdentifier: Seq[String]): Boolean = { + val tableIdent = processTableIdentifier(tableIdentifier) + val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( + hive.sessionState.getCurrentDatabase) + val tblName = tableIdent.last + try { + client.getTable(databaseName, tblName) != null + } catch { + case ie: InvalidTableException => false + } } def lookupRelation( - db: Option[String], - tableName: String, + tableIdentifier: Seq[String], alias: Option[String]): LogicalPlan = synchronized { - val (databaseName, tblName) = - processDatabaseAndTableName(db.getOrElse(hive.sessionState.getCurrentDatabase), tableName) + val tableIdent = processTableIdentifier(tableIdentifier) + val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( + hive.sessionState.getCurrentDatabase) + val tblName = tableIdent.last val table = client.getTable(databaseName, tblName) if (table.isView) { // if the unresolved relation is from hive view @@ -251,6 +253,26 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } } + protected def processDatabaseAndTableName( + databaseName: Option[String], + tableName: String): (Option[String], String) = { + if (!caseSensitive) { + (databaseName.map(_.toLowerCase), tableName.toLowerCase) + } else { + (databaseName, tableName) + } + } + + protected def processDatabaseAndTableName( + databaseName: String, + tableName: String): (String, String) = { + if (!caseSensitive) { + (databaseName.toLowerCase, tableName.toLowerCase) + } else { + (databaseName, tableName) + } + } + /** * Creates any tables required for query execution. * For example, because of a CREATE TABLE X AS statement. @@ -270,7 +292,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) // Get the CreateTableDesc from Hive SemanticAnalyzer - val desc: Option[CreateTableDesc] = if (tableExists(Some(databaseName), tblName)) { + val desc: Option[CreateTableDesc] = if (tableExists(Seq(databaseName, tblName))) { None } else { val sa = new SemanticAnalyzer(hive.hiveconf) { @@ -352,15 +374,13 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. */ - override def registerTable( - databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = ??? + override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = ??? /** * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. */ - override def unregisterTable( - databaseName: Option[String], tableName: String): Unit = ??? + override def unregisterTable(tableIdentifier: Seq[String]): Unit = ??? override def unregisterAllTables() = {} } @@ -386,88 +406,6 @@ private[hive] case class InsertIntoHiveTable( } } -/** - * :: DeveloperApi :: - * Provides conversions between Spark SQL data types and Hive Metastore types. - */ -@DeveloperApi -object HiveMetastoreTypes extends RegexParsers { - protected lazy val primitiveType: Parser[DataType] = - "string" ^^^ StringType | - "float" ^^^ FloatType | - "int" ^^^ IntegerType | - "tinyint" ^^^ ByteType | - "smallint" ^^^ ShortType | - "double" ^^^ DoubleType | - "bigint" ^^^ LongType | - "binary" ^^^ BinaryType | - "boolean" ^^^ BooleanType | - fixedDecimalType | // Hive 0.13+ decimal with precision/scale - "decimal" ^^^ DecimalType.Unlimited | // Hive 0.12 decimal with no precision/scale - "date" ^^^ DateType | - "timestamp" ^^^ TimestampType | - "varchar\\((\\d+)\\)".r ^^^ StringType - - protected lazy val fixedDecimalType: Parser[DataType] = - ("decimal" ~> "(" ~> "\\d+".r) ~ ("," ~> "\\d+".r <~ ")") ^^ { - case precision ~ scale => - DecimalType(precision.toInt, scale.toInt) - } - - protected lazy val arrayType: Parser[DataType] = - "array" ~> "<" ~> dataType <~ ">" ^^ { - case tpe => ArrayType(tpe) - } - - protected lazy val mapType: Parser[DataType] = - "map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { - case t1 ~ _ ~ t2 => MapType(t1, t2) - } - - protected lazy val structField: Parser[StructField] = - "[a-zA-Z0-9_]*".r ~ ":" ~ dataType ^^ { - case name ~ _ ~ tpe => StructField(name, tpe, nullable = true) - } - - protected lazy val structType: Parser[DataType] = - "struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ { - case fields => new StructType(fields) - } - - protected lazy val dataType: Parser[DataType] = - arrayType | - mapType | - structType | - primitiveType - - def toDataType(metastoreType: String): DataType = parseAll(dataType, metastoreType) match { - case Success(result, _) => result - case failure: NoSuccess => sys.error(s"Unsupported dataType: $metastoreType") - } - - def toMetastoreType(dt: DataType): String = dt match { - case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" - case StructType(fields) => - s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>" - case MapType(keyType, valueType, _) => - s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>" - case StringType => "string" - case FloatType => "float" - case IntegerType => "int" - case ByteType => "tinyint" - case ShortType => "smallint" - case DoubleType => "double" - case LongType => "bigint" - case BinaryType => "binary" - case BooleanType => "boolean" - case DateType => "date" - case d: DecimalType => HiveShim.decimalMetastoreString(d) - case TimestampType => "timestamp" - case NullType => "void" - case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) - } -} - private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: TTable, val partitions: Seq[TPartition]) @@ -525,7 +463,7 @@ private[hive] case class MetastoreRelation implicit class SchemaAttribute(f: FieldSchema) { def toAttribute = AttributeReference( f.getName, - HiveMetastoreTypes.toDataType(f.getType), + sqlContext.ddlParser.parseType(f.getType), // Since data can be dumped in randomly with no validation, everything is nullable. nullable = true )(qualifiers = Seq(alias.getOrElse(tableName))) @@ -545,3 +483,27 @@ private[hive] case class MetastoreRelation /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) } + +object HiveMetastoreTypes { + def toMetastoreType(dt: DataType): String = dt match { + case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" + case StructType(fields) => + s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>" + case MapType(keyType, valueType, _) => + s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>" + case StringType => "string" + case FloatType => "float" + case IntegerType => "int" + case ByteType => "tinyint" + case ShortType => "smallint" + case DoubleType => "double" + case LongType => "bigint" + case BinaryType => "binary" + case BooleanType => "boolean" + case DateType => "date" + case d: DecimalType => HiveShim.decimalMetastoreString(d) + case TimestampType => "timestamp" + case NullType => "void" + case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 8a9613cf96e54..34622b5f57873 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -24,8 +24,8 @@ import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils +import org.apache.spark.sql.SparkSQLParser -import org.apache.spark.sql.catalyst.SparkSQLParser import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable} /* Implicit conversions */ @@ -45,6 +46,22 @@ import scala.collection.JavaConversions._ */ private[hive] case object NativePlaceholder extends Command +/** + * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. + * @param table The table to be described. + * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. + * It is effective only when the table is a Hive table. + */ +case class DescribeCommand( + table: LogicalPlan, + isExtended: Boolean) extends Command { + override def output = Seq( + // Column names are based on Hive. + AttributeReference("col_name", StringType, nullable = false)(), + AttributeReference("data_type", StringType, nullable = false)(), + AttributeReference("comment", StringType, nullable = false)()) +} + /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ private[hive] object HiveQl { protected val nativeCommands = Seq( @@ -386,6 +403,15 @@ private[hive] object HiveQl { (db, tableName) } + protected def extractTableIdent(tableNameParts: Node): Seq[String] = { + tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { + case Seq(tableOnly) => Seq(tableOnly) + case Seq(databaseName, table) => Seq(databaseName, table) + case other => sys.error("Hive only supports tables names like 'tableName' " + + s"or 'databaseName.tableName', found '$other'") + } + } + /** * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) * is equivalent to @@ -448,17 +474,23 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Just fake explain for any of the native commands. case Token("TOK_EXPLAIN", explainArgs) if noExplainCommands.contains(explainArgs.head.getText) => - ExplainCommand(NoRelation) + ExplainCommand(NoRelation, Seq(AttributeReference("plan", StringType, nullable = false)())) case Token("TOK_EXPLAIN", explainArgs) if "TOK_CREATETABLE" == explainArgs.head.getText => val Some(crtTbl) :: _ :: extended :: Nil = getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs) - ExplainCommand(nodeToPlan(crtTbl), extended != None) + ExplainCommand( + nodeToPlan(crtTbl), + Seq(AttributeReference("plan", StringType,nullable = false)()), + extended != None) case Token("TOK_EXPLAIN", explainArgs) => // Ignore FORMATTED if present. val Some(query) :: _ :: extended :: Nil = getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) - ExplainCommand(nodeToPlan(query), extended != None) + ExplainCommand( + nodeToPlan(query), + Seq(AttributeReference("plan", StringType, nullable = false)()), + extended != None) case Token("TOK_DESCTABLE", describeArgs) => // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL @@ -475,16 +507,16 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token(".", dbName :: tableName :: Nil) => // It is describing a table with the format like "describe db.table". // TODO: Actually, a user may mean tableName.columnName. Need to resolve this issue. - val (db, tableName) = extractDbNameTableName(nameParts.head) + val tableIdent = extractTableIdent(nameParts.head) DescribeCommand( - UnresolvedRelation(db, tableName, None), extended.isDefined) + UnresolvedRelation(tableIdent, None), extended.isDefined) case Token(".", dbName :: tableName :: colName :: Nil) => // It is describing a column with the format like "describe db.table column". NativePlaceholder case tableName => // It is describing a table with the format like "describe table". DescribeCommand( - UnresolvedRelation(None, tableName.getText, None), + UnresolvedRelation(Seq(tableName.getText), None), extended.isDefined) } } @@ -757,13 +789,15 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C nonAliasClauses) } - val (db, tableName) = + val tableIdent = tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { - case Seq(tableOnly) => (None, tableOnly) - case Seq(databaseName, table) => (Some(databaseName), table) + case Seq(tableOnly) => Seq(tableOnly) + case Seq(databaseName, table) => Seq(databaseName, table) + case other => sys.error("Hive only supports tables names like 'tableName' " + + s"or 'databaseName.tableName', found '$other'") } val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) } - val relation = UnresolvedRelation(db, tableName, alias) + val relation = UnresolvedRelation(tableIdent, alias) // Apply sampling if requested. (bucketSampleClause orElse splitSampleClause).map { @@ -882,7 +916,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val Some(tableNameParts) :: partitionClause :: Nil = getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - val (db, tableName) = extractDbNameTableName(tableNameParts) + val tableIdent = extractTableIdent(tableNameParts) val partitionKeys = partitionClause.map(_.getChildren.map { // Parse partitions. We also make keys case insensitive. @@ -892,7 +926,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C cleanIdentifier(key.toLowerCase) -> None }.toMap).getOrElse(Map.empty) - InsertIntoTable(UnresolvedRelation(db, tableName, None), partitionKeys, query, overwrite) + InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite) case a: ASTNode => throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") @@ -1077,6 +1111,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token(AND(), left :: right:: Nil) => And(nodeToExpr(left), nodeToExpr(right)) case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right)) case Token(NOT(), child :: Nil) => Not(nodeToExpr(child)) + case Token("!", child :: Nil) => Not(nodeToExpr(child)) /* Case statements */ case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index d3f6381b69a4d..c439b9ebfe104 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.types.StringType +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive import org.apache.spark.sql.hive.execution._ @@ -209,14 +210,14 @@ private[hive] trait HiveStrategies { case class HiveCommandStrategy(context: HiveContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case describe: logical.DescribeCommand => + case describe: DescribeCommand => val resolvedTable = context.executePlan(describe.table).analyzed resolvedTable match { case t: MetastoreRelation => ExecutedCommand( DescribeHiveTableCommand(t, describe.output, describe.isExtended)) :: Nil case o: LogicalPlan => - ExecutedCommand(DescribeCommand(planLater(o), describe.output)) :: Nil + ExecutedCommand(RunnableDescribeCommand(planLater(o), describe.output)) :: Nil } case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index b2149bd95a336..1358a0eccb353 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -34,8 +34,9 @@ import org.apache.hadoop.hive.serde2.avro.AvroSerDe import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.plans.logical.{CacheTableCommand, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.execution.HiveNativeCommand @@ -167,7 +168,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { // Make sure any test tables referenced are loaded. val referencedTables = describedTables ++ - logical.collect { case UnresolvedRelation(databaseName, name, _) => name } + logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.last } val referencedTestTables = referencedTables.filter(testTables.contains) logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(loadTestTable) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index fe21454e7fb38..a547babcebfff 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -53,14 +53,14 @@ case class CreateTableAsSelect( hiveContext.catalog.createTable(database, tableName, query.output, allowExisting, desc) // Get the Metastore Relation - hiveContext.catalog.lookupRelation(Some(database), tableName, None) match { + hiveContext.catalog.lookupRelation(Seq(database, tableName), None) match { case r: MetastoreRelation => r } } // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. - if (hiveContext.catalog.tableExists(Some(database), tableName)) { + if (hiveContext.catalog.tableExists(Seq(database, tableName))) { if (allowExisting) { // table already exists, will do nothing, to keep consistent with Hive } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 6fc4153f6a5df..6b733a280e6d5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -53,7 +53,7 @@ case class DropTable( val hiveContext = sqlContext.asInstanceOf[HiveContext] val ifExistsClause = if (ifExists) "IF EXISTS " else "" hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") - hiveContext.catalog.unregisterTable(None, tableName) + hiveContext.catalog.unregisterTable(Seq(tableName)) Seq.empty[Row] } } diff --git a/sql/hive/src/test/resources/golden/! operator-0-81d1a187c7f4a6337baf081510a5dc5e b/sql/hive/src/test/resources/golden/! operator-0-81d1a187c7f4a6337baf081510a5dc5e new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/! operator-0-81d1a187c7f4a6337baf081510a5dc5e @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 2060e1f1a7a4b..f95a6b43af357 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -158,4 +158,10 @@ class CachedTableSuite extends QueryTest { uncacheTable("src") assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } + + test("CACHE TABLE with Hive UDF") { + sql("CACHE TABLE udfTest AS SELECT * FROM src WHERE floor(key) = 1") + assertCached(table("udfTest")) + uncacheTable("udfTest") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 86535f8dd4f58..041a36f1295ef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.types.StructType +import org.apache.spark.sql.sources.DDLParser import org.apache.spark.sql.test.ExamplePointUDT class HiveMetastoreCatalogSuite extends FunSuite { @@ -27,7 +28,9 @@ class HiveMetastoreCatalogSuite extends FunSuite { test("struct field should accept underscore in sub-column name") { val metastr = "struct" - val datatype = HiveMetastoreTypes.toDataType(metastr) + val ddlParser = new DDLParser + + val datatype = ddlParser.parseType(metastr) assert(datatype.isInstanceOf[StructType]) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 4b6a9308b9811..a758f921e0417 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -72,7 +72,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { test("analyze MetastoreRelations") { def queryTotalSize(tableName: String): BigInt = - catalog.lookupRelation(None, tableName).statistics.sizeInBytes + catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() @@ -123,7 +123,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { intercept[NotImplementedError] { analyze("tempTable") } - catalog.unregisterTable(None, "tempTable") + catalog.unregisterTable(Seq("tempTable")) } test("estimates the size of a test MetastoreRelation") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 4104df8f8e022..f8a957d55d57e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -22,6 +22,8 @@ import java.io._ import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} import org.apache.spark.Logging +import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} +import org.apache.spark.sql.hive.DescribeCommand import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index fb6da33e88ef6..700a45edb11d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -62,6 +62,14 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("SHOW TABLES") } } + + createQueryTest("! operator", + """ + |SELECT a FROM ( + | SELECT 1 AS a FROM src LIMIT 1 UNION ALL + | SELECT 2 AS a FROM src LIMIT 1) table + |WHERE !(a>1) + """.stripMargin) createQueryTest("constant object inspector for generic udf", """SELECT named_struct( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 5d0fb7237011f..c1c3683f84ab2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.Row import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.util.Utils case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) @@ -202,4 +203,15 @@ class SQLQuerySuite extends QueryTest { checkAnswer(sql("SELECT sum( distinct key) FROM src group by key order by key"), sql("SELECT distinct key FROM src order by key").collect().toSeq) } + + test("SPARK-4963 SchemaRDD sample on mutable row return wrong result") { + sql("SELECT * FROM src WHERE key % 2 = 0") + .sample(withReplacement = false, fraction = 0.3) + .registerTempTable("sampled") + (1 to 10).foreach { i => + checkAnswer( + sql("SELECT * FROM sampled WHERE key % 2 = 1"), + Seq.empty[Row]) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala index fc0e42c201d56..8bbb7f2fdbf48 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala @@ -174,6 +174,18 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll } Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table => + test(s"ordering of the partitioning columns $table") { + checkAnswer( + sql(s"SELECT p, stringField FROM $table WHERE p = 1"), + Seq.fill(10)((1, "part-1")) + ) + + checkAnswer( + sql(s"SELECT stringField, p FROM $table WHERE p = 1"), + Seq.fill(10)(("part-1", 1)) + ) + } + test(s"project the partitioning column $table") { checkAnswer( sql(s"SELECT p, count(*) FROM $table group by p"), diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 2d01a85067518..25fdf5c5f3da6 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -44,7 +44,7 @@ import scala.language.implicitConversions import org.apache.spark.sql.catalyst.types.DecimalType -class HiveFunctionWrapper(var functionClassName: String) extends java.io.Serializable { +case class HiveFunctionWrapper(functionClassName: String) extends java.io.Serializable { // for Serialization def this() = this(null) diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index b78c75798e988..e47002cb0b8c8 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -53,7 +53,7 @@ import scala.language.implicitConversions * * @param functionClassName UDF class name */ -class HiveFunctionWrapper(var functionClassName: String) extends java.io.Externalizable { +case class HiveFunctionWrapper(var functionClassName: String) extends java.io.Externalizable { // for Serialization def this() = this(null) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index c834744631e02..afd3c4bc4c4fe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -86,7 +86,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont }.toArray // Since storeInBlockManager = false, the storage level does not matter. new WriteAheadLogBackedBlockRDD[T](ssc.sparkContext, - blockIds, logSegments, storeInBlockManager = true, StorageLevel.MEMORY_ONLY_SER) + blockIds, logSegments, storeInBlockManager = false, StorageLevel.MEMORY_ONLY_SER) } else { new BlockRDD[T](ssc.sc, blockIds) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 2ce458cddec1a..c3d9d7b6813d3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -203,9 +203,11 @@ private[streaming] class ReceivedBlockTracker( /** Write an update to the tracker to the write ahead log */ private def writeToLog(record: ReceivedBlockTrackerLogEvent) { - logDebug(s"Writing to log $record") - logManagerOption.foreach { logManager => + if (isLogManagerEnabled) { + logDebug(s"Writing to log $record") + logManagerOption.foreach { logManager => logManager.writeToLog(ByteBuffer.wrap(Utils.serialize(record))) + } } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index c363d755c1752..032106371cd60 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -65,7 +65,7 @@ private[spark] class Client( private val amMemoryOverhead = args.amMemoryOverhead // MB private val executorMemoryOverhead = args.executorMemoryOverhead // MB private val distCacheMgr = new ClientDistributedCacheManager() - private val isClusterMode = args.userClass != null + private val isClusterMode = args.isClusterMode def stop(): Unit = yarnClient.stop() diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 39f1021c9d942..fdbf9f8eed029 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -38,23 +38,27 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var amMemory: Int = 512 // MB var appName: String = "Spark" var priority = 0 + def isClusterMode: Boolean = userClass != null + + private var driverMemory: Int = 512 // MB + private val driverMemOverheadKey = "spark.yarn.driver.memoryOverhead" + private val amMemKey = "spark.yarn.am.memory" + private val amMemOverheadKey = "spark.yarn.am.memoryOverhead" + private val isDynamicAllocationEnabled = + sparkConf.getBoolean("spark.dynamicAllocation.enabled", false) parseArgs(args.toList) + loadEnvironmentArgs() + validateArgs() // Additional memory to allocate to containers - // For now, use driver's memory overhead as our AM container's memory overhead - val amMemoryOverhead = sparkConf.getInt("spark.yarn.driver.memoryOverhead", + val amMemoryOverheadConf = if (isClusterMode) driverMemOverheadKey else amMemOverheadKey + val amMemoryOverhead = sparkConf.getInt(amMemoryOverheadConf, math.max((MEMORY_OVERHEAD_FACTOR * amMemory).toInt, MEMORY_OVERHEAD_MIN)) val executorMemoryOverhead = sparkConf.getInt("spark.yarn.executor.memoryOverhead", math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)) - private val isDynamicAllocationEnabled = - sparkConf.getBoolean("spark.dynamicAllocation.enabled", false) - - loadEnvironmentArgs() - validateArgs() - /** Load any default arguments provided through environment variables and Spark properties. */ private def loadEnvironmentArgs(): Unit = { // For backward compatibility, SPARK_YARN_DIST_{ARCHIVES/FILES} should be resolved to hdfs://, @@ -87,6 +91,21 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) throw new IllegalArgumentException( "You must specify at least 1 executor!\n" + getUsageMessage()) } + if (isClusterMode) { + for (key <- Seq(amMemKey, amMemOverheadKey)) { + if (sparkConf.contains(key)) { + println(s"$key is set but does not apply in cluster mode.") + } + } + amMemory = driverMemory + } else { + if (sparkConf.contains(driverMemOverheadKey)) { + println(s"$driverMemOverheadKey is set but does not apply in client mode.") + } + sparkConf.getOption(amMemKey) + .map(Utils.memoryStringToMb) + .foreach { mem => amMemory = mem } + } } private def parseArgs(inputArgs: List[String]): Unit = { @@ -118,7 +137,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) if (args(0) == "--master-memory") { println("--master-memory is deprecated. Use --driver-memory instead.") } - amMemory = value + driverMemory = value args = tail case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail => diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index ebf5616e8d303..c537da9f67552 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -148,12 +148,9 @@ class ExecutorRunnable( // registers with the Scheduler and transfers the spark configs. Since the Executor backend // uses Akka to connect to the scheduler, the akka settings are needed as well as the // authentication settings. - sparkConf.getAll. - filter { case (k, v) => k.startsWith("spark.auth") || k.startsWith("spark.akka") }. - foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } - - sparkConf.getAkkaConf. - foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } + sparkConf.getAll + .filter { case (k, v) => SparkConf.isExecutorStartupConf(k) } + .foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } // Commenting it out for now - so that people can refer to the properties if required. Remove // it once cpuset version is pushed out. diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 09597bd0e6ab9..f99291553b7b8 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -68,8 +68,6 @@ private[spark] class YarnClientSchedulerBackend( // List of (target Client argument, environment variable, Spark property) val optionTuples = List( - ("--driver-memory", "SPARK_MASTER_MEMORY", "spark.master.memory"), - ("--driver-memory", "SPARK_DRIVER_MEMORY", "spark.driver.memory"), ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.executor.instances"), ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"), ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"),
Property NameDefaultMeaning
spark.yarn.am.memory512m + Amount of memory to use for the YARN Application Master in client mode, in the same format as JVM memory strings (e.g. 512m, 2g). + In cluster mode, use spark.driver.memory instead. +
spark.yarn.am.waitTime 100000spark.yarn.driver.memoryOverhead driverMemory * 0.07, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per driver. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). + The amount of off heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). +
spark.yarn.am.memoryOverheadAM memory * 0.07, with minimum of 384 + Same as spark.yarn.driver.memoryOverhead, but for the Application Master in client mode.
spark.yarn.am.extraJavaOptions (none) - A string of extra JVM options to pass to the Yarn ApplicationMaster in client mode. + A string of extra JVM options to pass to the YARN Application Master in client mode. In cluster mode, use spark.driver.extraJavaOptions instead.