diff --git a/bin/spark-submit b/bin/spark-submit
index aefd38a0a2b90..3e5cbdbb24394 100755
--- a/bin/spark-submit
+++ b/bin/spark-submit
@@ -44,7 +44,10 @@ while (($#)); do
shift
done
-DEFAULT_PROPERTIES_FILE="$SPARK_HOME/conf/spark-defaults.conf"
+if [ -z "$SPARK_CONF_DIR" ]; then
+ export SPARK_CONF_DIR="$SPARK_HOME/conf"
+fi
+DEFAULT_PROPERTIES_FILE="$SPARK_CONF_DIR/spark-defaults.conf"
if [ "$MASTER" == "yarn-cluster" ]; then
SPARK_SUBMIT_DEPLOY_MODE=cluster
fi
diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd
index daf0284db9230..12244a9cb04fb 100644
--- a/bin/spark-submit2.cmd
+++ b/bin/spark-submit2.cmd
@@ -24,7 +24,11 @@ set ORIG_ARGS=%*
rem Reset the values of all variables used
set SPARK_SUBMIT_DEPLOY_MODE=client
-set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf
+
+if not defined %SPARK_CONF_DIR% (
+ set SPARK_CONF_DIR=%SPARK_HOME%\conf
+)
+set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_CONF_DIR%\spark-defaults.conf
set SPARK_SUBMIT_DRIVER_MEMORY=
set SPARK_SUBMIT_LIBRARY_PATH=
set SPARK_SUBMIT_CLASSPATH=
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index e9e90e3f2f65a..a0ee2a7cbb2a2 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -65,6 +65,9 @@ private[spark] class ExecutorAllocationManager(
listenerBus: LiveListenerBus,
conf: SparkConf)
extends Logging {
+
+ allocationManager =>
+
import ExecutorAllocationManager._
// Lower and upper bounds on the number of executors. These are required.
@@ -121,7 +124,7 @@ private[spark] class ExecutorAllocationManager(
private var clock: Clock = new RealClock
// Listener for Spark events that impact the allocation policy
- private val listener = new ExecutorAllocationListener(this)
+ private val listener = new ExecutorAllocationListener
/**
* Verify that the settings specified through the config are valid.
@@ -209,11 +212,12 @@ private[spark] class ExecutorAllocationManager(
addTime += sustainedSchedulerBacklogTimeout * 1000
}
- removeTimes.foreach { case (executorId, expireTime) =>
- if (now >= expireTime) {
+ removeTimes.retain { case (executorId, expireTime) =>
+ val expired = now >= expireTime
+ if (expired) {
removeExecutor(executorId)
- removeTimes.remove(executorId)
}
+ !expired
}
}
@@ -291,7 +295,7 @@ private[spark] class ExecutorAllocationManager(
// Do not kill the executor if we have already reached the lower bound
val numExistingExecutors = executorIds.size - executorsPendingToRemove.size
if (numExistingExecutors - 1 < minNumExecutors) {
- logInfo(s"Not removing idle executor $executorId because there are only " +
+ logDebug(s"Not removing idle executor $executorId because there are only " +
s"$numExistingExecutors executor(s) left (limit $minNumExecutors)")
return false
}
@@ -315,7 +319,11 @@ private[spark] class ExecutorAllocationManager(
private def onExecutorAdded(executorId: String): Unit = synchronized {
if (!executorIds.contains(executorId)) {
executorIds.add(executorId)
- executorIds.foreach(onExecutorIdle)
+ // If an executor (call this executor X) is not removed because the lower bound
+ // has been reached, it will no longer be marked as idle. When new executors join,
+ // however, we are no longer at the lower bound, and so we must mark executor X
+ // as idle again so as not to forget that it is a candidate for removal. (see SPARK-4951)
+ executorIds.filter(listener.isExecutorIdle).foreach(onExecutorIdle)
logInfo(s"New executor $executorId has registered (new total is ${executorIds.size})")
if (numExecutorsPending > 0) {
numExecutorsPending -= 1
@@ -373,10 +381,14 @@ private[spark] class ExecutorAllocationManager(
* the executor is not already marked as idle.
*/
private def onExecutorIdle(executorId: String): Unit = synchronized {
- if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) {
- logDebug(s"Starting idle timer for $executorId because there are no more tasks " +
- s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)")
- removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000
+ if (executorIds.contains(executorId)) {
+ if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) {
+ logDebug(s"Starting idle timer for $executorId because there are no more tasks " +
+ s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)")
+ removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000
+ }
+ } else {
+ logWarning(s"Attempted to mark unknown executor $executorId idle")
}
}
@@ -396,25 +408,24 @@ private[spark] class ExecutorAllocationManager(
* and consistency of events returned by the listener. For simplicity, it does not account
* for speculated tasks.
*/
- private class ExecutorAllocationListener(allocationManager: ExecutorAllocationManager)
- extends SparkListener {
+ private class ExecutorAllocationListener extends SparkListener {
private val stageIdToNumTasks = new mutable.HashMap[Int, Int]
private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]]
private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]]
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
- synchronized {
- val stageId = stageSubmitted.stageInfo.stageId
- val numTasks = stageSubmitted.stageInfo.numTasks
+ val stageId = stageSubmitted.stageInfo.stageId
+ val numTasks = stageSubmitted.stageInfo.numTasks
+ allocationManager.synchronized {
stageIdToNumTasks(stageId) = numTasks
allocationManager.onSchedulerBacklogged()
}
}
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
- synchronized {
- val stageId = stageCompleted.stageInfo.stageId
+ val stageId = stageCompleted.stageInfo.stageId
+ allocationManager.synchronized {
stageIdToNumTasks -= stageId
stageIdToTaskIndices -= stageId
@@ -426,39 +437,49 @@ private[spark] class ExecutorAllocationManager(
}
}
- override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized {
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
val stageId = taskStart.stageId
val taskId = taskStart.taskInfo.taskId
val taskIndex = taskStart.taskInfo.index
val executorId = taskStart.taskInfo.executorId
- // If this is the last pending task, mark the scheduler queue as empty
- stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex
- val numTasksScheduled = stageIdToTaskIndices(stageId).size
- val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1)
- if (numTasksScheduled == numTasksTotal) {
- // No more pending tasks for this stage
- stageIdToNumTasks -= stageId
- if (stageIdToNumTasks.isEmpty) {
- allocationManager.onSchedulerQueueEmpty()
+ allocationManager.synchronized {
+ // This guards against the race condition in which the `SparkListenerTaskStart`
+ // event is posted before the `SparkListenerBlockManagerAdded` event, which is
+ // possible because these events are posted in different threads. (see SPARK-4951)
+ if (!allocationManager.executorIds.contains(executorId)) {
+ allocationManager.onExecutorAdded(executorId)
+ }
+
+ // If this is the last pending task, mark the scheduler queue as empty
+ stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex
+ val numTasksScheduled = stageIdToTaskIndices(stageId).size
+ val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1)
+ if (numTasksScheduled == numTasksTotal) {
+ // No more pending tasks for this stage
+ stageIdToNumTasks -= stageId
+ if (stageIdToNumTasks.isEmpty) {
+ allocationManager.onSchedulerQueueEmpty()
+ }
}
- }
- // Mark the executor on which this task is scheduled as busy
- executorIdToTaskIds.getOrElseUpdate(executorId, new mutable.HashSet[Long]) += taskId
- allocationManager.onExecutorBusy(executorId)
+ // Mark the executor on which this task is scheduled as busy
+ executorIdToTaskIds.getOrElseUpdate(executorId, new mutable.HashSet[Long]) += taskId
+ allocationManager.onExecutorBusy(executorId)
+ }
}
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
val executorId = taskEnd.taskInfo.executorId
val taskId = taskEnd.taskInfo.taskId
-
- // If the executor is no longer running scheduled any tasks, mark it as idle
- if (executorIdToTaskIds.contains(executorId)) {
- executorIdToTaskIds(executorId) -= taskId
- if (executorIdToTaskIds(executorId).isEmpty) {
- executorIdToTaskIds -= executorId
- allocationManager.onExecutorIdle(executorId)
+ allocationManager.synchronized {
+ // If the executor is no longer running scheduled any tasks, mark it as idle
+ if (executorIdToTaskIds.contains(executorId)) {
+ executorIdToTaskIds(executorId) -= taskId
+ if (executorIdToTaskIds(executorId).isEmpty) {
+ executorIdToTaskIds -= executorId
+ allocationManager.onExecutorIdle(executorId)
+ }
}
}
}
@@ -466,7 +487,12 @@ private[spark] class ExecutorAllocationManager(
override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = {
val executorId = blockManagerAdded.blockManagerId.executorId
if (executorId != SparkContext.DRIVER_IDENTIFIER) {
- allocationManager.onExecutorAdded(executorId)
+ // This guards against the race condition in which the `SparkListenerTaskStart`
+ // event is posted before the `SparkListenerBlockManagerAdded` event, which is
+ // possible because these events are posted in different threads. (see SPARK-4951)
+ if (!allocationManager.executorIds.contains(executorId)) {
+ allocationManager.onExecutorAdded(executorId)
+ }
}
}
@@ -478,12 +504,23 @@ private[spark] class ExecutorAllocationManager(
/**
* An estimate of the total number of pending tasks remaining for currently running stages. Does
* not account for tasks which may have failed and been resubmitted.
+ *
+ * Note: This is not thread-safe without the caller owning the `allocationManager` lock.
*/
def totalPendingTasks(): Int = {
stageIdToNumTasks.map { case (stageId, numTasks) =>
numTasks - stageIdToTaskIndices.get(stageId).map(_.size).getOrElse(0)
}.sum
}
+
+ /**
+ * Return true if an executor is not currently running a task, and false otherwise.
+ *
+ * Note: This is not thread-safe without the caller owning the `allocationManager` lock.
+ */
+ def isExecutorIdle(executorId: String): Boolean = {
+ !executorIdToTaskIds.contains(executorId)
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
index edc3889c9ae51..677c5e0f89d72 100644
--- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -24,6 +24,7 @@ import com.google.common.io.Files
import org.apache.spark.util.Utils
private[spark] class HttpFileServer(
+ conf: SparkConf,
securityManager: SecurityManager,
requestedPort: Int = 0)
extends Logging {
@@ -41,7 +42,7 @@ private[spark] class HttpFileServer(
fileDir.mkdir()
jarDir.mkdir()
logInfo("HTTP File server directory is " + baseDir)
- httpServer = new HttpServer(baseDir, securityManager, requestedPort, "HTTP file server")
+ httpServer = new HttpServer(conf, baseDir, securityManager, requestedPort, "HTTP file server")
httpServer.start()
serverUri = httpServer.uri
logDebug("HTTP file server started at: " + serverUri)
diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala
index 912558d0cab7d..fa22787ce7ea3 100644
--- a/core/src/main/scala/org/apache/spark/HttpServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpServer.scala
@@ -42,6 +42,7 @@ private[spark] class ServerStateException(message: String) extends Exception(mes
* around a Jetty server.
*/
private[spark] class HttpServer(
+ conf: SparkConf,
resourceBase: File,
securityManager: SecurityManager,
requestedPort: Int = 0,
@@ -57,7 +58,7 @@ private[spark] class HttpServer(
} else {
logInfo("Starting HTTP Server")
val (actualServer, actualPort) =
- Utils.startServiceOnPort[Server](requestedPort, doStart, serverName)
+ Utils.startServiceOnPort[Server](requestedPort, doStart, conf, serverName)
server = actualServer
port = actualPort
}
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index c14764f773982..a0ce107f43b16 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -370,7 +370,9 @@ private[spark] object SparkConf {
}
/**
- * Return whether the given config is a Spark port config.
+ * Return true if the given config matches either `spark.*.port` or `spark.port.*`.
*/
- def isSparkPortConf(name: String): Boolean = name.startsWith("spark.") && name.endsWith(".port")
+ def isSparkPortConf(name: String): Boolean = {
+ (name.startsWith("spark.") && name.endsWith(".port")) || name.startsWith("spark.port.")
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 3bf3acd245d8f..ff5d796ee2766 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -458,7 +458,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
Option(localProperties.get).map(_.getProperty(key)).getOrElse(null)
/** Set a human readable description of the current job. */
- @deprecated("use setJobGroup", "0.8.1")
def setJobDescription(value: String) {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 43436a1697000..4d418037bd33f 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -312,7 +312,7 @@ object SparkEnv extends Logging {
val httpFileServer =
if (isDriver) {
val fileServerPort = conf.getInt("spark.fileserver.port", 0)
- val server = new HttpFileServer(securityManager, fileServerPort)
+ val server = new HttpFileServer(conf, securityManager, fileServerPort)
server.initialize()
conf.set("spark.fileserver.uri", server.serverUri)
server
diff --git a/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala b/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala
new file mode 100644
index 0000000000000..9df61062e1f85
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * Exception thrown when a task cannot be serialized.
+ */
+private[spark] class TaskNotSerializableException(error: Throwable) extends Exception(error)
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 31f0a462f84d8..31d6958c403b3 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -153,7 +153,8 @@ private[broadcast] object HttpBroadcast extends Logging {
private def createServer(conf: SparkConf) {
broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf))
val broadcastPort = conf.getInt("spark.broadcast.port", 0)
- server = new HttpServer(broadcastDir, securityManager, broadcastPort, "HTTP broadcast server")
+ server =
+ new HttpServer(conf, broadcastDir, securityManager, broadcastPort, "HTTP broadcast server")
server.start()
serverUri = server.uri
logInfo("Broadcast server started at " + serverUri)
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 1faabe91f49a8..f14ef4d299383 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -405,7 +405,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
| --queue QUEUE_NAME The YARN queue to submit to (Default: "default").
| --num-executors NUM Number of executors to launch (Default: 2).
| --archives ARCHIVES Comma separated list of archives to be extracted into the
- | working directory of each executor.""".stripMargin
+ | working directory of each executor.
+ """.stripMargin
)
SparkSubmit.exitFn()
}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index 3340fca08014e..03c4137ca0a81 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -174,7 +174,7 @@ private[nio] class ConnectionManager(
serverChannel.socket.bind(new InetSocketAddress(port))
(serverChannel, serverChannel.socket.getLocalPort)
}
- Utils.startServiceOnPort[ServerSocketChannel](port, startService, name)
+ Utils.startServiceOnPort[ServerSocketChannel](port, startService, conf, name)
serverChannel.register(selector, SelectionKey.OP_ACCEPT)
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 259621d263d7c..61d09d73e17cb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -866,26 +866,6 @@ class DAGScheduler(
}
if (tasks.size > 0) {
- // Preemptively serialize a task to make sure it can be serialized. We are catching this
- // exception here because it would be fairly hard to catch the non-serializable exception
- // down the road, where we have several different implementations for local scheduler and
- // cluster schedulers.
- //
- // We've already serialized RDDs and closures in taskBinary, but here we check for all other
- // objects such as Partition.
- try {
- closureSerializer.serialize(tasks.head)
- } catch {
- case e: NotSerializableException =>
- abortStage(stage, "Task not serializable: " + e.toString)
- runningStages -= stage
- return
- case NonFatal(e) => // Other exceptions, such as IllegalArgumentException from Kryo.
- abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}")
- runningStages -= stage
- return
- }
-
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
stage.pendingTasks ++= tasks
logDebug("New pending tasks: " + stage.pendingTasks)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index a41f3eef195d2..a1dfb01062591 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -31,6 +31,7 @@ import scala.util.Random
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
+import org.apache.spark.scheduler.TaskLocality.TaskLocality
import org.apache.spark.util.Utils
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
@@ -209,6 +210,40 @@ private[spark] class TaskSchedulerImpl(
.format(manager.taskSet.id, manager.parent.name))
}
+ private def resourceOfferSingleTaskSet(
+ taskSet: TaskSetManager,
+ maxLocality: TaskLocality,
+ shuffledOffers: Seq[WorkerOffer],
+ availableCpus: Array[Int],
+ tasks: Seq[ArrayBuffer[TaskDescription]]) : Boolean = {
+ var launchedTask = false
+ for (i <- 0 until shuffledOffers.size) {
+ val execId = shuffledOffers(i).executorId
+ val host = shuffledOffers(i).host
+ if (availableCpus(i) >= CPUS_PER_TASK) {
+ try {
+ for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
+ tasks(i) += task
+ val tid = task.taskId
+ taskIdToTaskSetId(tid) = taskSet.taskSet.id
+ taskIdToExecutorId(tid) = execId
+ executorsByHost(host) += execId
+ availableCpus(i) -= CPUS_PER_TASK
+ assert(availableCpus(i) >= 0)
+ launchedTask = true
+ }
+ } catch {
+ case e: TaskNotSerializableException =>
+ logError(s"Resource offer failed, task set ${taskSet.name} was not serializable")
+ // Do not offer resources for this task, but don't throw an error to allow other
+ // task sets to be submitted.
+ return launchedTask
+ }
+ }
+ }
+ return launchedTask
+ }
+
/**
* Called by cluster manager to offer resources on slaves. We respond by asking our active task
* sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so
@@ -251,23 +286,8 @@ private[spark] class TaskSchedulerImpl(
var launchedTask = false
for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) {
do {
- launchedTask = false
- for (i <- 0 until shuffledOffers.size) {
- val execId = shuffledOffers(i).executorId
- val host = shuffledOffers(i).host
- if (availableCpus(i) >= CPUS_PER_TASK) {
- for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
- tasks(i) += task
- val tid = task.taskId
- taskIdToTaskSetId(tid) = taskSet.taskSet.id
- taskIdToExecutorId(tid) = execId
- executorsByHost(host) += execId
- availableCpus(i) -= CPUS_PER_TASK
- assert(availableCpus(i) >= 0)
- launchedTask = true
- }
- }
- }
+ launchedTask = resourceOfferSingleTaskSet(
+ taskSet, maxLocality, shuffledOffers, availableCpus, tasks)
} while (launchedTask)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 28e6147509f78..4667850917151 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -18,12 +18,14 @@
package org.apache.spark.scheduler
import java.io.NotSerializableException
+import java.nio.ByteBuffer
import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.math.{min, max}
+import scala.util.control.NonFatal
import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
@@ -417,6 +419,7 @@ private[spark] class TaskSetManager(
* @param host the host Id of the offered resource
* @param maxLocality the maximum locality we want to schedule the tasks at
*/
+ @throws[TaskNotSerializableException]
def resourceOffer(
execId: String,
host: String,
@@ -456,10 +459,17 @@ private[spark] class TaskSetManager(
}
// Serialize and return the task
val startTime = clock.getTime()
- // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
- // we assume the task can be serialized without exceptions.
- val serializedTask = Task.serializeWithDependencies(
- task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ val serializedTask: ByteBuffer = try {
+ Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ } catch {
+ // If the task cannot be serialized, then there's no point to re-attempt the task,
+ // as it will always fail. So just abort the whole task-set.
+ case NonFatal(e) =>
+ val msg = s"Failed to serialize task $taskId, not attempting to retry it."
+ logError(msg, e)
+ abort(s"$msg Exception during serialization: $e")
+ throw new TaskNotSerializableException(e)
+ }
if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 &&
!emittedTaskSizeWarning) {
emittedTaskSizeWarning = true
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index 50721b9d6cd6c..f14aaeea0a25c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -17,6 +17,8 @@
package org.apache.spark.scheduler.cluster
+import scala.concurrent.{Future, ExecutionContext}
+
import akka.actor.{Actor, ActorRef, Props}
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
@@ -24,7 +26,9 @@ import org.apache.spark.SparkContext
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.ui.JettyUtils
-import org.apache.spark.util.AkkaUtils
+import org.apache.spark.util.{AkkaUtils, Utils}
+
+import scala.util.control.NonFatal
/**
* Abstract Yarn scheduler backend that contains common logic
@@ -97,6 +101,9 @@ private[spark] abstract class YarnSchedulerBackend(
private class YarnSchedulerActor extends Actor {
private var amActor: Option[ActorRef] = None
+ implicit val askAmActorExecutor = ExecutionContext.fromExecutor(
+ Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-executor"))
+
override def preStart(): Unit = {
// Listen for disassociation events
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
@@ -110,7 +117,12 @@ private[spark] abstract class YarnSchedulerBackend(
case r: RequestExecutors =>
amActor match {
case Some(actor) =>
- sender ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout)
+ val driverActor = sender
+ Future {
+ driverActor ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout)
+ } onFailure {
+ case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e)
+ }
case None =>
logWarning("Attempted to request executors before the AM has registered!")
sender ! false
@@ -119,7 +131,12 @@ private[spark] abstract class YarnSchedulerBackend(
case k: KillExecutors =>
amActor match {
case Some(actor) =>
- sender ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout)
+ val driverActor = sender
+ Future {
+ driverActor ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout)
+ } onFailure {
+ case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e)
+ }
case None =>
logWarning("Attempted to kill executors before the AM has registered!")
sender ! false
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index d2947dcea4f7c..d56e23ce4478a 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -29,7 +29,7 @@ import org.apache.spark._
import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.broadcast.HttpBroadcast
import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock}
-import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus}
import org.apache.spark.storage._
import org.apache.spark.util.BoundedPriorityQueue
import org.apache.spark.util.collection.CompactBuffer
@@ -207,7 +207,8 @@ private[serializer] object KryoSerializer {
classOf[PutBlock],
classOf[GotBlock],
classOf[GetBlock],
- classOf[MapStatus],
+ classOf[CompressedMapStatus],
+ classOf[HighlyCompressedMapStatus],
classOf[CompactBuffer[_]],
classOf[BlockManagerId],
classOf[Array[Byte]],
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index 8dadf6794039e..61ef5ff168791 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -31,7 +31,8 @@ import org.apache.spark.util.Utils
private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager)
extends BlockStore(blockManager) with Logging {
- val minMemoryMapBytes = blockManager.conf.getLong("spark.storage.memoryMapThreshold", 2 * 4096L)
+ val minMemoryMapBytes = blockManager.conf.getLong(
+ "spark.storage.memoryMapThreshold", 2 * 1024L * 1024L)
override def getSize(blockId: BlockId): Long = {
diskManager.getFile(blockId.name).length
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 2a27d49d2de05..88fed833f922d 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -201,7 +201,7 @@ private[spark] object JettyUtils extends Logging {
}
}
- val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, serverName)
+ val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName)
ServerInfo(server, boundPort, collection)
}
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index db2531dc171f8..4c9b1e3c46f0f 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -53,7 +53,7 @@ private[spark] object AkkaUtils extends Logging {
val startService: Int => (ActorSystem, Int) = { actualPort =>
doCreateActorSystem(name, host, actualPort, conf, securityManager)
}
- Utils.startServiceOnPort(port, startService, name)
+ Utils.startServiceOnPort(port, startService, conf, name)
}
private def doCreateActorSystem(
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index c4f1898a2db15..2c04e4ddfbcb7 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -701,7 +701,7 @@ private[spark] object Utils extends Logging {
}
}
- private var customHostname: Option[String] = None
+ private var customHostname: Option[String] = sys.env.get("SPARK_LOCAL_HOSTNAME")
/**
* Allow setting a custom host name because when we run on Mesos we need to use the same
@@ -1690,17 +1690,15 @@ private[spark] object Utils extends Logging {
}
/**
- * Default maximum number of retries when binding to a port before giving up.
+ * Maximum number of retries when binding to a port before giving up.
*/
- val portMaxRetries: Int = {
- if (sys.props.contains("spark.testing")) {
+ def portMaxRetries(conf: SparkConf): Int = {
+ val maxRetries = conf.getOption("spark.port.maxRetries").map(_.toInt)
+ if (conf.contains("spark.testing")) {
// Set a higher number of retries for tests...
- sys.props.get("spark.port.maxRetries").map(_.toInt).getOrElse(100)
+ maxRetries.getOrElse(100)
} else {
- Option(SparkEnv.get)
- .flatMap(_.conf.getOption("spark.port.maxRetries"))
- .map(_.toInt)
- .getOrElse(16)
+ maxRetries.getOrElse(16)
}
}
@@ -1709,17 +1707,18 @@ private[spark] object Utils extends Logging {
* Each subsequent attempt uses 1 + the port used in the previous attempt (unless the port is 0).
*
* @param startPort The initial port to start the service on.
- * @param maxRetries Maximum number of retries to attempt.
- * A value of 3 means attempting ports n, n+1, n+2, and n+3, for example.
* @param startService Function to start service on a given port.
* This is expected to throw java.net.BindException on port collision.
+ * @param conf A SparkConf used to get the maximum number of retries when binding to a port.
+ * @param serviceName Name of the service.
*/
def startServiceOnPort[T](
startPort: Int,
startService: Int => (T, Int),
- serviceName: String = "",
- maxRetries: Int = portMaxRetries): (T, Int) = {
+ conf: SparkConf,
+ serviceName: String = ""): (T, Int) = {
val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'"
+ val maxRetries = portMaxRetries(conf)
for (offset <- 0 to maxRetries) {
// Do not increment port if startPort is 0, which is treated as a special port
val tryPort = if (startPort == 0) {
diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
index c817f6dcede75..0e4df17c1bf87 100644
--- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark
+import scala.collection.mutable
+
import org.scalatest.{FunSuite, PrivateMethodTester}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
@@ -143,11 +145,17 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext {
// Verify that running a task reduces the cap
sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 3)))
+ sc.listenerBus.postToAll(SparkListenerBlockManagerAdded(
+ 0L, BlockManagerId("executor-1", "host1", 1), 100L))
sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1")))
+ assert(numExecutorsPending(manager) === 4)
assert(addExecutors(manager) === 1)
- assert(numExecutorsPending(manager) === 6)
+ assert(numExecutorsPending(manager) === 5)
assert(numExecutorsToAdd(manager) === 2)
- assert(addExecutors(manager) === 1)
+ assert(addExecutors(manager) === 2)
+ assert(numExecutorsPending(manager) === 7)
+ assert(numExecutorsToAdd(manager) === 4)
+ assert(addExecutors(manager) === 0)
assert(numExecutorsPending(manager) === 7)
assert(numExecutorsToAdd(manager) === 1)
@@ -325,6 +333,8 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext {
val manager = sc.executorAllocationManager.get
manager.setClock(clock)
+ executorIds(manager).asInstanceOf[mutable.Set[String]] ++= List("1", "2", "3")
+
// Starting remove timer is idempotent for each executor
assert(removeTimes(manager).isEmpty)
onExecutorIdle(manager, "1")
@@ -597,6 +607,41 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext {
assert(removeTimes(manager).size === 1)
}
+ test("SPARK-4951: call onTaskStart before onBlockManagerAdded") {
+ sc = createSparkContext(2, 10)
+ val manager = sc.executorAllocationManager.get
+ assert(executorIds(manager).isEmpty)
+ assert(removeTimes(manager).isEmpty)
+
+ sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1")))
+ sc.listenerBus.postToAll(SparkListenerBlockManagerAdded(
+ 0L, BlockManagerId("executor-1", "host1", 1), 100L))
+ assert(executorIds(manager).size === 1)
+ assert(executorIds(manager).contains("executor-1"))
+ assert(removeTimes(manager).size === 0)
+ }
+
+ test("SPARK-4951: onExecutorAdded should not add a busy executor to removeTimes") {
+ sc = createSparkContext(2, 10)
+ val manager = sc.executorAllocationManager.get
+ assert(executorIds(manager).isEmpty)
+ assert(removeTimes(manager).isEmpty)
+ sc.listenerBus.postToAll(SparkListenerBlockManagerAdded(
+ 0L, BlockManagerId("executor-1", "host1", 1), 100L))
+ sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1")))
+
+ assert(executorIds(manager).size === 1)
+ assert(executorIds(manager).contains("executor-1"))
+ assert(removeTimes(manager).size === 0)
+
+ sc.listenerBus.postToAll(SparkListenerBlockManagerAdded(
+ 0L, BlockManagerId("executor-2", "host1", 1), 100L))
+ assert(executorIds(manager).size === 2)
+ assert(executorIds(manager).contains("executor-2"))
+ assert(removeTimes(manager).size === 1)
+ assert(removeTimes(manager).contains("executor-2"))
+ assert(!removeTimes(manager).contains("executor-1"))
+ }
}
/**
diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
index 0b6511a80df1d..3d2700b7e6be4 100644
--- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
+++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
@@ -30,7 +30,7 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: Suite =>
var conf = new SparkConf(false)
override def beforeAll() {
- _sc = new SparkContext("local", "test", conf)
+ _sc = new SparkContext("local[4]", "test", conf)
super.beforeAll()
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 6836e9ab0fd6b..0deb9b18b8688 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -17,6 +17,10 @@
package org.apache.spark.rdd
+import java.io.{ObjectInputStream, ObjectOutputStream, IOException}
+
+import com.esotericsoftware.kryo.KryoException
+
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.JavaConverters._
import scala.reflect.ClassTag
@@ -887,6 +891,23 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(ancestors6.count(_.isInstanceOf[CyclicalDependencyRDD[_]]) === 3)
}
+ test("task serialization exception should not hang scheduler") {
+ class BadSerializable extends Serializable {
+ @throws(classOf[IOException])
+ private def writeObject(out: ObjectOutputStream): Unit = throw new KryoException("Bad serialization")
+
+ @throws(classOf[IOException])
+ private def readObject(in: ObjectInputStream): Unit = {}
+ }
+ // Note that in the original bug, SPARK-4349, that this verifies, the job would only hang if there were
+ // more threads in the Spark Context than there were number of objects in this sequence.
+ intercept[Throwable] {
+ sc.parallelize(Seq(new BadSerializable, new BadSerializable)).collect
+ }
+ // Check that the context has not crashed
+ sc.parallelize(1 to 100).map(x => x*2).collect
+ }
+
/** A contrived RDD that allows the manual addition of dependencies after creation. */
private class CyclicalDependencyRDD[T: ClassTag] extends RDD[T](sc, Nil) {
private val mutableDependencies: ArrayBuffer[Dependency[_]] = ArrayBuffer.empty
diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
new file mode 100644
index 0000000000000..6b75c98839e03
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.{ObjectInputStream, ObjectOutputStream, IOException}
+
+import org.apache.spark.TaskContext
+
+/**
+ * A Task implementation that fails to serialize.
+ */
+private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) extends Task[Array[Byte]](stageId, 0) {
+ override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
+ override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
+
+ @throws(classOf[IOException])
+ private def writeObject(out: ObjectOutputStream): Unit = {
+ if (stageId == 0) {
+ throw new IllegalStateException("Cannot serialize")
+ }
+ }
+
+ @throws(classOf[IOException])
+ private def readObject(in: ObjectInputStream): Unit = {}
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index 8874cf00e9993..add13f5b21765 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -100,4 +100,34 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin
assert(1 === taskDescriptions.length)
assert("executor0" === taskDescriptions(0).executorId)
}
+
+ test("Scheduler does not crash when tasks are not serializable") {
+ sc = new SparkContext("local", "TaskSchedulerImplSuite")
+ val taskCpus = 2
+
+ sc.conf.set("spark.task.cpus", taskCpus.toString)
+ val taskScheduler = new TaskSchedulerImpl(sc)
+ taskScheduler.initialize(new FakeSchedulerBackend)
+ // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
+ val dagScheduler = new DAGScheduler(sc, taskScheduler) {
+ override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
+ override def executorAdded(execId: String, host: String) {}
+ }
+ val numFreeCores = 1
+ taskScheduler.setDAGScheduler(dagScheduler)
+ var taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null)
+ val multiCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", taskCpus),
+ new WorkerOffer("executor1", "host1", numFreeCores))
+ taskScheduler.submitTasks(taskSet)
+ var taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten
+ assert(0 === taskDescriptions.length)
+
+ // Now check that we can still submit tasks
+ // Even if one of the tasks has not-serializable tasks, the other task set should still be processed without error
+ taskScheduler.submitTasks(taskSet)
+ taskScheduler.submitTasks(FakeTask.createTaskSet(1))
+ taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten
+ assert(taskDescriptions.map(_.executorId) === Seq("executor0"))
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 472191551a01f..84b9b788237bf 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.scheduler
+import java.io.{ObjectInputStream, ObjectOutputStream, IOException}
import java.util.Random
import scala.collection.mutable.ArrayBuffer
@@ -563,6 +564,19 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
assert(manager.emittedTaskSizeWarning)
}
+ test("Not serializable exception thrown if the task cannot be serialized") {
+ sc = new SparkContext("local", "test")
+ val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
+
+ val taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null)
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
+
+ intercept[TaskNotSerializableException] {
+ manager.resourceOffer("exec1", "host1", ANY)
+ }
+ assert(manager.isZombie)
+ }
+
test("abort the job if total size of results is too large") {
val conf = new SparkConf().set("spark.driver.maxResultSize", "2m")
sc = new SparkContext("local", "test", conf)
diff --git a/dev/run-tests b/dev/run-tests
index 20603fc089239..2257a566bb1bb 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -21,8 +21,10 @@
FWDIR="$(cd "`dirname $0`"/..; pwd)"
cd "$FWDIR"
-# Remove work directory
+# Clean up work directory and caches
rm -rf ./work
+rm -rf ~/.ivy2/local/org.apache.spark
+rm -rf ~/.ivy2/cache/org.apache.spark
source "$FWDIR/dev/run-tests-codes.sh"
diff --git a/docs/_config.yml b/docs/_config.yml
index a96a76dd9ab5e..e2db274e1f619 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -17,6 +17,6 @@ SPARK_VERSION: 1.3.0-SNAPSHOT
SPARK_VERSION_SHORT: 1.3.0
SCALA_BINARY_VERSION: "2.10"
SCALA_VERSION: "2.10.4"
-MESOS_VERSION: 0.18.1
+MESOS_VERSION: 0.21.0
SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK
SPARK_GITHUB_URL: https://github.com/apache/spark
diff --git a/docs/configuration.md b/docs/configuration.md
index 2add48569bece..f292bfbb7dcd6 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -678,7 +678,7 @@ Apart from these, the following properties are also available, and may be useful
spark.storage.memoryMapThreshold |
- 8192 |
+ 2097152 |
Size of a block, in bytes, above which Spark memory maps when reading a block from disk.
This prevents Spark from memory mapping very small blocks. In general, memory
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 183698ffe9304..4f273098c5db3 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -21,6 +21,14 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
Property Name | Default | Meaning |
+
+ spark.yarn.am.memory |
+ 512m |
+
+ Amount of memory to use for the YARN Application Master in client mode, in the same format as JVM memory strings (e.g. 512m , 2g ).
+ In cluster mode, use spark.driver.memory instead.
+ |
+
spark.yarn.am.waitTime |
100000 |
@@ -90,7 +98,14 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
spark.yarn.driver.memoryOverhead |
driverMemory * 0.07, with minimum of 384 |
- The amount of off heap memory (in megabytes) to be allocated per driver. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%).
+ The amount of off heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%).
+ |
+
+
+ spark.yarn.am.memoryOverhead |
+ AM memory * 0.07, with minimum of 384 |
+
+ Same as spark.yarn.driver.memoryOverhead , but for the Application Master in client mode.
|
@@ -145,7 +160,7 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
spark.yarn.am.extraJavaOptions |
(none) |
- A string of extra JVM options to pass to the Yarn ApplicationMaster in client mode.
+ A string of extra JVM options to pass to the YARN Application Master in client mode.
In cluster mode, use spark.driver.extraJavaOptions instead.
|
diff --git a/examples/pom.xml b/examples/pom.xml
index 002d4458c4b3e..4b92147725f6b 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -392,29 +392,6 @@
-
- hbase-hadoop2
-
-
- hbase.profile
- hadoop2
-
-
-
- 0.98.7-hadoop2
-
-
-
- hbase-hadoop1
-
-
- !hbase.profile
-
-
-
- 0.98.7-hadoop1
-
-
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
index 9fbb0a800d735..35b8dd6c29b66 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
@@ -27,8 +27,8 @@ object SparkPi {
val conf = new SparkConf().setAppName("Spark Pi")
val spark = new SparkContext(conf)
val slices = if (args.length > 0) args(0).toInt else 2
- val n = 100000 * slices
- val count = spark.parallelize(1 to n, slices).map { i =>
+ val n = math.min(100000L * slices, Int.MaxValue).toInt // avoid overflow
+ val count = spark.parallelize(1 until n, slices).map { i =>
val x = random * 2 - 1
val y = random * 2 - 1
if (x*x + y*y < 1) 1 else 0
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
index 13943ed5442b9..f333e3891b5f0 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
@@ -80,7 +80,7 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L
val socket = new ServerSocket(trialPort)
socket.close()
(null, trialPort)
- })._2
+ }, conf)._2
}
/** Setup and start the streaming context */
diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
index 98fe6cb301f52..fe53a29cba0c9 100644
--- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
+++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
@@ -19,16 +19,21 @@ package org.apache.spark.streaming.mqtt
import java.net.{URI, ServerSocket}
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
import org.apache.activemq.broker.{TransportConnector, BrokerService}
-import org.apache.spark.util.Utils
+import org.eclipse.paho.client.mqttv3._
+import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
+
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.concurrent.Eventually
-import scala.concurrent.duration._
+
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
-import org.eclipse.paho.client.mqttv3._
-import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
+import org.apache.spark.SparkConf
+import org.apache.spark.util.Utils
class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
@@ -38,8 +43,9 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
private val freePort = findFreePort()
private val brokerUri = "//localhost:" + freePort
private val topic = "def"
- private var ssc: StreamingContext = _
private val persistenceDir = Utils.createTempDir()
+
+ private var ssc: StreamingContext = _
private var broker: BrokerService = _
private var connector: TransportConnector = _
@@ -101,7 +107,7 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
val socket = new ServerSocket(trialPort)
socket.close()
(null, trialPort)
- })._2
+ }, new SparkConf())._2
}
def publishData(data: String): Unit = {
@@ -115,8 +121,9 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
val message: MqttMessage = new MqttMessage(data.getBytes("utf-8"))
message.setQos(1)
message.setRetained(true)
- for (i <- 0 to 100)
+ for (i <- 0 to 100) {
msgTopic.publish(message)
+ }
}
} finally {
client.disconnect()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
index 3a6c0e681e3fa..d8e134619411b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
@@ -20,10 +20,12 @@ package org.apache.spark.mllib.clustering
import scala.collection.mutable.IndexedSeq
import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose}
-import org.apache.spark.rdd.RDD
+
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS}
-import org.apache.spark.mllib.stat.impl.MultivariateGaussian
+import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
/**
* This class performs expectation maximization for multivariate Gaussian
@@ -45,10 +47,11 @@ import org.apache.spark.mllib.util.MLUtils
class GaussianMixtureEM private (
private var k: Int,
private var convergenceTol: Double,
- private var maxIterations: Int) extends Serializable {
+ private var maxIterations: Int,
+ private var seed: Long) extends Serializable {
/** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
- def this() = this(2, 0.01, 100)
+ def this() = this(2, 0.01, 100, Utils.random.nextLong())
// number of samples per cluster to use when initializing Gaussians
private val nSamples = 5
@@ -100,11 +103,21 @@ class GaussianMixtureEM private (
this
}
- /** Return the largest change in log-likelihood at which convergence is
- * considered to have occurred.
+ /**
+ * Return the largest change in log-likelihood at which convergence is
+ * considered to have occurred.
*/
def getConvergenceTol: Double = convergenceTol
-
+
+ /** Set the random seed */
+ def setSeed(seed: Long): this.type = {
+ this.seed = seed
+ this
+ }
+
+ /** Return the random seed */
+ def getSeed: Long = seed
+
/** Perform expectation maximization */
def run(data: RDD[Vector]): GaussianMixtureModel = {
val sc = data.sparkContext
@@ -113,7 +126,7 @@ class GaussianMixtureEM private (
val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()
// Get length of the input vectors
- val d = breezeData.first.length
+ val d = breezeData.first().length
// Determine initial weights and corresponding Gaussians.
// If the user supplied an initial GMM, we use those values, otherwise
@@ -122,11 +135,11 @@ class GaussianMixtureEM private (
// derived from the samples
val (weights, gaussians) = initialModel match {
case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) =>
- new MultivariateGaussian(mu.toBreeze.toDenseVector, sigma.toBreeze.toDenseMatrix)
+ new MultivariateGaussian(mu, sigma)
})
case None => {
- val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)
+ val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
(Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
@@ -164,8 +177,8 @@ class GaussianMixtureEM private (
}
// Need to convert the breeze matrices to MLlib matrices
- val means = Array.tabulate(k) { i => Vectors.fromBreeze(gaussians(i).mu) }
- val sigmas = Array.tabulate(k) { i => Matrices.fromBreeze(gaussians(i).sigma) }
+ val means = Array.tabulate(k) { i => gaussians(i).mu }
+ val sigmas = Array.tabulate(k) { i => gaussians(i).sigma }
new GaussianMixtureModel(weights, means, sigmas)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index b461ea4f0f06e..416cad080c408 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -21,7 +21,7 @@ import breeze.linalg.{DenseVector => BreezeVector}
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Matrix, Vector}
-import org.apache.spark.mllib.stat.impl.MultivariateGaussian
+import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
index 36d8cadd2bdd7..181f507516485 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
@@ -102,6 +102,9 @@ class IndexedRowMatrix(
k: Int,
computeU: Boolean = false,
rCond: Double = 1e-9): SingularValueDecomposition[IndexedRowMatrix, Matrix] = {
+
+ val n = numCols().toInt
+ require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.")
val indices = rows.map(_.index)
val svd = toRowMatrix().computeSVD(k, computeU, rCond)
val U = if (computeU) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index fbd35e372f9b1..d5abba6a4b645 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -212,7 +212,7 @@ class RowMatrix(
tol: Double,
mode: String): SingularValueDecomposition[RowMatrix, Matrix] = {
val n = numCols().toInt
- require(k > 0 && k <= n, s"Request up to n singular values but got k=$k and n=$n.")
+ require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.")
object SVDMode extends Enumeration {
val LocalARPACK, LocalLAPACK, DistARPACK = Value
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
similarity index 61%
rename from mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala
rename to mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
index bc7f6c5197ac7..fd186b5ee6f72 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
@@ -15,13 +15,16 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.stat.impl
+package org.apache.spark.mllib.stat.distribution
import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym}
+import org.apache.spark.annotation.DeveloperApi;
+import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
import org.apache.spark.mllib.util.MLUtils
/**
+ * :: DeveloperApi ::
* This class provides basic functionality for a Multivariate Gaussian (Normal) Distribution. In
* the event that the covariance matrix is singular, the density will be computed in a
* reduced dimensional subspace under which the distribution is supported.
@@ -30,33 +33,64 @@ import org.apache.spark.mllib.util.MLUtils
* @param mu The mean vector of the distribution
* @param sigma The covariance matrix of the distribution
*/
-private[mllib] class MultivariateGaussian(
- val mu: DBV[Double],
- val sigma: DBM[Double]) extends Serializable {
+@DeveloperApi
+class MultivariateGaussian (
+ val mu: Vector,
+ val sigma: Matrix) extends Serializable {
+ require(sigma.numCols == sigma.numRows, "Covariance matrix must be square")
+ require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size")
+
+ private val breezeMu = mu.toBreeze.toDenseVector
+
+ /**
+ * private[mllib] constructor
+ *
+ * @param mu The mean vector of the distribution
+ * @param sigma The covariance matrix of the distribution
+ */
+ private[mllib] def this(mu: DBV[Double], sigma: DBM[Double]) = {
+ this(Vectors.fromBreeze(mu), Matrices.fromBreeze(sigma))
+ }
+
/**
* Compute distribution dependent constants:
- * rootSigmaInv = D^(-1/2) * U, where sigma = U * D * U.t
- * u = (2*pi)^(-k/2) * det(sigma)^(-1/2)
+ * rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t
+ * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
*/
private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
/** Returns density of this multivariate Gaussian at given point, x */
- def pdf(x: DBV[Double]): Double = {
- val delta = x - mu
+ def pdf(x: Vector): Double = {
+ pdf(x.toBreeze.toDenseVector)
+ }
+
+ /** Returns the log-density of this multivariate Gaussian at given point, x */
+ def logpdf(x: Vector): Double = {
+ logpdf(x.toBreeze.toDenseVector)
+ }
+
+ /** Returns density of this multivariate Gaussian at given point, x */
+ private[mllib] def pdf(x: DBV[Double]): Double = {
+ math.exp(logpdf(x))
+ }
+
+ /** Returns the log-density of this multivariate Gaussian at given point, x */
+ private[mllib] def logpdf(x: DBV[Double]): Double = {
+ val delta = x - breezeMu
val v = rootSigmaInv * delta
- u * math.exp(v.t * v * -0.5)
+ u + v.t * v * -0.5
}
/**
* Calculate distribution dependent components used for the density function:
- * pdf(x) = (2*pi)^(-k/2) * det(sigma)^(-1/2) * exp( (-1/2) * (x-mu).t * inv(sigma) * (x-mu) )
+ * pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu))
* where k is length of the mean vector.
*
* We here compute distribution-fixed parts
- * (2*pi)^(-k/2) * det(sigma)^(-1/2)
+ * log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
* and
- * D^(-1/2) * U, where sigma = U * D * U.t
+ * D^(-1/2)^ * U, where sigma = U * D * U.t
*
* Both the determinant and the inverse can be computed from the singular value decomposition
* of sigma. Noting that covariance matrices are always symmetric and positive semi-definite,
@@ -65,11 +99,11 @@ private[mllib] class MultivariateGaussian(
*
* sigma = U * D * U.t
* inv(Sigma) = U * inv(D) * U.t
- * = (D^{-1/2} * U).t * (D^{-1/2} * U)
+ * = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U)
*
* and thus
*
- * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2} * U * (x-mu))^2
+ * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^
*
* To guard against singular covariance matrices, this method computes both the
* pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered
@@ -77,21 +111,21 @@ private[mllib] class MultivariateGaussian(
* relation to the maximum singular value (same tolerance used by, e.g., Octave).
*/
private def calculateCovarianceConstants: (DBM[Double], Double) = {
- val eigSym.EigSym(d, u) = eigSym(sigma) // sigma = u * diag(d) * u.t
+ val eigSym.EigSym(d, u) = eigSym(sigma.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t
// For numerical stability, values are considered to be non-zero only if they exceed tol.
// This prevents any inverted value from exceeding (eps * n * max(d))^-1
val tol = MLUtils.EPSILON * max(d) * d.length
try {
- // pseudo-determinant is product of all non-zero singular values
- val pdetSigma = d.activeValuesIterator.filter(_ > tol).reduce(_ * _)
+ // log(pseudo-determinant) is sum of the logs of all non-zero singular values
+ val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum
// calculate the root-pseudo-inverse of the diagonal matrix of singular values
// by inverting the square root of all non-zero values
val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray))
- (pinvS * u, math.pow(2.0 * math.Pi, -mu.length / 2.0) * math.pow(pdetSigma, -0.5))
+ (pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma))
} catch {
case uex: UnsupportedOperationException =>
throw new IllegalArgumentException("Covariance matrix has no non-zero singular values")
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
index 23feb82874b70..9da5495741a80 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
@@ -35,12 +35,14 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
val Ew = 1.0
val Emu = Vectors.dense(5.0, 10.0)
val Esigma = Matrices.dense(2, 2, Array(2.0 / 3.0, -2.0 / 3.0, -2.0 / 3.0, 2.0 / 3.0))
-
- val gmm = new GaussianMixtureEM().setK(1).run(data)
-
- assert(gmm.weight(0) ~== Ew absTol 1E-5)
- assert(gmm.mu(0) ~== Emu absTol 1E-5)
- assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
+
+ val seeds = Array(314589, 29032897, 50181, 494821, 4660)
+ seeds.foreach { seed =>
+ val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data)
+ assert(gmm.weight(0) ~== Ew absTol 1E-5)
+ assert(gmm.mu(0) ~== Emu absTol 1E-5)
+ assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
+ }
}
test("two clusters") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
index e25bc02b06c9a..741cd4997b853 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
@@ -113,6 +113,13 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext {
assert(closeToZero(U * brzDiag(s) * V.t - localA))
}
+ test("validate k in svd") {
+ val A = new IndexedRowMatrix(indexedRows)
+ intercept[IllegalArgumentException] {
+ A.computeSVD(-1)
+ }
+ }
+
def closeToZero(G: BDM[Double]): Boolean = {
G.valuesIterator.map(math.abs).sum < 1e-6
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
index dbf55ff81ca99..3309713e91f87 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
@@ -171,6 +171,14 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext {
}
}
+ test("validate k in svd") {
+ for (mat <- Seq(denseMat, sparseMat)) {
+ intercept[IllegalArgumentException] {
+ mat.computeSVD(-1)
+ }
+ }
+ }
+
def closeToZero(G: BDM[Double]): Boolean = {
G.valuesIterator.map(math.abs).sum < 1e-6
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
similarity index 72%
rename from mllib/src/test/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussianSuite.scala
rename to mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
index d58f2587e55aa..fac2498e4dcb3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussianSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
@@ -15,54 +15,53 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.stat.impl
+package org.apache.spark.mllib.stat.distribution
import org.scalatest.FunSuite
-import breeze.linalg.{ DenseVector => BDV, DenseMatrix => BDM }
-
+import org.apache.spark.mllib.linalg.{ Vectors, Matrices }
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
class MultivariateGaussianSuite extends FunSuite with MLlibTestSparkContext {
test("univariate") {
- val x1 = new BDV(Array(0.0))
- val x2 = new BDV(Array(1.5))
+ val x1 = Vectors.dense(0.0)
+ val x2 = Vectors.dense(1.5)
- val mu = new BDV(Array(0.0))
- val sigma1 = new BDM(1, 1, Array(1.0))
+ val mu = Vectors.dense(0.0)
+ val sigma1 = Matrices.dense(1, 1, Array(1.0))
val dist1 = new MultivariateGaussian(mu, sigma1)
assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5)
assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5)
- val sigma2 = new BDM(1, 1, Array(4.0))
+ val sigma2 = Matrices.dense(1, 1, Array(4.0))
val dist2 = new MultivariateGaussian(mu, sigma2)
assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5)
assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5)
}
test("multivariate") {
- val x1 = new BDV(Array(0.0, 0.0))
- val x2 = new BDV(Array(1.0, 1.0))
+ val x1 = Vectors.dense(0.0, 0.0)
+ val x2 = Vectors.dense(1.0, 1.0)
- val mu = new BDV(Array(0.0, 0.0))
- val sigma1 = new BDM(2, 2, Array(1.0, 0.0, 0.0, 1.0))
+ val mu = Vectors.dense(0.0, 0.0)
+ val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0))
val dist1 = new MultivariateGaussian(mu, sigma1)
assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5)
assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5)
- val sigma2 = new BDM(2, 2, Array(4.0, -1.0, -1.0, 2.0))
+ val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0))
val dist2 = new MultivariateGaussian(mu, sigma2)
assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5)
assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5)
}
test("multivariate degenerate") {
- val x1 = new BDV(Array(0.0, 0.0))
- val x2 = new BDV(Array(1.0, 1.0))
+ val x1 = Vectors.dense(0.0, 0.0)
+ val x2 = Vectors.dense(1.0, 1.0)
- val mu = new BDV(Array(0.0, 0.0))
- val sigma = new BDM(2, 2, Array(1.0, 1.0, 1.0, 1.0))
+ val mu = Vectors.dense(0.0, 0.0)
+ val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0))
val dist = new MultivariateGaussian(mu, sigma)
assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5)
assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5)
diff --git a/pom.xml b/pom.xml
index 703e5c47bf59b..f4466e56c2a53 100644
--- a/pom.xml
+++ b/pom.xml
@@ -115,14 +115,14 @@
1.6
spark
2.0.1
- 0.18.1
+ 0.21.0
shaded-protobuf
1.7.5
1.2.17
1.0.4
2.4.1
${hadoop.version}
- 0.94.6
+ 0.98.7-hadoop1
hbase
1.4.0
3.4.5
@@ -1130,6 +1130,7 @@
${test_classpath}
true
+ false
@@ -1465,6 +1466,7 @@
2.2.0
2.5.0
+ 0.98.7-hadoop2
hadoop2
@@ -1475,6 +1477,7 @@
2.3.0
2.5.0
0.9.0
+ 0.98.7-hadoop2
3.1.1
hadoop2
@@ -1486,6 +1489,7 @@
2.4.0
2.5.0
0.9.0
+ 0.98.7-hadoop2
3.1.1
hadoop2
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 31d4c317ae569..51e8bd4cf6419 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -36,7 +36,6 @@ object MimaExcludes {
case v if v.startsWith("1.3") =>
Seq(
MimaBuild.excludeSparkPackage("deploy"),
- MimaBuild.excludeSparkPackage("graphx"),
// These are needed if checking against the sbt build, since they are part of
// the maven-generated artifacts in the 1.2 build.
MimaBuild.excludeSparkPackage("unused"),
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 0e8b398fc6b97..014ac1791c849 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -807,14 +807,14 @@ def convert_struct(obj):
return
if isinstance(obj, tuple):
- if hasattr(obj, "fields"):
- d = dict(zip(obj.fields, obj))
- if hasattr(obj, "__FIELDS__"):
+ if hasattr(obj, "_fields"):
+ d = dict(zip(obj._fields, obj))
+ elif hasattr(obj, "__FIELDS__"):
d = dict(zip(obj.__FIELDS__, obj))
elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
d = dict(obj)
else:
- raise ValueError("unexpected tuple: %s" % obj)
+ raise ValueError("unexpected tuple: %s" % str(obj))
elif isinstance(obj, dict):
d = obj
@@ -1327,6 +1327,16 @@ def inferSchema(self, rdd, samplingRatio=None):
>>> srdd = sqlCtx.inferSchema(nestedRdd2)
>>> srdd.collect()
[Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
+
+ >>> from collections import namedtuple
+ >>> CustomRow = namedtuple('CustomRow', 'field1 field2')
+ >>> rdd = sc.parallelize(
+ ... [CustomRow(field1=1, field2="row1"),
+ ... CustomRow(field1=2, field2="row2"),
+ ... CustomRow(field1=3, field2="row3")])
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.collect()[0]
+ Row(field1=1, field2=u'row1')
"""
if isinstance(rdd, SchemaRDD):
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
index 646c68e60c2e9..b646f0b6f0868 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
@@ -106,7 +106,7 @@ import org.apache.spark.util.Utils
val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles
/** Jetty server that will serve our classes to worker nodes */
val classServerPort = conf.getInt("spark.replClassServer.port", 0)
- val classServer = new HttpServer(outputDir, new SecurityManager(conf), classServerPort, "HTTP class server")
+ val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server")
private var currentSettings: Settings = initialSettings
var printResults = true // whether to print result lines
var totalSilence = false // whether to print anything
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
index 5e93a71995072..69e44d4f916e1 100644
--- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
+++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
@@ -32,7 +32,7 @@ object Main extends Logging {
val s = new Settings()
s.processArguments(List("-Yrepl-class-based",
"-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-sync"), true)
- val classServer = new HttpServer(outputDir, new SecurityManager(conf))
+ val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf))
var sparkContext: SparkContext = _
var interp = new SparkILoop // this is a public var because tests reset it.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
similarity index 62%
rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
index f1a1ca6616a21..93d74adbcc957 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
@@ -105,72 +105,3 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical {
}
}
}
-
-/**
- * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL
- * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser.
- *
- * @param fallback A function that parses an input string to a logical plan
- */
-private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser {
-
- // A parser for the key-value part of the "SET [key = [value ]]" syntax
- private object SetCommandParser extends RegexParsers {
- private val key: Parser[String] = "(?m)[^=]+".r
-
- private val value: Parser[String] = "(?m).*$".r
-
- private val pair: Parser[LogicalPlan] =
- (key ~ ("=".r ~> value).?).? ^^ {
- case None => SetCommand(None)
- case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim)))
- }
-
- def apply(input: String): LogicalPlan = parseAll(pair, input) match {
- case Success(plan, _) => plan
- case x => sys.error(x.toString)
- }
- }
-
- protected val AS = Keyword("AS")
- protected val CACHE = Keyword("CACHE")
- protected val LAZY = Keyword("LAZY")
- protected val SET = Keyword("SET")
- protected val TABLE = Keyword("TABLE")
- protected val UNCACHE = Keyword("UNCACHE")
-
- protected implicit def asParser(k: Keyword): Parser[String] =
- lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
-
- private val reservedWords: Seq[String] =
- this
- .getClass
- .getMethods
- .filter(_.getReturnType == classOf[Keyword])
- .map(_.invoke(this).asInstanceOf[Keyword].str)
-
- override val lexical = new SqlLexical(reservedWords)
-
- override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | others
-
- private lazy val cache: Parser[LogicalPlan] =
- CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ {
- case isLazy ~ tableName ~ plan =>
- CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined)
- }
-
- private lazy val uncache: Parser[LogicalPlan] =
- UNCACHE ~ TABLE ~> ident ^^ {
- case tableName => UncacheTableCommand(tableName)
- }
-
- private lazy val set: Parser[LogicalPlan] =
- SET ~> restInput ^^ {
- case input => SetCommandParser(input)
- }
-
- private lazy val others: Parser[LogicalPlan] =
- wholeInput ^^ {
- case input => fallback(input)
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index f79d4ff444dc0..5d974df98b699 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -125,7 +125,7 @@ class SqlParser extends AbstractSparkSQLParser {
}
protected lazy val start: Parser[LogicalPlan] =
- ( select *
+ ( (select | ("(" ~> select <~ ")")) *
( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) }
| INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) }
| EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)}
@@ -178,10 +178,10 @@ class SqlParser extends AbstractSparkSQLParser {
joinedRelation | relationFactor
protected lazy val relationFactor: Parser[LogicalPlan] =
- ( ident ~ (opt(AS) ~> opt(ident)) ^^ {
- case tableName ~ alias => UnresolvedRelation(None, tableName, alias)
+ ( rep1sep(ident, ".") ~ (opt(AS) ~> opt(ident)) ^^ {
+ case tableIdent ~ alias => UnresolvedRelation(tableIdent, alias)
}
- | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) }
+ | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) }
)
protected lazy val joinedRelation: Parser[LogicalPlan] =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 72680f37a0b4d..c009cc1e1e85c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -228,11 +228,11 @@ class Analyzer(catalog: Catalog,
*/
object ResolveRelations extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case i @ InsertIntoTable(UnresolvedRelation(databaseName, name, alias), _, _, _) =>
+ case i @ InsertIntoTable(UnresolvedRelation(tableIdentifier, alias), _, _, _) =>
i.copy(
- table = EliminateAnalysisOperators(catalog.lookupRelation(databaseName, name, alias)))
- case UnresolvedRelation(databaseName, name, alias) =>
- catalog.lookupRelation(databaseName, name, alias)
+ table = EliminateAnalysisOperators(catalog.lookupRelation(tableIdentifier, alias)))
+ case UnresolvedRelation(tableIdentifier, alias) =>
+ catalog.lookupRelation(tableIdentifier, alias)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
index 0415d74bd8141..df8d03b86c533 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
@@ -28,77 +28,74 @@ trait Catalog {
def caseSensitive: Boolean
- def tableExists(db: Option[String], tableName: String): Boolean
+ def tableExists(tableIdentifier: Seq[String]): Boolean
def lookupRelation(
- databaseName: Option[String],
- tableName: String,
- alias: Option[String] = None): LogicalPlan
+ tableIdentifier: Seq[String],
+ alias: Option[String] = None): LogicalPlan
- def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit
+ def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit
- def unregisterTable(databaseName: Option[String], tableName: String): Unit
+ def unregisterTable(tableIdentifier: Seq[String]): Unit
def unregisterAllTables(): Unit
- protected def processDatabaseAndTableName(
- databaseName: Option[String],
- tableName: String): (Option[String], String) = {
+ protected def processTableIdentifier(tableIdentifier: Seq[String]): Seq[String] = {
if (!caseSensitive) {
- (databaseName.map(_.toLowerCase), tableName.toLowerCase)
+ tableIdentifier.map(_.toLowerCase)
} else {
- (databaseName, tableName)
+ tableIdentifier
}
}
- protected def processDatabaseAndTableName(
- databaseName: String,
- tableName: String): (String, String) = {
- if (!caseSensitive) {
- (databaseName.toLowerCase, tableName.toLowerCase)
+ protected def getDbTableName(tableIdent: Seq[String]): String = {
+ val size = tableIdent.size
+ if (size <= 2) {
+ tableIdent.mkString(".")
} else {
- (databaseName, tableName)
+ tableIdent.slice(size - 2, size).mkString(".")
}
}
+
+ protected def getDBTable(tableIdent: Seq[String]) : (Option[String], String) = {
+ (tableIdent.lift(tableIdent.size - 2), tableIdent.last)
+ }
}
class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
val tables = new mutable.HashMap[String, LogicalPlan]()
override def registerTable(
- databaseName: Option[String],
- tableName: String,
+ tableIdentifier: Seq[String],
plan: LogicalPlan): Unit = {
- val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
- tables += ((tblName, plan))
+ val tableIdent = processTableIdentifier(tableIdentifier)
+ tables += ((getDbTableName(tableIdent), plan))
}
- override def unregisterTable(
- databaseName: Option[String],
- tableName: String) = {
- val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
- tables -= tblName
+ override def unregisterTable(tableIdentifier: Seq[String]) = {
+ val tableIdent = processTableIdentifier(tableIdentifier)
+ tables -= getDbTableName(tableIdent)
}
override def unregisterAllTables() = {
tables.clear()
}
- override def tableExists(db: Option[String], tableName: String): Boolean = {
- val (dbName, tblName) = processDatabaseAndTableName(db, tableName)
- tables.get(tblName) match {
+ override def tableExists(tableIdentifier: Seq[String]): Boolean = {
+ val tableIdent = processTableIdentifier(tableIdentifier)
+ tables.get(getDbTableName(tableIdent)) match {
case Some(_) => true
case None => false
}
}
override def lookupRelation(
- databaseName: Option[String],
- tableName: String,
+ tableIdentifier: Seq[String],
alias: Option[String] = None): LogicalPlan = {
- val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
- val table = tables.getOrElse(tblName, sys.error(s"Table Not Found: $tableName"))
- val tableWithQualifiers = Subquery(tblName, table)
+ val tableIdent = processTableIdentifier(tableIdentifier)
+ val tableFullName = getDbTableName(tableIdent)
+ val table = tables.getOrElse(tableFullName, sys.error(s"Table Not Found: $tableFullName"))
+ val tableWithQualifiers = Subquery(tableIdent.last, table)
// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
// properly qualified with this alias.
@@ -117,41 +114,39 @@ trait OverrideCatalog extends Catalog {
// TODO: This doesn't work when the database changes...
val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]()
- abstract override def tableExists(db: Option[String], tableName: String): Boolean = {
- val (dbName, tblName) = processDatabaseAndTableName(db, tableName)
- overrides.get((dbName, tblName)) match {
+ abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = {
+ val tableIdent = processTableIdentifier(tableIdentifier)
+ overrides.get(getDBTable(tableIdent)) match {
case Some(_) => true
- case None => super.tableExists(db, tableName)
+ case None => super.tableExists(tableIdentifier)
}
}
abstract override def lookupRelation(
- databaseName: Option[String],
- tableName: String,
+ tableIdentifier: Seq[String],
alias: Option[String] = None): LogicalPlan = {
- val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
- val overriddenTable = overrides.get((dbName, tblName))
- val tableWithQualifers = overriddenTable.map(r => Subquery(tblName, r))
+ val tableIdent = processTableIdentifier(tableIdentifier)
+ val overriddenTable = overrides.get(getDBTable(tableIdent))
+ val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r))
// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
// properly qualified with this alias.
val withAlias =
tableWithQualifers.map(r => alias.map(a => Subquery(a, r)).getOrElse(r))
- withAlias.getOrElse(super.lookupRelation(dbName, tblName, alias))
+ withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias))
}
override def registerTable(
- databaseName: Option[String],
- tableName: String,
+ tableIdentifier: Seq[String],
plan: LogicalPlan): Unit = {
- val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
- overrides.put((dbName, tblName), plan)
+ val tableIdent = processTableIdentifier(tableIdentifier)
+ overrides.put(getDBTable(tableIdent), plan)
}
- override def unregisterTable(databaseName: Option[String], tableName: String): Unit = {
- val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
- overrides.remove((dbName, tblName))
+ override def unregisterTable(tableIdentifier: Seq[String]): Unit = {
+ val tableIdent = processTableIdentifier(tableIdentifier)
+ overrides.remove(getDBTable(tableIdent))
}
override def unregisterAllTables(): Unit = {
@@ -167,22 +162,21 @@ object EmptyCatalog extends Catalog {
val caseSensitive: Boolean = true
- def tableExists(db: Option[String], tableName: String): Boolean = {
+ def tableExists(tableIdentifier: Seq[String]): Boolean = {
throw new UnsupportedOperationException
}
def lookupRelation(
- databaseName: Option[String],
- tableName: String,
+ tableIdentifier: Seq[String],
alias: Option[String] = None) = {
throw new UnsupportedOperationException
}
- def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = {
+ def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = {
throw new UnsupportedOperationException
}
- def unregisterTable(databaseName: Option[String], tableName: String): Unit = {
+ def unregisterTable(tableIdentifier: Seq[String]): Unit = {
throw new UnsupportedOperationException
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 77d84e1687e1b..71a738a0b2ca0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -34,8 +34,7 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str
* Holds the name of a relation that has yet to be looked up in a [[Catalog]].
*/
case class UnresolvedRelation(
- databaseName: Option[String],
- tableName: String,
+ tableIdentifier: Seq[String],
alias: Option[String] = None) extends LeafNode {
override def output = Nil
override lazy val resolved = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 9608e15c0f302..b2262e5e6efb6 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -290,7 +290,7 @@ package object dsl {
def insertInto(tableName: String, overwrite: Boolean = false) =
InsertIntoTable(
- analysis.UnresolvedRelation(None, tableName), Map.empty, logicalPlan, overwrite)
+ analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite)
def analyze = analysis.SimpleAnalyzer(logicalPlan)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
index 5a1863953eae9..45905f8ef98c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
-import org.apache.spark.sql.catalyst.types.StringType
+import org.apache.spark.sql.catalyst.expressions.Attribute
/**
* A logical node that represents a non-query command to be executed by the system. For example,
@@ -28,48 +27,3 @@ abstract class Command extends LeafNode {
self: Product =>
def output: Seq[Attribute] = Seq.empty
}
-
-/**
- *
- * Commands of the form "SET [key [= value] ]".
- */
-case class SetCommand(kv: Option[(String, Option[String])]) extends Command {
- override def output = Seq(
- AttributeReference("", StringType, nullable = false)())
-}
-
-/**
- * Returned by a parser when the users only wants to see what query plan would be executed, without
- * actually performing the execution.
- */
-case class ExplainCommand(plan: LogicalPlan, extended: Boolean = false) extends Command {
- override def output =
- Seq(AttributeReference("plan", StringType, nullable = false)())
-}
-
-/**
- * Returned for the "CACHE TABLE tableName [AS SELECT ...]" command.
- */
-case class CacheTableCommand(tableName: String, plan: Option[LogicalPlan], isLazy: Boolean)
- extends Command
-
-/**
- * Returned for the "UNCACHE TABLE tableName" command.
- */
-case class UncacheTableCommand(tableName: String) extends Command
-
-/**
- * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command.
- * @param table The table to be described.
- * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false.
- * It is effective only when the table is a Hive table.
- */
-case class DescribeCommand(
- table: LogicalPlan,
- isExtended: Boolean) extends Command {
- override def output = Seq(
- // Column names are based on Hive.
- AttributeReference("col_name", StringType, nullable = false)(),
- AttributeReference("data_type", StringType, nullable = false)(),
- AttributeReference("comment", StringType, nullable = false)())
-}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 82f2101d8ce17..f430057ef7191 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -44,8 +44,8 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
AttributeReference("e", ShortType)())
before {
- caseSensitiveCatalog.registerTable(None, "TaBlE", testRelation)
- caseInsensitiveCatalog.registerTable(None, "TaBlE", testRelation)
+ caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
+ caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
}
test("union project *") {
@@ -64,45 +64,45 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
assert(
caseSensitiveAnalyze(
Project(Seq(UnresolvedAttribute("TbL.a")),
- UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
+ UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
val e = intercept[TreeNodeException[_]] {
caseSensitiveAnalyze(
Project(Seq(UnresolvedAttribute("tBl.a")),
- UnresolvedRelation(None, "TaBlE", Some("TbL"))))
+ UnresolvedRelation(Seq("TaBlE"), Some("TbL"))))
}
assert(e.getMessage().toLowerCase.contains("unresolved"))
assert(
caseInsensitiveAnalyze(
Project(Seq(UnresolvedAttribute("TbL.a")),
- UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
+ UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
assert(
caseInsensitiveAnalyze(
Project(Seq(UnresolvedAttribute("tBl.a")),
- UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
+ UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
}
test("resolve relations") {
val e = intercept[RuntimeException] {
- caseSensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None))
+ caseSensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None))
}
assert(e.getMessage == "Table Not Found: tAbLe")
assert(
- caseSensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) ===
+ caseSensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) ===
testRelation)
assert(
- caseInsensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None)) ===
+ caseInsensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) ===
testRelation)
assert(
- caseInsensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) ===
+ caseInsensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) ===
testRelation)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index 3677a6e72e23a..bbbeb4f2e4fe3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -41,7 +41,7 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
val f: Expression = UnresolvedAttribute("f")
before {
- catalog.registerTable(None, "table", relation)
+ catalog.registerTable(Seq("table"), relation)
}
private def checkType(expression: Expression, expectedType: DataType): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 6a1a4d995bf61..6c575dd727b46 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -76,7 +76,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
protected[sql] val sqlParser = {
val fallback = new catalyst.SqlParser
- new catalyst.SparkSQLParser(fallback(_))
+ new SparkSQLParser(fallback(_))
}
protected[sql] def parseSql(sql: String): LogicalPlan = {
@@ -276,7 +276,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group userf
*/
def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = {
- catalog.registerTable(None, tableName, rdd.queryExecution.logical)
+ catalog.registerTable(Seq(tableName), rdd.queryExecution.logical)
}
/**
@@ -289,7 +289,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
def dropTempTable(tableName: String): Unit = {
tryUncacheQuery(table(tableName))
- catalog.unregisterTable(None, tableName)
+ catalog.unregisterTable(Seq(tableName))
}
/**
@@ -308,7 +308,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
/** Returns the specified table as a SchemaRDD */
def table(tableName: String): SchemaRDD =
- new SchemaRDD(this, catalog.lookupRelation(None, tableName))
+ new SchemaRDD(this, catalog.lookupRelation(Seq(tableName)))
/**
* :: DeveloperApi ::
@@ -329,7 +329,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
def strategies: Seq[Strategy] =
extraStrategies ++ (
- CommandStrategy ::
DataSourceStrategy ::
TakeOrdered ::
HashAggregation ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
index fd5f4abcbcd65..3cf9209465b76 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
@@ -97,8 +97,8 @@ private[sql] trait SchemaRDDLike {
*/
@Experimental
def insertInto(tableName: String, overwrite: Boolean): Unit =
- sqlContext.executePlan(
- InsertIntoTable(UnresolvedRelation(None, tableName), Map.empty, logicalPlan, overwrite)).toRdd
+ sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)),
+ Map.empty, logicalPlan, overwrite)).toRdd
/**
* :: Experimental ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala
new file mode 100644
index 0000000000000..65358b7d4ea8e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.{SqlLexical, AbstractSparkSQLParser}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.{UncacheTableCommand, CacheTableCommand, SetCommand}
+
+import scala.util.parsing.combinator.RegexParsers
+
+/**
+ * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL
+ * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser.
+ *
+ * @param fallback A function that parses an input string to a logical plan
+ */
+private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser {
+
+ // A parser for the key-value part of the "SET [key = [value ]]" syntax
+ private object SetCommandParser extends RegexParsers {
+ private val key: Parser[String] = "(?m)[^=]+".r
+
+ private val value: Parser[String] = "(?m).*$".r
+
+ private val output: Seq[Attribute] = Seq(AttributeReference("", StringType, nullable = false)())
+
+ private val pair: Parser[LogicalPlan] =
+ (key ~ ("=".r ~> value).?).? ^^ {
+ case None => SetCommand(None, output)
+ case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim)), output)
+ }
+
+ def apply(input: String): LogicalPlan = parseAll(pair, input) match {
+ case Success(plan, _) => plan
+ case x => sys.error(x.toString)
+ }
+ }
+
+ protected val AS = Keyword("AS")
+ protected val CACHE = Keyword("CACHE")
+ protected val LAZY = Keyword("LAZY")
+ protected val SET = Keyword("SET")
+ protected val TABLE = Keyword("TABLE")
+ protected val UNCACHE = Keyword("UNCACHE")
+
+ protected implicit def asParser(k: Keyword): Parser[String] =
+ lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
+
+ private val reservedWords: Seq[String] =
+ this
+ .getClass
+ .getMethods
+ .filter(_.getReturnType == classOf[Keyword])
+ .map(_.invoke(this).asInstanceOf[Keyword].str)
+
+ override val lexical = new SqlLexical(reservedWords)
+
+ override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | others
+
+ private lazy val cache: Parser[LogicalPlan] =
+ CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ {
+ case isLazy ~ tableName ~ plan =>
+ CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined)
+ }
+
+ private lazy val uncache: Parser[LogicalPlan] =
+ UNCACHE ~ TABLE ~> ident ^^ {
+ case tableName => UncacheTableCommand(tableName)
+ }
+
+ private lazy val set: Parser[LogicalPlan] =
+ SET ~> restInput ^^ {
+ case input => SetCommandParser(input)
+ }
+
+ private lazy val others: Parser[LogicalPlan] =
+ wholeInput ^^ {
+ case input => fallback(input)
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index ce878c137e627..99b6611d3bbcf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -259,6 +259,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def numPartitions = self.numPartitions
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case r: RunnableCommand => ExecutedCommand(r) :: Nil
+
case logical.Distinct(child) =>
execution.Distinct(partial = false,
execution.Distinct(partial = true, planLater(child))) :: Nil
@@ -308,22 +310,4 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => Nil
}
}
-
- case object CommandStrategy extends Strategy {
- def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case r: RunnableCommand => ExecutedCommand(r) :: Nil
- case logical.SetCommand(kv) =>
- Seq(ExecutedCommand(execution.SetCommand(kv, plan.output)))
- case logical.ExplainCommand(logicalPlan, extended) =>
- Seq(ExecutedCommand(
- execution.ExplainCommand(logicalPlan, plan.output, extended)))
- case logical.CacheTableCommand(tableName, optPlan, isLazy) =>
- Seq(ExecutedCommand(
- execution.CacheTableCommand(tableName, optPlan, isLazy)))
- case logical.UncacheTableCommand(tableName) =>
- Seq(ExecutedCommand(
- execution.UncacheTableCommand(tableName)))
- case _ => Nil
- }
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index e53723c176569..16ca4be5587c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -70,7 +70,7 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child:
override def output = child.output
// TODO: How to pick seed?
- override def execute() = child.execute().sample(withReplacement, fraction, seed)
+ override def execute() = child.execute().map(_.copy()).sample(withReplacement, fraction, seed)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index b8fa4b019953e..0d765c4c92f85 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -113,7 +113,7 @@ case class SetCommand(
@DeveloperApi
case class ExplainCommand(
logicalPlan: LogicalPlan,
- override val output: Seq[Attribute], extended: Boolean) extends RunnableCommand {
+ override val output: Seq[Attribute], extended: Boolean = false) extends RunnableCommand {
// Run through the optimizer to generate the physical plan.
override def run(sqlContext: SQLContext) = try {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
index fc70c183437f6..a9a6696cb15e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -18,31 +18,48 @@
package org.apache.spark.sql.json
import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.sources._
-private[sql] class DefaultSource extends RelationProvider {
- /** Returns a new base relation with the given parameters. */
+private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider {
+
+ /** Returns a new base relation with the parameters. */
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
- JSONRelation(fileName, samplingRatio)(sqlContext)
+ JSONRelation(fileName, samplingRatio, None)(sqlContext)
+ }
+
+ /** Returns a new base relation with the given schema and parameters. */
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String],
+ schema: StructType): BaseRelation = {
+ val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
+ val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
+
+ JSONRelation(fileName, samplingRatio, Some(schema))(sqlContext)
}
}
-private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)(
+private[sql] case class JSONRelation(
+ fileName: String,
+ samplingRatio: Double,
+ userSpecifiedSchema: Option[StructType])(
@transient val sqlContext: SQLContext)
extends TableScan {
private def baseRDD = sqlContext.sparkContext.textFile(fileName)
- override val schema =
- JsonRDD.inferSchema(
- baseRDD,
- samplingRatio,
- sqlContext.columnNameOfCorruptRecord)
+ override val schema = userSpecifiedSchema.getOrElse(
+ JsonRDD.nullTypeToStringType(
+ JsonRDD.inferSchema(
+ baseRDD,
+ samplingRatio,
+ sqlContext.columnNameOfCorruptRecord)))
override def buildScan() =
JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
index b237a07c72d07..2835dc3408b96 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
@@ -28,7 +28,7 @@ import parquet.schema.MessageType
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException}
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
/**
@@ -67,6 +67,8 @@ private[sql] case class ParquetRelation(
conf,
sqlContext.isParquetBinaryAsString)
+ lazy val attributeMap = AttributeMap(output.map(o => o -> o))
+
override def newInstance() = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type]
// Equals must also take into account the output attributes so that we can distinguish between
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index 96bace1769f71..f5487740d3af9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -64,18 +64,17 @@ case class ParquetTableScan(
// The resolution of Parquet attributes is case sensitive, so we resolve the original attributes
// by exprId. note: output cannot be transient, see
// https://issues.apache.org/jira/browse/SPARK-1367
- val normalOutput =
- attributes
- .filterNot(a => relation.partitioningAttributes.map(_.exprId).contains(a.exprId))
- .flatMap(a => relation.output.find(o => o.exprId == a.exprId))
+ val output = attributes.map(relation.attributeMap)
- val partOutput =
- attributes.flatMap(a => relation.partitioningAttributes.find(o => o.exprId == a.exprId))
+ // A mapping of ordinals partitionRow -> finalOutput.
+ val requestedPartitionOrdinals = {
+ val partitionAttributeOrdinals = AttributeMap(relation.partitioningAttributes.zipWithIndex)
- def output = partOutput ++ normalOutput
-
- assert(normalOutput.size + partOutput.size == attributes.size,
- s"$normalOutput + $partOutput != $attributes, ${relation.output}")
+ attributes.zipWithIndex.flatMap {
+ case (attribute, finalOrdinal) =>
+ partitionAttributeOrdinals.get(attribute).map(_ -> finalOrdinal)
+ }
+ }.toArray
override def execute(): RDD[Row] = {
import parquet.filter2.compat.FilterCompat.FilterPredicateCompat
@@ -97,7 +96,7 @@ case class ParquetTableScan(
// Store both requested and original schema in `Configuration`
conf.set(
RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA,
- ParquetTypesConverter.convertToString(normalOutput))
+ ParquetTypesConverter.convertToString(output))
conf.set(
RowWriteSupport.SPARK_ROW_SCHEMA,
ParquetTypesConverter.convertToString(relation.output))
@@ -125,7 +124,7 @@ case class ParquetTableScan(
classOf[Row],
conf)
- if (partOutput.nonEmpty) {
+ if (requestedPartitionOrdinals.nonEmpty) {
baseRDD.mapPartitionsWithInputSplit { case (split, iter) =>
val partValue = "([^=]+)=([^=]+)".r
val partValues =
@@ -138,15 +137,25 @@ case class ParquetTableScan(
case _ => None
}.toMap
+ // Convert the partitioning attributes into the correct types
val partitionRowValues =
- partOutput.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow))
+ relation.partitioningAttributes
+ .map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow))
new Iterator[Row] {
- private[this] val joinedRow = new JoinedRow5(Row(partitionRowValues:_*), null)
-
def hasNext = iter.hasNext
-
- def next() = joinedRow.withRight(iter.next()._2)
+ def next() = {
+ val row = iter.next()._2.asInstanceOf[SpecificMutableRow]
+
+ // Parquet will leave partitioning columns empty, so we fill them in here.
+ var i = 0
+ while (i < requestedPartitionOrdinals.size) {
+ row(requestedPartitionOrdinals(i)._2) =
+ partitionRowValues(requestedPartitionOrdinals(i)._1)
+ i += 1
+ }
+ row
+ }
}
}
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index 8a66ac31f2dfb..fe2c4d8436b2b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -17,16 +17,15 @@
package org.apache.spark.sql.sources
-import org.apache.spark.Logging
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.execution.RunnableCommand
-import org.apache.spark.util.Utils
-
import scala.language.implicitConversions
-import scala.util.parsing.combinator.lexical.StdLexical
import scala.util.parsing.combinator.syntactical.StandardTokenParsers
import scala.util.parsing.combinator.PackratParsers
+import org.apache.spark.Logging
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.execution.RunnableCommand
+import org.apache.spark.util.Utils
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.SqlLexical
@@ -44,6 +43,14 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
}
}
+ def parseType(input: String): DataType = {
+ phrase(dataType)(new lexical.Scanner(input)) match {
+ case Success(r, x) => r
+ case x =>
+ sys.error(s"Unsupported dataType: $x")
+ }
+ }
+
protected case class Keyword(str: String)
protected implicit def asParser(k: Keyword): Parser[String] =
@@ -55,6 +62,24 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
protected val USING = Keyword("USING")
protected val OPTIONS = Keyword("OPTIONS")
+ // Data types.
+ protected val STRING = Keyword("STRING")
+ protected val BINARY = Keyword("BINARY")
+ protected val BOOLEAN = Keyword("BOOLEAN")
+ protected val TINYINT = Keyword("TINYINT")
+ protected val SMALLINT = Keyword("SMALLINT")
+ protected val INT = Keyword("INT")
+ protected val BIGINT = Keyword("BIGINT")
+ protected val FLOAT = Keyword("FLOAT")
+ protected val DOUBLE = Keyword("DOUBLE")
+ protected val DECIMAL = Keyword("DECIMAL")
+ protected val DATE = Keyword("DATE")
+ protected val TIMESTAMP = Keyword("TIMESTAMP")
+ protected val VARCHAR = Keyword("VARCHAR")
+ protected val ARRAY = Keyword("ARRAY")
+ protected val MAP = Keyword("MAP")
+ protected val STRUCT = Keyword("STRUCT")
+
// Use reflection to find the reserved words defined in this class.
protected val reservedWords =
this.getClass
@@ -67,15 +92,25 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
protected lazy val ddl: Parser[LogicalPlan] = createTable
/**
- * CREATE TEMPORARY TABLE avroTable
+ * `CREATE TEMPORARY TABLE avroTable
* USING org.apache.spark.sql.avro
- * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")
+ * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
+ * or
+ * `CREATE TEMPORARY TABLE avroTable(intField int, stringField string...)
+ * USING org.apache.spark.sql.avro
+ * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
*/
protected lazy val createTable: Parser[LogicalPlan] =
- CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ {
- case tableName ~ provider ~ opts =>
- CreateTableUsing(tableName, provider, opts)
+ (
+ CREATE ~ TEMPORARY ~ TABLE ~> ident
+ ~ (tableCols).? ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ {
+ case tableName ~ columns ~ provider ~ opts =>
+ val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields)))
+ CreateTableUsing(tableName, userSpecifiedSchema, provider, opts)
}
+ )
+
+ protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")"
protected lazy val options: Parser[Map[String, String]] =
"(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap }
@@ -83,10 +118,66 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")}
protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) }
+
+ protected lazy val column: Parser[StructField] =
+ ident ~ dataType ^^ { case columnName ~ typ =>
+ StructField(columnName, typ)
+ }
+
+ protected lazy val primitiveType: Parser[DataType] =
+ STRING ^^^ StringType |
+ BINARY ^^^ BinaryType |
+ BOOLEAN ^^^ BooleanType |
+ TINYINT ^^^ ByteType |
+ SMALLINT ^^^ ShortType |
+ INT ^^^ IntegerType |
+ BIGINT ^^^ LongType |
+ FLOAT ^^^ FloatType |
+ DOUBLE ^^^ DoubleType |
+ fixedDecimalType | // decimal with precision/scale
+ DECIMAL ^^^ DecimalType.Unlimited | // decimal with no precision/scale
+ DATE ^^^ DateType |
+ TIMESTAMP ^^^ TimestampType |
+ VARCHAR ~ "(" ~ numericLit ~ ")" ^^^ StringType
+
+ protected lazy val fixedDecimalType: Parser[DataType] =
+ (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ {
+ case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
+ }
+
+ protected lazy val arrayType: Parser[DataType] =
+ ARRAY ~> "<" ~> dataType <~ ">" ^^ {
+ case tpe => ArrayType(tpe)
+ }
+
+ protected lazy val mapType: Parser[DataType] =
+ MAP ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ {
+ case t1 ~ _ ~ t2 => MapType(t1, t2)
+ }
+
+ protected lazy val structField: Parser[StructField] =
+ ident ~ ":" ~ dataType ^^ {
+ case fieldName ~ _ ~ tpe => StructField(fieldName, tpe, nullable = true)
+ }
+
+ protected lazy val structType: Parser[DataType] =
+ (STRUCT ~> "<" ~> repsep(structField, ",") <~ ">" ^^ {
+ case fields => new StructType(fields)
+ }) |
+ (STRUCT ~> "<>" ^^ {
+ case fields => new StructType(Nil)
+ })
+
+ private[sql] lazy val dataType: Parser[DataType] =
+ arrayType |
+ mapType |
+ structType |
+ primitiveType
}
private[sql] case class CreateTableUsing(
tableName: String,
+ userSpecifiedSchema: Option[StructType],
provider: String,
options: Map[String, String]) extends RunnableCommand {
@@ -99,8 +190,29 @@ private[sql] case class CreateTableUsing(
sys.error(s"Failed to load class for data source: $provider")
}
}
- val dataSource = clazz.newInstance().asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
- val relation = dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options))
+
+ val relation = userSpecifiedSchema match {
+ case Some(schema: StructType) => {
+ clazz.newInstance match {
+ case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
+ dataSource
+ .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider]
+ .createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
+ case _ =>
+ sys.error(s"${clazz.getCanonicalName} should extend SchemaRelationProvider.")
+ }
+ }
+ case None => {
+ clazz.newInstance match {
+ case dataSource: org.apache.spark.sql.sources.RelationProvider =>
+ dataSource
+ .asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
+ .createRelation(sqlContext, new CaseInsensitiveMap(options))
+ case _ =>
+ sys.error(s"${clazz.getCanonicalName} should extend RelationProvider.")
+ }
+ }
+ }
sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName)
Seq.empty
@@ -110,7 +222,8 @@ private[sql] case class CreateTableUsing(
/**
* Builds a map in which keys are case insensitive
*/
-protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] {
+protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String]
+ with Serializable {
val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 02eff80456dbe..990f7e0e74bcf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.sources
import org.apache.spark.annotation.{Experimental, DeveloperApi}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{SQLConf, Row, SQLContext, StructType}
+import org.apache.spark.sql.{Row, SQLContext, StructType}
import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute}
/**
@@ -44,6 +44,33 @@ trait RelationProvider {
def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation
}
+/**
+ * ::DeveloperApi::
+ * Implemented by objects that produce relations for a specific kind of data source. When
+ * Spark SQL is given a DDL operation with
+ * 1. USING clause: to specify the implemented SchemaRelationProvider
+ * 2. User defined schema: users can define schema optionally when create table
+ *
+ * Users may specify the fully qualified class name of a given data source. When that class is
+ * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for
+ * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the
+ * data source 'org.apache.spark.sql.json.DefaultSource'
+ *
+ * A new instance of this class with be instantiated each time a DDL call is made.
+ */
+@DeveloperApi
+trait SchemaRelationProvider {
+ /**
+ * Returns a new base relation with the given parameters and user defined schema.
+ * Note: the parameters' keywords are case insensitive and this insensitivity is enforced
+ * by the Map that is passed to the function.
+ */
+ def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String],
+ schema: StructType): BaseRelation
+}
+
/**
* ::DeveloperApi::
* Represents a collection of tuples with a known schema. Classes that extend BaseRelation must
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 1a4232dab86e7..c7e136388fce8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -302,8 +302,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
upperCaseData.where('N <= 4).registerTempTable("left")
upperCaseData.where('N >= 3).registerTempTable("right")
- val left = UnresolvedRelation(None, "left", None)
- val right = UnresolvedRelation(None, "right", None)
+ val left = UnresolvedRelation(Seq("left"), None)
+ val right = UnresolvedRelation(Seq("right"), None)
checkAnswer(
left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index add4e218a22ee..d9de5686dce48 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -272,6 +272,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
mapData.collect().take(1).toSeq)
}
+ test("from follow multiple brackets") {
+ checkAnswer(sql(
+ "select key from ((select * from testData limit 1) union all (select * from testData limit 1)) x limit 1"),
+ 1
+ )
+
+ checkAnswer(sql(
+ "select key from (select * from testData) x limit 1"),
+ 1
+ )
+
+ checkAnswer(sql(
+ "select key from (select * from testData limit 1 union all select * from testData limit 1) x limit 1"),
+ 1
+ )
+ }
+
test("average") {
checkAnswer(
sql("SELECT AVG(a) FROM testData2"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 3cd7b0115d567..605190f5ae6a2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -17,7 +17,10 @@
package org.apache.spark.sql.sources
+import java.sql.{Timestamp, Date}
+
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.types.DecimalType
class DefaultSource extends SimpleScanSource
@@ -38,9 +41,77 @@ case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext)
override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_))
}
+class AllDataTypesScanSource extends SchemaRelationProvider {
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String],
+ schema: StructType): BaseRelation = {
+ AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext)
+ }
+}
+
+case class AllDataTypesScan(
+ from: Int,
+ to: Int,
+ userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext)
+ extends TableScan {
+
+ override def schema = userSpecifiedSchema
+
+ override def buildScan() = {
+ sqlContext.sparkContext.parallelize(from to to).map { i =>
+ Row(
+ s"str_$i",
+ s"str_$i".getBytes(),
+ i % 2 == 0,
+ i.toByte,
+ i.toShort,
+ i,
+ i.toLong,
+ i.toFloat,
+ i.toDouble,
+ BigDecimal(i),
+ BigDecimal(i),
+ new Date((i + 1) * 8640000),
+ new Timestamp(20000 + i),
+ s"varchar_$i",
+ Seq(i, i + 1),
+ Seq(Map(s"str_$i" -> Row(i.toLong))),
+ Map(i -> i.toString),
+ Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)),
+ Row(i, i.toString),
+ Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000)))))
+ }
+ }
+}
+
class TableScanSuite extends DataSourceTest {
import caseInsensisitiveContext._
+ var tableWithSchemaExpected = (1 to 10).map { i =>
+ Row(
+ s"str_$i",
+ s"str_$i",
+ i % 2 == 0,
+ i.toByte,
+ i.toShort,
+ i,
+ i.toLong,
+ i.toFloat,
+ i.toDouble,
+ BigDecimal(i),
+ BigDecimal(i),
+ new Date((i + 1) * 8640000),
+ new Timestamp(20000 + i),
+ s"varchar_$i",
+ Seq(i, i + 1),
+ Seq(Map(s"str_$i" -> Row(i.toLong))),
+ Map(i -> i.toString),
+ Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)),
+ Row(i, i.toString),
+ Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000)))))
+ }.toSeq
+
before {
sql(
"""
@@ -51,6 +122,37 @@ class TableScanSuite extends DataSourceTest {
| To '10'
|)
""".stripMargin)
+
+ sql(
+ """
+ |CREATE TEMPORARY TABLE tableWithSchema (
+ |`string$%Field` stRIng,
+ |binaryField binary,
+ |`booleanField` boolean,
+ |ByteField tinyint,
+ |shortField smaLlint,
+ |int_Field iNt,
+ |`longField_:,<>=+/~^` Bigint,
+ |floatField flOat,
+ |doubleField doubLE,
+ |decimalField1 decimal,
+ |decimalField2 decimal(9,2),
+ |dateField dAte,
+ |timestampField tiMestamp,
+ |varcharField varchaR(12),
+ |arrayFieldSimple Array,
+ |arrayFieldComplex Array |