diff --git a/.rat-excludes b/.rat-excludes index 15344dfb292db..796c32a80896c 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -4,6 +4,8 @@ target .classpath .mima-excludes .generated-mima-excludes +.generated-mima-class-excludes +.generated-mima-member-excludes .rat-excludes .*md derby.log diff --git a/bin/spark-class b/bin/spark-class index cfe363a71da31..60d9657c0ffcd 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -130,6 +130,11 @@ else fi if [[ "$1" =~ org.apache.spark.tools.* ]]; then + if test -z "$SPARK_TOOLS_JAR"; then + echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SCALA_VERSION/" 1>&2 + echo "You need to build spark before running $1." 1>&2 + exit 1 + fi CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR" fi diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index cdfd338081fa2..9c55bfbb47626 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -127,7 +127,7 @@ class Accumulable[R, T] ( Accumulators.register(this, false) } - override def toString = value_.toString + override def toString = if (value_ == null) "null" else value_.toString } /** diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index c8c194a111aac..09a60571238ea 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -61,7 +61,8 @@ class ShuffleDependency[K, V, C]( val partitioner: Partitioner, val serializer: Option[Serializer] = None, val keyOrdering: Option[Ordering[K]] = None, - val aggregator: Option[Aggregator[K, V, C]] = None) + val aggregator: Option[Aggregator[K, V, C]] = None, + val mapSideCombine: Boolean = false) extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { val shuffleId: Int = rdd.context.newShuffleId() diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index a3074916d13e7..5e8bd8c8e533a 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -30,27 +30,69 @@ import org.apache.spark.storage.BlockManagerId @DeveloperApi sealed trait TaskEndReason +/** + * :: DeveloperApi :: + * Task succeeded. + */ @DeveloperApi case object Success extends TaskEndReason +/** + * :: DeveloperApi :: + * Various possible reasons why a task failed. + */ +@DeveloperApi +sealed trait TaskFailedReason extends TaskEndReason { + /** Error message displayed in the web UI. */ + def toErrorString: String +} + +/** + * :: DeveloperApi :: + * A [[org.apache.spark.scheduler.ShuffleMapTask]] that completed successfully earlier, but we + * lost the executor before the stage completed. This means Spark needs to reschedule the task + * to be re-executed on a different executor. + */ @DeveloperApi -case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it +case object Resubmitted extends TaskFailedReason { + override def toErrorString: String = "Resubmitted (resubmitted due to lost executor)" +} +/** + * :: DeveloperApi :: + * Task failed to fetch shuffle data from a remote node. Probably means we have lost the remote + * executors the task is trying to fetch from, and thus need to rerun the previous stage. + */ @DeveloperApi case class FetchFailed( bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) - extends TaskEndReason + extends TaskFailedReason { + override def toErrorString: String = { + val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString + s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId)" + } +} +/** + * :: DeveloperApi :: + * Task failed due to a runtime exception. This is the most common failure case and also captures + * user program exceptions. + */ @DeveloperApi case class ExceptionFailure( className: String, description: String, stackTrace: Array[StackTraceElement], metrics: Option[TaskMetrics]) - extends TaskEndReason + extends TaskFailedReason { + override def toErrorString: String = { + val stackTraceString = if (stackTrace == null) "null" else stackTrace.mkString("\n") + s"$className ($description}\n$stackTraceString" + } +} /** * :: DeveloperApi :: @@ -58,10 +100,18 @@ case class ExceptionFailure( * it was fetched. */ @DeveloperApi -case object TaskResultLost extends TaskEndReason +case object TaskResultLost extends TaskFailedReason { + override def toErrorString: String = "TaskResultLost (result lost from block manager)" +} +/** + * :: DeveloperApi :: + * Task was killed intentionally and needs to be rescheduled. + */ @DeveloperApi -case object TaskKilled extends TaskEndReason +case object TaskKilled extends TaskFailedReason { + override def toErrorString: String = "TaskKilled (killed intentionally)" +} /** * :: DeveloperApi :: @@ -69,7 +119,9 @@ case object TaskKilled extends TaskEndReason * the task crashed the JVM. */ @DeveloperApi -case object ExecutorLostFailure extends TaskEndReason +case object ExecutorLostFailure extends TaskFailedReason { + override def toErrorString: String = "ExecutorLostFailure (executor lost)" +} /** * :: DeveloperApi :: @@ -77,4 +129,6 @@ case object ExecutorLostFailure extends TaskEndReason * deserializing the task result. */ @DeveloperApi -case object UnknownReason extends TaskEndReason +case object UnknownReason extends TaskFailedReason { + override def toErrorString: String = "UnknownReason" +} diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala new file mode 100644 index 0000000000000..a0e8bd403a41d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import org.apache.spark.ui.SparkUI + +private[spark] case class ApplicationHistoryInfo( + id: String, + name: String, + startTime: Long, + endTime: Long, + lastUpdated: Long, + sparkUser: String) + +private[spark] abstract class ApplicationHistoryProvider { + + /** + * Returns a list of applications available for the history server to show. + * + * @return List of all know applications. + */ + def getListing(): Seq[ApplicationHistoryInfo] + + /** + * Returns the Spark UI for a specific application. + * + * @param appId The application ID. + * @return The application's UI, or null if application is not found. + */ + def getAppUI(appId: String): SparkUI + + /** + * Called when the server is shutting down. + */ + def stop(): Unit = { } + + /** + * Returns configuration data to be shown in the History Server home page. + * + * @return A map with the configuration data. Data is show in the order returned by the map. + */ + def getConfig(): Map[String, String] = Map() + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala new file mode 100644 index 0000000000000..a8c9ac072449f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import java.io.FileNotFoundException + +import scala.collection.mutable + +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.scheduler._ +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.Utils + +private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHistoryProvider + with Logging { + + // Interval between each check for event log updates + private val UPDATE_INTERVAL_MS = conf.getInt("spark.history.fs.updateInterval", + conf.getInt("spark.history.updateInterval", 10)) * 1000 + + private val logDir = conf.get("spark.history.fs.logDirectory", null) + if (logDir == null) { + throw new IllegalArgumentException("Logging directory must be specified.") + } + + private val fs = Utils.getHadoopFileSystem(logDir) + + // A timestamp of when the disk was last accessed to check for log updates + private var lastLogCheckTimeMs = -1L + + // List of applications, in order from newest to oldest. + @volatile private var appList: Seq[ApplicationHistoryInfo] = Nil + + /** + * A background thread that periodically checks for event log updates on disk. + * + * If a log check is invoked manually in the middle of a period, this thread re-adjusts the + * time at which it performs the next log check to maintain the same period as before. + * + * TODO: Add a mechanism to update manually. + */ + private val logCheckingThread = new Thread("LogCheckingThread") { + override def run() = Utils.logUncaughtExceptions { + while (true) { + val now = getMonotonicTimeMs() + if (now - lastLogCheckTimeMs > UPDATE_INTERVAL_MS) { + Thread.sleep(UPDATE_INTERVAL_MS) + } else { + // If the user has manually checked for logs recently, wait until + // UPDATE_INTERVAL_MS after the last check time + Thread.sleep(lastLogCheckTimeMs + UPDATE_INTERVAL_MS - now) + } + checkForLogs() + } + } + } + + initialize() + + private def initialize() { + // Validate the log directory. + val path = new Path(logDir) + if (!fs.exists(path)) { + throw new IllegalArgumentException( + "Logging directory specified does not exist: %s".format(logDir)) + } + if (!fs.getFileStatus(path).isDir) { + throw new IllegalArgumentException( + "Logging directory specified is not a directory: %s".format(logDir)) + } + + checkForLogs() + logCheckingThread.setDaemon(true) + logCheckingThread.start() + } + + override def getListing() = appList + + override def getAppUI(appId: String): SparkUI = { + try { + val appLogDir = fs.getFileStatus(new Path(logDir, appId)) + loadAppInfo(appLogDir, true)._2 + } catch { + case e: FileNotFoundException => null + } + } + + override def getConfig(): Map[String, String] = + Map(("Event Log Location" -> logDir)) + + /** + * Builds the application list based on the current contents of the log directory. + * Tries to reuse as much of the data already in memory as possible, by not reading + * applications that haven't been updated since last time the logs were checked. + */ + private def checkForLogs() = { + lastLogCheckTimeMs = getMonotonicTimeMs() + logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTimeMs)) + try { + val logStatus = fs.listStatus(new Path(logDir)) + val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]() + val logInfos = logDirs.filter { + dir => fs.isFile(new Path(dir.getPath(), EventLoggingListener.APPLICATION_COMPLETE)) + } + + val currentApps = Map[String, ApplicationHistoryInfo]( + appList.map(app => (app.id -> app)):_*) + + // For any application that either (i) is not listed or (ii) has changed since the last time + // the listing was created (defined by the log dir's modification time), load the app's info. + // Otherwise just reuse what's already in memory. + val newApps = new mutable.ArrayBuffer[ApplicationHistoryInfo](logInfos.size) + for (dir <- logInfos) { + val curr = currentApps.getOrElse(dir.getPath().getName(), null) + if (curr == null || curr.lastUpdated < getModificationTime(dir)) { + try { + newApps += loadAppInfo(dir, false)._1 + } catch { + case e: Exception => logError(s"Failed to load app info from directory $dir.") + } + } else { + newApps += curr + } + } + + appList = newApps.sortBy { info => -info.endTime } + } catch { + case t: Throwable => logError("Exception in checking for event log updates", t) + } + } + + /** + * Parse the application's logs to find out the information we need to build the + * listing page. + * + * When creating the listing of available apps, there is no need to load the whole UI for the + * application. The UI is requested by the HistoryServer (by calling getAppInfo()) when the user + * clicks on a specific application. + * + * @param logDir Directory with application's log files. + * @param renderUI Whether to create the SparkUI for the application. + * @return A 2-tuple `(app info, ui)`. `ui` will be null if `renderUI` is false. + */ + private def loadAppInfo(logDir: FileStatus, renderUI: Boolean) = { + val elogInfo = EventLoggingListener.parseLoggingInfo(logDir.getPath(), fs) + val path = logDir.getPath + val appId = path.getName + val replayBus = new ReplayListenerBus(elogInfo.logPaths, fs, elogInfo.compressionCodec) + val appListener = new ApplicationEventListener + replayBus.addListener(appListener) + + val ui: SparkUI = if (renderUI) { + val conf = this.conf.clone() + val appSecManager = new SecurityManager(conf) + new SparkUI(conf, appSecManager, replayBus, appId, "/history/" + appId) + // Do not call ui.bind() to avoid creating a new server for each application + } else { + null + } + + replayBus.replay() + val appInfo = ApplicationHistoryInfo( + appId, + appListener.appName, + appListener.startTime, + appListener.endTime, + getModificationTime(logDir), + appListener.sparkUser) + + if (ui != null) { + val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) + ui.getSecurityManager.setUIAcls(uiAclsEnabled) + ui.getSecurityManager.setViewAcls(appListener.sparkUser, appListener.viewAcls) + } + (appInfo, ui) + } + + /** Return when this directory was last modified. */ + private def getModificationTime(dir: FileStatus): Long = { + try { + val logFiles = fs.listStatus(dir.getPath) + if (logFiles != null && !logFiles.isEmpty) { + logFiles.map(_.getModificationTime).max + } else { + dir.getModificationTime + } + } catch { + case t: Throwable => + logError("Exception in accessing modification time of %s".format(dir.getPath), t) + -1L + } + } + + /** Returns the system's mononotically increasing time. */ + private def getMonotonicTimeMs() = System.nanoTime() / (1000 * 1000) + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 180c853ce3096..a958c837c2ff6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -25,20 +25,36 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { + private val pageSize = 20 + def render(request: HttpServletRequest): Seq[Node] = { - val appRows = parent.appIdToInfo.values.toSeq.sortBy { app => -app.lastUpdated } - val appTable = UIUtils.listingTable(appHeader, appRow, appRows) + val requestedPage = Option(request.getParameter("page")).getOrElse("1").toInt + val requestedFirst = (requestedPage - 1) * pageSize + + val allApps = parent.getApplicationList() + val actualFirst = if (requestedFirst < allApps.size) requestedFirst else 0 + val apps = allApps.slice(actualFirst, Math.min(actualFirst + pageSize, allApps.size)) + + val actualPage = (actualFirst / pageSize) + 1 + val last = Math.min(actualFirst + pageSize, allApps.size) - 1 + val pageCount = allApps.size / pageSize + (if (allApps.size % pageSize > 0) 1 else 0) + + val appTable = UIUtils.listingTable(appHeader, appRow, apps) + val providerConfig = parent.getProviderConfig() val content =
{ - if (parent.appIdToInfo.size > 0) { + if (allApps.size > 0) {

- Showing {parent.appIdToInfo.size}/{parent.getNumApplications} - Completed Application{if (parent.getNumApplications > 1) "s" else ""} + Showing {actualFirst + 1}-{last + 1} of {allApps.size} + + {if (actualPage > 1) <} + {if (actualPage < pageCount) >} +

++ appTable } else { @@ -56,26 +72,20 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { "Completed", "Duration", "Spark User", - "Log Directory", "Last Updated") private def appRow(info: ApplicationHistoryInfo): Seq[Node] = { - val appName = if (info.started) info.name else info.logDirPath.getName - val uiAddress = parent.getAddress + info.ui.basePath - val startTime = if (info.started) UIUtils.formatDate(info.startTime) else "Not started" - val endTime = if (info.completed) UIUtils.formatDate(info.endTime) else "Not completed" - val difference = if (info.started && info.completed) info.endTime - info.startTime else -1L - val duration = if (difference > 0) UIUtils.formatDuration(difference) else "---" - val sparkUser = if (info.started) info.sparkUser else "Unknown user" - val logDirectory = info.logDirPath.getName + val uiAddress = "/history/" + info.id + val startTime = UIUtils.formatDate(info.startTime) + val endTime = UIUtils.formatDate(info.endTime) + val duration = UIUtils.formatDuration(info.endTime - info.startTime) val lastUpdated = UIUtils.formatDate(info.lastUpdated) - {appName} + {info.name} {startTime} {endTime} {duration} - {sparkUser} - {logDirectory} + {info.sparkUser} {lastUpdated} } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index a9c11dca5678e..29a78a56c8ed5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -17,14 +17,15 @@ package org.apache.spark.deploy.history -import scala.collection.mutable +import java.util.NoSuchElementException +import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} -import org.apache.hadoop.fs.{FileStatus, Path} +import com.google.common.cache._ +import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.scheduler._ -import org.apache.spark.ui.{WebUI, SparkUI} +import org.apache.spark.ui.{WebUI, SparkUI, UIUtils} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.Utils @@ -38,56 +39,68 @@ import org.apache.spark.util.Utils * application's event logs are maintained in the application's own sub-directory. This * is the same structure as maintained in the event log write code path in * EventLoggingListener. - * - * @param baseLogDir The base directory in which event logs are found */ class HistoryServer( - val baseLogDir: String, + conf: SparkConf, + provider: ApplicationHistoryProvider, securityManager: SecurityManager, - conf: SparkConf) - extends WebUI(securityManager, HistoryServer.WEB_UI_PORT, conf) with Logging { - - import HistoryServer._ + port: Int) + extends WebUI(securityManager, port, conf) with Logging { - private val fileSystem = Utils.getHadoopFileSystem(baseLogDir) - private val localHost = Utils.localHostName() - private val publicHost = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHost) + // How many applications to retain + private val retainedApplications = conf.getInt("spark.history.retainedApplications", 50) - // A timestamp of when the disk was last accessed to check for log updates - private var lastLogCheckTime = -1L + private val appLoader = new CacheLoader[String, SparkUI] { + override def load(key: String): SparkUI = { + val ui = provider.getAppUI(key) + if (ui == null) { + throw new NoSuchElementException() + } + attachSparkUI(ui) + ui + } + } - // Number of completed applications found in this directory - private var numCompletedApplications = 0 + private val appCache = CacheBuilder.newBuilder() + .maximumSize(retainedApplications) + .removalListener(new RemovalListener[String, SparkUI] { + override def onRemoval(rm: RemovalNotification[String, SparkUI]) = { + detachSparkUI(rm.getValue()) + } + }) + .build(appLoader) + + private val loaderServlet = new HttpServlet { + protected override def doGet(req: HttpServletRequest, res: HttpServletResponse): Unit = { + val parts = Option(req.getPathInfo()).getOrElse("").split("/") + if (parts.length < 2) { + res.sendError(HttpServletResponse.SC_BAD_REQUEST, + s"Unexpected path info in request (URI = ${req.getRequestURI()}") + return + } - @volatile private var stopped = false + val appId = parts(1) - /** - * A background thread that periodically checks for event log updates on disk. - * - * If a log check is invoked manually in the middle of a period, this thread re-adjusts the - * time at which it performs the next log check to maintain the same period as before. - * - * TODO: Add a mechanism to update manually. - */ - private val logCheckingThread = new Thread { - override def run(): Unit = Utils.logUncaughtExceptions { - while (!stopped) { - val now = System.currentTimeMillis - if (now - lastLogCheckTime > UPDATE_INTERVAL_MS) { - checkForLogs() - Thread.sleep(UPDATE_INTERVAL_MS) - } else { - // If the user has manually checked for logs recently, wait until - // UPDATE_INTERVAL_MS after the last check time - Thread.sleep(lastLogCheckTime + UPDATE_INTERVAL_MS - now) + // Note we don't use the UI retrieved from the cache; the cache loader above will register + // the app's UI, and all we need to do is redirect the user to the same URI that was + // requested, and the proper data should be served at that point. + try { + appCache.get(appId) + res.sendRedirect(res.encodeRedirectURL(req.getRequestURI())) + } catch { + case e: Exception => e.getCause() match { + case nsee: NoSuchElementException => + val msg =
Application {appId} not found.
+ res.setStatus(HttpServletResponse.SC_NOT_FOUND) + UIUtils.basicSparkPage(msg, "Not Found").foreach( + n => res.getWriter().write(n.toString)) + + case cause: Exception => throw cause } } } } - // A mapping of application ID to its history information, which includes the rendered UI - val appIdToInfo = mutable.HashMap[String, ApplicationHistoryInfo]() - initialize() /** @@ -98,108 +111,23 @@ class HistoryServer( */ def initialize() { attachPage(new HistoryPage(this)) - attachHandler(createStaticHandler(STATIC_RESOURCE_DIR, "/static")) + attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) + + val contextHandler = new ServletContextHandler + contextHandler.setContextPath("/history") + contextHandler.addServlet(new ServletHolder(loaderServlet), "/*") + attachHandler(contextHandler) } /** Bind to the HTTP server behind this web interface. */ override def bind() { super.bind() - logCheckingThread.start() - } - - /** - * Check for any updates to event logs in the base directory. This is only effective once - * the server has been bound. - * - * If a new completed application is found, the server renders the associated SparkUI - * from the application's event logs, attaches this UI to itself, and stores metadata - * information for this application. - * - * If the logs for an existing completed application are no longer found, the server - * removes all associated information and detaches the SparkUI. - */ - def checkForLogs() = synchronized { - if (serverInfo.isDefined) { - lastLogCheckTime = System.currentTimeMillis - logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTime)) - try { - val logStatus = fileSystem.listStatus(new Path(baseLogDir)) - val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]() - val logInfos = logDirs - .sortBy { dir => getModificationTime(dir) } - .map { dir => (dir, EventLoggingListener.parseLoggingInfo(dir.getPath, fileSystem)) } - .filter { case (dir, info) => info.applicationComplete } - - // Logging information for applications that should be retained - val retainedLogInfos = logInfos.takeRight(RETAINED_APPLICATIONS) - val retainedAppIds = retainedLogInfos.map { case (dir, _) => dir.getPath.getName } - - // Remove any applications that should no longer be retained - appIdToInfo.foreach { case (appId, info) => - if (!retainedAppIds.contains(appId)) { - detachSparkUI(info.ui) - appIdToInfo.remove(appId) - } - } - - // Render the application's UI if it is not already there - retainedLogInfos.foreach { case (dir, info) => - val appId = dir.getPath.getName - if (!appIdToInfo.contains(appId)) { - renderSparkUI(dir, info) - } - } - - // Track the total number of completed applications observed this round - numCompletedApplications = logInfos.size - - } catch { - case e: Exception => logError("Exception in checking for event log updates", e) - } - } else { - logWarning("Attempted to check for event log updates before binding the server.") - } - } - - /** - * Render a new SparkUI from the event logs if the associated application is completed. - * - * HistoryServer looks for a special file that indicates application completion in the given - * directory. If this file exists, the associated application is regarded to be completed, in - * which case the server proceeds to render the SparkUI. Otherwise, the server does nothing. - */ - private def renderSparkUI(logDir: FileStatus, elogInfo: EventLoggingInfo) { - val path = logDir.getPath - val appId = path.getName - val replayBus = new ReplayListenerBus(elogInfo.logPaths, fileSystem, elogInfo.compressionCodec) - val appListener = new ApplicationEventListener - replayBus.addListener(appListener) - val appConf = conf.clone() - val appSecManager = new SecurityManager(appConf) - val ui = new SparkUI(conf, appSecManager, replayBus, appId, "/history/" + appId) - - // Do not call ui.bind() to avoid creating a new server for each application - replayBus.replay() - if (appListener.applicationStarted) { - appSecManager.setUIAcls(HISTORY_UI_ACLS_ENABLED) - appSecManager.setViewAcls(appListener.sparkUser, appListener.viewAcls) - attachSparkUI(ui) - val appName = appListener.appName - val sparkUser = appListener.sparkUser - val startTime = appListener.startTime - val endTime = appListener.endTime - val lastUpdated = getModificationTime(logDir) - ui.setAppName(appName + " (completed)") - appIdToInfo(appId) = ApplicationHistoryInfo(appId, appName, startTime, endTime, - lastUpdated, sparkUser, path, ui) - } } /** Stop the server and close the file system. */ override def stop() { super.stop() - stopped = true - fileSystem.close() + provider.stop() } /** Attach a reconstructed UI to this server. Only valid after bind(). */ @@ -215,27 +143,20 @@ class HistoryServer( ui.getHandlers.foreach(detachHandler) } - /** Return the address of this server. */ - def getAddress: String = "http://" + publicHost + ":" + boundPort + /** + * Returns a list of available applications, in descending order according to their end time. + * + * @return List of all known applications. + */ + def getApplicationList() = provider.getListing() - /** Return the number of completed applications found, whether or not the UI is rendered. */ - def getNumApplications: Int = numCompletedApplications + /** + * Returns the provider configuration to show in the listing page. + * + * @return A map with the provider's configuration. + */ + def getProviderConfig() = provider.getConfig() - /** Return when this directory was last modified. */ - private def getModificationTime(dir: FileStatus): Long = { - try { - val logFiles = fileSystem.listStatus(dir.getPath) - if (logFiles != null && !logFiles.isEmpty) { - logFiles.map(_.getModificationTime).max - } else { - dir.getModificationTime - } - } catch { - case e: Exception => - logError("Exception in accessing modification time of %s".format(dir.getPath), e) - -1L - } - } } /** @@ -251,30 +172,31 @@ class HistoryServer( object HistoryServer { private val conf = new SparkConf - // Interval between each check for event log updates - val UPDATE_INTERVAL_MS = conf.getInt("spark.history.updateInterval", 10) * 1000 - - // How many applications to retain - val RETAINED_APPLICATIONS = conf.getInt("spark.history.retainedApplications", 250) - - // The port to which the web UI is bound - val WEB_UI_PORT = conf.getInt("spark.history.ui.port", 18080) - - // set whether to enable or disable view acls for all applications - val HISTORY_UI_ACLS_ENABLED = conf.getBoolean("spark.history.ui.acls.enable", false) - - val STATIC_RESOURCE_DIR = SparkUI.STATIC_RESOURCE_DIR - def main(argStrings: Array[String]) { initSecurity() - val args = new HistoryServerArguments(argStrings) + val args = new HistoryServerArguments(conf, argStrings) val securityManager = new SecurityManager(conf) - val server = new HistoryServer(args.logDir, securityManager, conf) + + val providerName = conf.getOption("spark.history.provider") + .getOrElse(classOf[FsHistoryProvider].getName()) + val provider = Class.forName(providerName) + .getConstructor(classOf[SparkConf]) + .newInstance(conf) + .asInstanceOf[ApplicationHistoryProvider] + + val port = conf.getInt("spark.history.ui.port", 18080) + + val server = new HistoryServer(conf, provider, securityManager, port) server.bind() + Runtime.getRuntime().addShutdownHook(new Thread("HistoryServerStopper") { + override def run() = { + server.stop() + } + }) + // Wait until the end of the world... or if the HistoryServer process is manually stopped while(true) { Thread.sleep(Int.MaxValue) } - server.stop() } def initSecurity() { @@ -291,17 +213,3 @@ object HistoryServer { } } - - -private[spark] case class ApplicationHistoryInfo( - id: String, - name: String, - startTime: Long, - endTime: Long, - lastUpdated: Long, - sparkUser: String, - logDirPath: Path, - ui: SparkUI) { - def started = startTime != -1 - def completed = endTime != -1 -} diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 943c061743dbd..be9361b754fc3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -17,17 +17,14 @@ package org.apache.spark.deploy.history -import java.net.URI - -import org.apache.hadoop.fs.Path - +import org.apache.spark.SparkConf import org.apache.spark.util.Utils /** * Command-line parser for the master. */ -private[spark] class HistoryServerArguments(args: Array[String]) { - var logDir = "" +private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) { + private var logDir: String = null parse(args.toList) @@ -45,32 +42,36 @@ private[spark] class HistoryServerArguments(args: Array[String]) { case _ => printUsageAndExit(1) } - validateLogDir() - } - - private def validateLogDir() { - if (logDir == "") { - System.err.println("Logging directory must be specified.") - printUsageAndExit(1) - } - val fileSystem = Utils.getHadoopFileSystem(new URI(logDir)) - val path = new Path(logDir) - if (!fileSystem.exists(path)) { - System.err.println("Logging directory specified does not exist: %s".format(logDir)) - printUsageAndExit(1) - } - if (!fileSystem.getFileStatus(path).isDir) { - System.err.println("Logging directory specified is not a directory: %s".format(logDir)) - printUsageAndExit(1) + if (logDir != null) { + conf.set("spark.history.fs.logDirectory", logDir) } } private def printUsageAndExit(exitCode: Int) { System.err.println( - "Usage: HistoryServer [options]\n" + - "\n" + - "Options:\n" + - " -d DIR, --dir DIR Location of event log files") + """ + |Usage: HistoryServer + | + |Configuration options can be set by setting the corresponding JVM system property. + |History Server options are always available; additional options depend on the provider. + | + |History Server options: + | + | spark.history.ui.port Port where server will listen for connections + | (default 18080) + | spark.history.acls.enable Whether to enable view acls for all applications + | (default false) + | spark.history.provider Name of history provider class (defaults to + | file system-based provider) + | spark.history.retainedApplications Max number of application UIs to keep loaded in memory + | (default 50) + |FsHistoryProvider options: + | + | spark.history.fs.logDirectory Directory where app logs are stored (required) + | spark.history.fs.updateInterval How often to reload log data from storage (in seconds, + | default 10) + |""".stripMargin) System.exit(exitCode) } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 6433aac1c23e0..467317dd9b44c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -77,6 +77,7 @@ private[spark] class ExecutorRunner( * @param message the exception message which caused the executor's death */ private def killProcess(message: Option[String]) { + var exitCode: Option[Int] = None if (process != null) { logInfo("Killing process!") process.destroy() @@ -87,9 +88,9 @@ private[spark] class ExecutorRunner( if (stderrAppender != null) { stderrAppender.stop() } - val exitCode = process.waitFor() - worker ! ExecutorStateChanged(appId, execId, state, message, Some(exitCode)) + exitCode = Some(process.waitFor()) } + worker ! ExecutorStateChanged(appId, execId, state, message, exitCode) } /** Stop this executor runner, including killing the process it launched */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 6a5ffb1b71bfb..b389cb546de6c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -120,7 +120,7 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") w
- UIUtils.basicSparkPage(content, logType + " log page for " + appId) + UIUtils.basicSparkPage(content, logType + " log page for " + appId.getOrElse("unknown app")) } /** Get the part of the log files given the offset and desired length of bytes */ diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 2279d77c91c89..b5fd334f40203 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -19,25 +19,26 @@ package org.apache.spark.executor import java.nio.ByteBuffer -import akka.actor._ -import akka.remote._ +import scala.concurrent.Await -import org.apache.spark.{SparkEnv, Logging, SecurityManager, SparkConf} +import akka.actor.{Actor, ActorSelection, Props} +import akka.pattern.Patterns +import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent} + +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskDescription +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{AkkaUtils, Utils} private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, executorId: String, hostPort: String, - cores: Int) - extends Actor - with ExecutorBackend - with Logging { + cores: Int, + sparkProperties: Seq[(String, String)]) extends Actor with ExecutorBackend with Logging { Utils.checkHostPort(hostPort, "Expected hostport") @@ -52,7 +53,7 @@ private[spark] class CoarseGrainedExecutorBackend( } override def receive = { - case RegisteredExecutor(sparkProperties) => + case RegisteredExecutor => logInfo("Successfully registered with driver") // Make this host instead of hostPort ? executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties, @@ -101,26 +102,33 @@ private[spark] object CoarseGrainedExecutorBackend { workerUrl: Option[String]) { SparkHadoopUtil.get.runAsSparkUser { () => - // Debug code - Utils.checkHost(hostname) - - val conf = new SparkConf - // Create a new ActorSystem to run the backend, because we can't create a - // SparkEnv / Executor before getting started with all our system properties, etc - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0, - conf, new SecurityManager(conf)) - // set it - val sparkHostPort = hostname + ":" + boundPort - actorSystem.actorOf( - Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, - sparkHostPort, cores), - name = "Executor") - workerUrl.foreach { - url => - actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") - } - actorSystem.awaitTermination() - + // Debug code + Utils.checkHost(hostname) + + // Bootstrap to fetch the driver's Spark properties. + val executorConf = new SparkConf + val (fetcher, _) = AkkaUtils.createActorSystem( + "driverPropsFetcher", hostname, 0, executorConf, new SecurityManager(executorConf)) + val driver = fetcher.actorSelection(driverUrl) + val timeout = AkkaUtils.askTimeout(executorConf) + val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) + val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] + fetcher.shutdown() + + // Create a new ActorSystem using driver's Spark properties to run the backend. + val driverConf = new SparkConf().setAll(props) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + "sparkExecutor", hostname, 0, driverConf, new SecurityManager(driverConf)) + // set it + val sparkHostPort = hostname + ":" + boundPort + actorSystem.actorOf( + Props(classOf[CoarseGrainedExecutorBackend], + driverUrl, executorId, sparkHostPort, cores, props), + name = "Executor") + workerUrl.foreach { url => + actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") + } + actorSystem.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 1c95f4d9ba136..1f0785d4056a7 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -212,7 +212,7 @@ private[spark] class Executor( val serializedDirectResult = ser.serialize(directResult) logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit) val serializedResult = { - if (serializedDirectResult.limit >= akkaFrameSize - 1024) { + if (serializedDirectResult.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { logInfo("Storing result for " + taskId + " in local BlockManager") val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala index 3b6298a26d7c5..5285ec82c1b64 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -17,11 +17,6 @@ package org.apache.spark.network -import org.apache.spark._ -import org.apache.spark.SparkSaslServer - -import scala.collection.mutable.{HashMap, Queue, ArrayBuffer} - import java.net._ import java.nio._ import java.nio.channels._ @@ -41,7 +36,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = { this(channel_, selector_, ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]), id_) + channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), id_) } channel.configureBlocking(false) @@ -89,7 +84,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, private def disposeSasl() { if (sparkSaslServer != null) { - sparkSaslServer.dispose(); + sparkSaslServer.dispose() } if (sparkSaslClient != null) { @@ -328,15 +323,13 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // Is highly unlikely unless there was an unclean close of socket, etc registerInterest() logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") - true } catch { case e: Exception => { logWarning("Error finishing connection to " + address, e) callOnExceptionCallback(e) - // ignore - return true } } + true } override def write(): Boolean = { @@ -546,7 +539,7 @@ private[spark] class ReceivingConnection( /* println("Filled buffer at " + System.currentTimeMillis) */ val bufferMessage = inbox.getMessageForChunk(currentChunk).get if (bufferMessage.isCompletelyReceived) { - bufferMessage.flip + bufferMessage.flip() bufferMessage.finishTime = System.currentTimeMillis logDebug("Finished receiving [" + bufferMessage + "] from " + "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index cf1c985c2fff9..8a1cdb812962e 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -249,7 +249,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, def run() { try { while(!selectorThread.isInterrupted) { - while (! registerRequests.isEmpty) { + while (!registerRequests.isEmpty) { val conn: SendingConnection = registerRequests.dequeue() addListeners(conn) conn.connect() @@ -308,7 +308,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, // Some keys within the selectors list are invalid/closed. clear them. val allKeys = selector.keys().iterator() - while (allKeys.hasNext()) { + while (allKeys.hasNext) { val key = allKeys.next() try { if (! key.isValid) { @@ -341,7 +341,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, if (0 != selectedKeysCount) { val selectedKeys = selector.selectedKeys().iterator() - while (selectedKeys.hasNext()) { + while (selectedKeys.hasNext) { val key = selectedKeys.next selectedKeys.remove() try { @@ -419,62 +419,63 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, connectionsByKey -= connection.key try { - if (connection.isInstanceOf[SendingConnection]) { - val sendingConnection = connection.asInstanceOf[SendingConnection] - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - logInfo("Removing SendingConnection to " + sendingConnectionManagerId) - - connectionsById -= sendingConnectionManagerId - connectionsAwaitingSasl -= connection.connectionId + connection match { + case sendingConnection: SendingConnection => + val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() + logInfo("Removing SendingConnection to " + sendingConnectionManagerId) + + connectionsById -= sendingConnectionManagerId + connectionsAwaitingSasl -= connection.connectionId + + messageStatuses.synchronized { + messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) + .foreach(status => { + logInfo("Notifying " + status) + status.synchronized { + status.attempted = true + status.acked = false + status.markDone() + } + }) - messageStatuses.synchronized { - messageStatuses - .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => { - logInfo("Notifying " + status) - status.synchronized { - status.attempted = true - status.acked = false - status.markDone() - } + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId }) + } + case receivingConnection: ReceivingConnection => + val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId() + logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - } else if (connection.isInstanceOf[ReceivingConnection]) { - val receivingConnection = connection.asInstanceOf[ReceivingConnection] - val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId() - logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) - - val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) - if (! sendingConnectionOpt.isDefined) { - logError("Corresponding SendingConnectionManagerId not found") - return - } + val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) + if (!sendingConnectionOpt.isDefined) { + logError("Corresponding SendingConnectionManagerId not found") + return + } - val sendingConnection = sendingConnectionOpt.get - connectionsById -= remoteConnectionManagerId - sendingConnection.close() + val sendingConnection = sendingConnectionOpt.get + connectionsById -= remoteConnectionManagerId + sendingConnection.close() - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() + val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - assert (sendingConnectionManagerId == remoteConnectionManagerId) + assert(sendingConnectionManagerId == remoteConnectionManagerId) - messageStatuses.synchronized { - for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { - logInfo("Notifying " + s) - s.synchronized { - s.attempted = true - s.acked = false - s.markDone() + messageStatuses.synchronized { + for (s <- messageStatuses.values + if s.connectionManagerId == sendingConnectionManagerId) { + logInfo("Notifying " + s) + s.synchronized { + s.attempted = true + s.acked = false + s.markDone() + } } - } - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId + }) + } + case _ => logError("Unsupported type of connection.") } } finally { // So that the selection keys can be removed. @@ -517,13 +518,13 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, logDebug("Client sasl completed for id: " + waitingConn.connectionId) connectionsAwaitingSasl -= waitingConn.connectionId waitingConn.getAuthenticated().synchronized { - waitingConn.getAuthenticated().notifyAll(); + waitingConn.getAuthenticated().notifyAll() } return } else { var replyToken : Array[Byte] = null try { - replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken); + replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken) if (waitingConn.isSaslComplete()) { logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) connectionsAwaitingSasl -= waitingConn.connectionId @@ -533,7 +534,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, return } val securityMsgResp = SecurityMessage.fromResponse(replyToken, - securityMsg.getConnectionId.toString()) + securityMsg.getConnectionId.toString) val message = securityMsgResp.toBufferMessage if (message == null) throw new Exception("Error creating security message") sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) @@ -630,13 +631,13 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, case bufferMessage: BufferMessage => { if (authEnabled) { val res = handleAuthentication(connection, bufferMessage) - if (res == true) { + if (res) { // message was security negotiation so skip the rest logDebug("After handleAuth result was true, returning") return } } - if (bufferMessage.hasAckId) { + if (bufferMessage.hasAckId()) { val sentMessageStatus = messageStatuses.synchronized { messageStatuses.get(bufferMessage.ackId) match { case Some(status) => { @@ -646,7 +647,6 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, case None => { throw new Exception("Could not find reference for received ack message " + message.id) - null } } } @@ -668,7 +668,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, if (ackMessage.isDefined) { if (!ackMessage.get.isInstanceOf[BufferMessage]) { logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " - + ackMessage.get.getClass()) + + ackMessage.get.getClass) } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { logDebug("Response to " + bufferMessage + " does not have ack id set") ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala index b82edb6850d23..57f7586883af1 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala @@ -32,6 +32,6 @@ private[spark] case class ConnectionManagerId(host: String, port: Int) { private[spark] object ConnectionManagerId { def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { - new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort()) + new ConnectionManagerId(socketAddress.getHostName, socketAddress.getPort) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 6a3f698444283..f1f4b4324edfd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -57,7 +57,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, */ def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = { val part = new RangePartitioner(numPartitions, self, ascending) - val shuffled = new ShuffledRDD[K, V, P](self, part) + val shuffled = new ShuffledRDD[K, V, V, P](self, part).setKeyOrdering(ordering) shuffled.mapPartitions(iter => { val buf = iter.toArray if (ascending) { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 443d1c587c3ee..fc9beb166befe 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -90,21 +90,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) self.mapPartitionsWithContext((context, iter) => { new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context)) }, preservesPartitioning = true) - } else if (mapSideCombine) { - val combined = self.mapPartitionsWithContext((context, iter) => { - aggregator.combineValuesByKey(iter, context) - }, preservesPartitioning = true) - val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner) - .setSerializer(serializer) - partitioned.mapPartitionsWithContext((context, iter) => { - new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter, context)) - }, preservesPartitioning = true) } else { - // Don't apply map-side combiner. - val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializer) - values.mapPartitionsWithContext((context, iter) => { - new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context)) - }, preservesPartitioning = true) + new ShuffledRDD[K, V, C, (K, C)](self, partitioner) + .setSerializer(serializer) + .setAggregator(aggregator) + .setMapSideCombine(mapSideCombine) } } @@ -401,7 +391,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) if (self.partitioner == Some(partitioner)) { self } else { - new ShuffledRDD[K, V, (K, V)](self, partitioner) + new ShuffledRDD[K, V, V, (K, V)](self, partitioner) } } @@ -772,7 +762,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) outputFormatClass: Class[_ <: NewOutputFormat[_, _]], conf: Configuration = self.context.hadoopConfiguration) { - val job = new NewAPIHadoopJob(conf) + // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). + val hadoopConf = conf + val job = new NewAPIHadoopJob(hadoopConf) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) @@ -805,22 +797,25 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) outputFormatClass: Class[_ <: OutputFormat[_, _]], conf: JobConf = new JobConf(self.context.hadoopConfiguration), codec: Option[Class[_ <: CompressionCodec]] = None) { - conf.setOutputKeyClass(keyClass) - conf.setOutputValueClass(valueClass) + // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). + val hadoopConf = conf + hadoopConf.setOutputKeyClass(keyClass) + hadoopConf.setOutputValueClass(valueClass) // Doesn't work in Scala 2.9 due to what may be a generics bug // TODO: Should we uncomment this for Scala 2.10? // conf.setOutputFormat(outputFormatClass) - conf.set("mapred.output.format.class", outputFormatClass.getName) + hadoopConf.set("mapred.output.format.class", outputFormatClass.getName) for (c <- codec) { - conf.setCompressMapOutput(true) - conf.set("mapred.output.compress", "true") - conf.setMapOutputCompressorClass(c) - conf.set("mapred.output.compression.codec", c.getCanonicalName) - conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) + hadoopConf.setCompressMapOutput(true) + hadoopConf.set("mapred.output.compress", "true") + hadoopConf.setMapOutputCompressorClass(c) + hadoopConf.set("mapred.output.compression.codec", c.getCanonicalName) + hadoopConf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) } - conf.setOutputCommitter(classOf[FileOutputCommitter]) - FileOutputFormat.setOutputPath(conf, SparkHadoopWriter.createPathFromString(path, conf)) - saveAsHadoopDataset(conf) + hadoopConf.setOutputCommitter(classOf[FileOutputCommitter]) + FileOutputFormat.setOutputPath(hadoopConf, + SparkHadoopWriter.createPathFromString(path, hadoopConf)) + saveAsHadoopDataset(hadoopConf) } /** @@ -830,7 +825,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * configured for a Hadoop MapReduce job. */ def saveAsNewAPIHadoopDataset(conf: Configuration) { - val job = new NewAPIHadoopJob(conf) + // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). + val hadoopConf = conf + val job = new NewAPIHadoopJob(hadoopConf) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id @@ -887,9 +884,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * MapReduce job. */ def saveAsHadoopDataset(conf: JobConf) { - val outputFormatInstance = conf.getOutputFormat - val keyClass = conf.getOutputKeyClass - val valueClass = conf.getOutputValueClass + // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). + val hadoopConf = conf + val outputFormatInstance = hadoopConf.getOutputFormat + val keyClass = hadoopConf.getOutputKeyClass + val valueClass = hadoopConf.getOutputValueClass if (outputFormatInstance == null) { throw new SparkException("Output format class not set") } @@ -899,18 +898,18 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) if (valueClass == null) { throw new SparkException("Output value class not set") } - SparkHadoopUtil.get.addCredentials(conf) + SparkHadoopUtil.get.addCredentials(hadoopConf) logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + valueClass.getSimpleName + ")") if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)) { // FileOutputFormat ignores the filesystem parameter - val ignoredFs = FileSystem.get(conf) - conf.getOutputFormat.checkOutputSpecs(ignoredFs, conf) + val ignoredFs = FileSystem.get(hadoopConf) + hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf) } - val writer = new SparkHadoopWriter(conf) + val writer = new SparkHadoopWriter(hadoopConf) writer.preSetup() def writeToFile(context: TaskContext, iter: Iterator[(K, V)]) { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index cebfd109d825f..4e841bc992bff 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -340,7 +340,7 @@ abstract class RDD[T: ClassTag]( // include a shuffle step so that our upstream tasks are still distributed new CoalescedRDD( - new ShuffledRDD[Int, T, (Int, T)](mapPartitionsWithIndex(distributePartition), + new ShuffledRDD[Int, T, T, (Int, T)](mapPartitionsWithIndex(distributePartition), new HashPartitioner(numPartitions)), numPartitions).values } else { diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index bb108ef163c56..bf02f68d0d3d3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import scala.reflect.ClassTag -import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext} +import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.Serializer @@ -35,23 +35,48 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { * @param part the partitioner used to partition the RDD * @tparam K the key class. * @tparam V the value class. + * @tparam C the combiner class. */ @DeveloperApi -class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( - @transient var prev: RDD[P], +class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( + @transient var prev: RDD[_ <: Product2[K, V]], part: Partitioner) extends RDD[P](prev.context, Nil) { private var serializer: Option[Serializer] = None + private var keyOrdering: Option[Ordering[K]] = None + + private var aggregator: Option[Aggregator[K, V, C]] = None + + private var mapSideCombine: Boolean = false + /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ - def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = { + def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = { this.serializer = Option(serializer) this } + /** Set key ordering for RDD's shuffle. */ + def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C, P] = { + this.keyOrdering = Option(keyOrdering) + this + } + + /** Set aggregator for RDD's shuffle. */ + def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C, P] = { + this.aggregator = Option(aggregator) + this + } + + /** Set mapSideCombine flag for RDD's shuffle. */ + def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C, P] = { + this.mapSideCombine = mapSideCombine + this + } + override def getDependencies: Seq[Dependency[_]] = { - List(new ShuffleDependency(prev, part, serializer)) + List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine)) } override val partitioner = Some(part) @@ -61,7 +86,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( } override def compute(split: Partition, context: TaskContext): Iterator[P] = { - val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, V]] + val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) .read() .asInstanceOf[Iterator[P]] diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b3ebaa547de0d..c8559a7a82868 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1038,7 +1038,7 @@ class DAGScheduler( private def failJobAndIndependentStages(job: ActiveJob, failureReason: String, resultStage: Option[Stage]) { val error = new SparkException(failureReason) - job.listener.jobFailed(error) + var ableToCancelStages = true val shouldInterruptThread = if (job.properties == null) false @@ -1062,18 +1062,26 @@ class DAGScheduler( // This is the only job that uses this stage, so fail the stage if it is running. val stage = stageIdToStage(stageId) if (runningStages.contains(stage)) { - taskScheduler.cancelTasks(stageId, shouldInterruptThread) - val stageInfo = stageToInfos(stage) - stageInfo.stageFailed(failureReason) - listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage))) + try { // cancelTasks will fail if a SchedulerBackend does not implement killTask + taskScheduler.cancelTasks(stageId, shouldInterruptThread) + val stageInfo = stageToInfos(stage) + stageInfo.stageFailed(failureReason) + listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage))) + } catch { + case e: UnsupportedOperationException => + logInfo(s"Could not cancel tasks for stage $stageId", e) + ableToCancelStages = false + } } } } } - cleanupStateForJobAndIndependentStages(job, resultStage) - - listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) + if (ableToCancelStages) { + job.listener.jobFailed(error) + cleanupStateForJobAndIndependentStages(job, resultStage) + listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) + } } /** @@ -1155,7 +1163,11 @@ private[scheduler] class DAGSchedulerActorSupervisor(dagScheduler: DAGScheduler) case x: Exception => logError("eventProcesserActor failed due to the error %s; shutting down SparkContext" .format(x.getMessage)) - dagScheduler.doCancelAllJobs() + try { + dagScheduler.doCancelAllJobs() + } catch { + case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) + } dagScheduler.sc.stop() Stop } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 859cdc524a581..fdaf1de83f051 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -144,10 +144,8 @@ private[spark] class ShuffleMapTask( try { val manager = SparkEnv.get.shuffleManager writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) - for (elem <- rdd.iterator(split, context)) { - writer.write(elem.asInstanceOf[Product2[Any, Any]]) - } - writer.stop(success = true).get + writer.write(rdd.iterator(split, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) + return writer.stop(success = true).get } catch { case e: Exception => if (writer != null) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 17292b4c15b8b..5ed2803d76afc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -210,11 +210,14 @@ private[spark] class TaskSchedulerImpl( SparkEnv.set(sc.env) // Mark each slave as alive and remember its hostname + // Also track if new executor is added + var newExecAvail = false for (o <- offers) { executorIdToHost(o.executorId) = o.host if (!executorsByHost.contains(o.host)) { executorsByHost(o.host) = new HashSet[String]() executorAdded(o.executorId, o.host) + newExecAvail = true } } @@ -227,12 +230,15 @@ private[spark] class TaskSchedulerImpl( for (taskSet <- sortedTaskSets) { logDebug("parentName: %s, name: %s, runningTasks: %s".format( taskSet.parent.name, taskSet.name, taskSet.runningTasks)) + if (newExecAvail) { + taskSet.executorAdded() + } } // Take each TaskSet in our scheduling order, and then offer it each node in increasing order // of locality levels so that it gets a chance to launch local tasks on all of them. var launchedTask = false - for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) { + for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) { do { launchedTask = false for (i <- 0 until shuffledOffers.size) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index f3bd0797aa035..c0898f64fb0c9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -118,7 +118,7 @@ private[spark] class TaskSetManager( private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] // Set containing pending tasks with no locality preferences. - val pendingTasksWithNoPrefs = new ArrayBuffer[Int] + var pendingTasksWithNoPrefs = new ArrayBuffer[Int] // Set containing all pending tasks (also used as a stack, as above). val allPendingTasks = new ArrayBuffer[Int] @@ -153,8 +153,8 @@ private[spark] class TaskSetManager( } // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling - val myLocalityLevels = computeValidLocalityLevels() - val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level + var myLocalityLevels = computeValidLocalityLevels() + var localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level // Delay scheduling variables: we keep track of our current locality level and the time we // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. @@ -181,16 +181,14 @@ private[spark] class TaskSetManager( var hadAliveLocations = false for (loc <- tasks(index).preferredLocations) { for (execId <- loc.executorId) { - if (sched.isExecutorAlive(execId)) { - addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) - hadAliveLocations = true - } + addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) } if (sched.hasExecutorsAliveOnHost(loc.host)) { - addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) - for (rack <- sched.getRackForHost(loc.host)) { - addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) - } + hadAliveLocations = true + } + addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) + for (rack <- sched.getRackForHost(loc.host)) { + addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) hadAliveLocations = true } } @@ -643,7 +641,9 @@ private[spark] class TaskSetManager( addPendingTask(index, readding=true) } - // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage. + // The reason is the next stage wouldn't be able to fetch the data from this dead executor + // so we would need to rerun these tasks on other executors. if (tasks(0).isInstanceOf[ShuffleMapTask]) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index @@ -725,10 +725,12 @@ private[spark] class TaskSetManager( private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = { import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY} val levels = new ArrayBuffer[TaskLocality.TaskLocality] - if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) { + if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0 && + pendingTasksForExecutor.keySet.exists(sched.isExecutorAlive(_))) { levels += PROCESS_LOCAL } - if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) { + if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0 && + pendingTasksForHost.keySet.exists(sched.hasExecutorsAliveOnHost(_))) { levels += NODE_LOCAL } if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) { @@ -738,4 +740,21 @@ private[spark] class TaskSetManager( logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", ")) levels.toArray } + + // Re-compute pendingTasksWithNoPrefs since new preferred locations may become available + def executorAdded() { + def newLocAvail(index: Int): Boolean = { + for (loc <- tasks(index).preferredLocations) { + if (sched.hasExecutorsAliveOnHost(loc.host) || + sched.getRackForHost(loc.host).isDefined) { + return true + } + } + false + } + logInfo("Re-computing pending task lists.") + pendingTasksWithNoPrefs = pendingTasksWithNoPrefs.filter(!newLocAvail(_)) + myLocalityLevels = computeValidLocalityLevels() + localityWaits = myLocalityLevels.map(getLocalityWait) + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index ca74069ef885c..318e16552201c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -20,21 +20,21 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler.TaskDescription import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable private[spark] object CoarseGrainedClusterMessages { + case object RetrieveSparkProps extends CoarseGrainedClusterMessage + // Driver to executors case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage case class KillTask(taskId: Long, executor: String, interruptThread: Boolean) extends CoarseGrainedClusterMessage - case class RegisteredExecutor(sparkProperties: Seq[(String, String)]) - extends CoarseGrainedClusterMessage + case object RegisteredExecutor extends CoarseGrainedClusterMessage case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index e47a060683a2d..05d01b0c821f9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -75,7 +75,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) } else { logInfo("Registered executor: " + sender + " with ID " + executorId) - sender ! RegisteredExecutor(sparkProperties) + sender ! RegisteredExecutor executorActor(executorId) = sender executorHost(executorId) = Utils.parseHostPort(hostPort)._1 totalCores(executorId) = cores @@ -124,6 +124,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disassociated")) + case RetrieveSparkProps => + sender ! sparkProperties } // Make fake resource offers on all executors @@ -143,14 +145,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A for (task <- tasks.flatten) { val ser = SparkEnv.get.closureSerializer.newInstance() val serializedTask = ser.serialize(task) - if (serializedTask.limit >= akkaFrameSize - 1024) { + if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { val taskSetId = scheduler.taskIdToTaskSetId(task.taskId) scheduler.activeTaskSets.get(taskSetId).foreach { taskSet => try { - var msg = "Serialized task %s:%d was %d bytes which " + - "exceeds spark.akka.frameSize (%d bytes). " + - "Consider using broadcast variables for large values." - msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize) + var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + + "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " + + "spark.akka.frameSize or using broadcast variables for large values." + msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize, + AkkaUtils.reservedSizeBytes) taskSet.abort(msg) } catch { case e: Exception => logError("Exception in error callback", e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index a089a02d42170..c717e7c621a8f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -185,8 +185,8 @@ private[spark] class MesosSchedulerBackend( synchronized { // Build a big list of the offerable workers, and remember their indices so that we can // figure out which Offer to reply to for each worker - val offerableIndices = new ArrayBuffer[Int] val offerableWorkers = new ArrayBuffer[WorkerOffer] + val offerableIndices = new HashMap[String, Int] def enoughMemory(o: Offer) = { val mem = getResource(o.getResourcesList, "mem") @@ -195,7 +195,7 @@ private[spark] class MesosSchedulerBackend( } for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) { - offerableIndices += index + offerableIndices.put(offer.getSlaveId.getValue, index) offerableWorkers += new WorkerOffer( offer.getSlaveId.getValue, offer.getHostname, @@ -206,14 +206,13 @@ private[spark] class MesosSchedulerBackend( val taskLists = scheduler.resourceOffers(offerableWorkers) // Build a list of Mesos tasks for each slave - val mesosTasks = offers.map(o => Collections.emptyList[MesosTaskInfo]()) + val mesosTasks = offers.map(o => new JArrayList[MesosTaskInfo]()) for ((taskList, index) <- taskLists.zipWithIndex) { if (!taskList.isEmpty) { - val offerNum = offerableIndices(index) - val slaveId = offers(offerNum).getSlaveId.getValue - slaveIdsWithExecutors += slaveId - mesosTasks(offerNum) = new JArrayList[MesosTaskInfo](taskList.size) for (taskDesc <- taskList) { + val slaveId = taskDesc.executorId + val offerNum = offerableIndices(slaveId) + slaveIdsWithExecutors += slaveId taskIdToSlaveId(taskDesc.taskId) = slaveId mesosTasks(offerNum).add(createMesosTask(taskDesc, slaveId)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 43f0e18a0cbe0..9b95ccca0443e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -97,7 +97,8 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: localActor ! ReviveOffers } - override def defaultParallelism() = totalCores + override def defaultParallelism() = + scheduler.conf.getInt("spark.default.parallelism", totalCores) override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { localActor ! KillTask(taskId, interruptThread) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 5286f7b4c211a..82b62aaf61521 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -64,6 +64,9 @@ class KryoSerializer(conf: SparkConf) kryo.register(cls) } + // For results returned by asJavaIterable. See JavaIterableWrapperSerializer. + kryo.register(JavaIterableWrapperSerializer.wrapperClass, new JavaIterableWrapperSerializer) + // Allow sending SerializableWritable kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) @@ -183,3 +186,50 @@ private[serializer] object KryoSerializer { classOf[Array[Byte]] ) } + +/** + * A Kryo serializer for serializing results returned by asJavaIterable. + * + * The underlying object is scala.collection.convert.Wrappers$IterableWrapper. + * Kryo deserializes this into an AbstractCollection, which unfortunately doesn't work. + */ +private class JavaIterableWrapperSerializer + extends com.esotericsoftware.kryo.Serializer[java.lang.Iterable[_]] { + + import JavaIterableWrapperSerializer._ + + override def write(kryo: Kryo, out: KryoOutput, obj: java.lang.Iterable[_]): Unit = { + // If the object is the wrapper, simply serialize the underlying Scala Iterable object. + // Otherwise, serialize the object itself. + if (obj.getClass == wrapperClass && underlyingMethodOpt.isDefined) { + kryo.writeClassAndObject(out, underlyingMethodOpt.get.invoke(obj)) + } else { + kryo.writeClassAndObject(out, obj) + } + } + + override def read(kryo: Kryo, in: KryoInput, clz: Class[java.lang.Iterable[_]]) + : java.lang.Iterable[_] = { + kryo.readClassAndObject(in) match { + case scalaIterable: Iterable[_] => + scala.collection.JavaConversions.asJavaIterable(scalaIterable) + case javaIterable: java.lang.Iterable[_] => + javaIterable + } + } +} + +private object JavaIterableWrapperSerializer extends Logging { + // The class returned by asJavaIterable (scala.collection.convert.Wrappers$IterableWrapper). + val wrapperClass = + scala.collection.convert.WrapAsJava.asJavaIterable(Seq(1)).getClass + + // Get the underlying method so we can use it to get the Scala collection for serialization. + private val underlyingMethodOpt = { + try Some(wrapperClass.getDeclaredMethod("underlying")) catch { + case e: Exception => + logError("Failed to find the underlying field in " + wrapperClass, e) + None + } + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala index ead3ebd652ca5..b934480cfb9be 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -23,8 +23,8 @@ import org.apache.spark.scheduler.MapStatus * Obtained inside a map task to write out records to the shuffle system. */ private[spark] trait ShuffleWriter[K, V] { - /** Write a record to this task's output */ - def write(record: Product2[K, V]): Unit + /** Write a bunch of records to this task's output */ + def write(records: Iterator[_ <: Product2[K, V]]): Unit /** Close this writer, passing along whether the map completed */ def stop(success: Boolean): Option[MapStatus] diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index f6a790309a587..d45258c0a492b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,9 +17,9 @@ package org.apache.spark.shuffle.hash +import org.apache.spark.{InterruptibleIterator, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} -import org.apache.spark.TaskContext class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], @@ -31,10 +31,24 @@ class HashShuffleReader[K, C]( require(endPartition == startPartition + 1, "Hash shuffle currently only supports fetching one partition") + private val dep = handle.dependency + /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, - Serializer.getSerializer(handle.dependency.serializer)) + val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, + Serializer.getSerializer(dep.serializer)) + + if (dep.aggregator.isDefined) { + if (dep.mapSideCombine) { + new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context)) + } else { + new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context)) + } + } else if (dep.aggregator.isEmpty && dep.mapSideCombine) { + throw new IllegalStateException("Aggregator is empty for map-side combine") + } else { + iter + } } /** Close this reader */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 4c6749098c110..9b78228519da4 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -40,11 +40,24 @@ class HashShuffleWriter[K, V]( private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null)) private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser) - /** Write a record to this task's output */ - override def write(record: Product2[K, V]): Unit = { - val pair = record.asInstanceOf[Product2[Any, Any]] - val bucketId = dep.partitioner.getPartition(pair._1) - shuffle.writers(bucketId).write(pair) + /** Write a bunch of records to this task's output */ + override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + val iter = if (dep.aggregator.isDefined) { + if (dep.mapSideCombine) { + dep.aggregator.get.combineValuesByKey(records, context) + } else { + records + } + } else if (dep.aggregator.isEmpty && dep.mapSideCombine) { + throw new IllegalStateException("Aggregator is empty for map-side combine") + } else { + records + } + + for (elem <- iter) { + val bucketId = dep.partitioner.getPartition(elem._1) + shuffle.writers(bucketId).write(elem) + } } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index b08f308fda1dd..856273e1d4e21 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -51,6 +51,7 @@ private[spark] abstract class WebUI( def getTabs: Seq[WebUITab] = tabs.toSeq def getHandlers: Seq[ServletContextHandler] = handlers.toSeq + def getSecurityManager: SecurityManager = securityManager /** Attach a tab to this UI, along with all of its attached pages. */ def attachTab(tab: WebUITab) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 396cbcbc8d268..381a5443df8b5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -19,7 +19,7 @@ package org.apache.spark.ui.jobs import scala.collection.mutable.{HashMap, ListBuffer} -import org.apache.spark.{ExceptionFailure, SparkConf, SparkContext, Success} +import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ @@ -51,6 +51,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { var totalShuffleRead = 0L var totalShuffleWrite = 0L + // TODO: Should probably consolidate all following into a single hash map. val stageIdToTime = HashMap[Int, Long]() val stageIdToShuffleRead = HashMap[Int, Long]() val stageIdToShuffleWrite = HashMap[Int, Long]() @@ -183,14 +184,17 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { // Remove by taskId, rather than by TaskInfo, in case the TaskInfo is from storage tasksActive.remove(info.taskId) - val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = + val (errorMessage, metrics): (Option[String], Option[TaskMetrics]) = taskEnd.reason match { - case e: ExceptionFailure => - stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1 - (Some(e), e.metrics) - case _ => + case org.apache.spark.Success => stageIdToTasksComplete(sid) = stageIdToTasksComplete.getOrElse(sid, 0) + 1 (None, Option(taskEnd.taskMetrics)) + case e: ExceptionFailure => // Handle ExceptionFailure because we might have metrics + stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1 + (Some(e.toErrorString), e.metrics) + case e: TaskFailedReason => // All other failure cases + stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1 + (Some(e.toErrorString), None) } stageIdToTime.getOrElseUpdate(sid, 0L) @@ -218,7 +222,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { stageIdToDiskBytesSpilled(sid) += diskBytesSpilled val taskMap = stageIdToTaskData.getOrElse(sid, HashMap[Long, TaskUIData]()) - taskMap(info.taskId) = new TaskUIData(info, metrics, failureInfo) + taskMap(info.taskId) = new TaskUIData(info, metrics, errorMessage) stageIdToTaskData(sid) = taskMap } } @@ -253,7 +257,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { case class TaskUIData( taskInfo: TaskInfo, taskMetrics: Option[TaskMetrics] = None, - exception: Option[ExceptionFailure] = None) + errorMessage: Option[String] = None) private object JobProgressListener { val DEFAULT_POOL_NAME = "default" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 4bce472036f7d..8b65f0671bdb9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -210,10 +210,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { def taskRow(shuffleRead: Boolean, shuffleWrite: Boolean, bytesSpilled: Boolean) (taskData: TaskUIData): Seq[Node] = { - def fmtStackTrace(trace: Seq[StackTraceElement]): Seq[Node] = - trace.map(e => {e.toString}) - - taskData match { case TaskUIData(info, metrics, exception) => + taskData match { case TaskUIData(info, metrics, errorMessage) => val duration = if (info.status == "RUNNING") info.timeRunning(System.currentTimeMillis()) else metrics.map(_.executorRunTime).getOrElse(1L) val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) @@ -283,12 +280,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { }} - {exception.map { e => - - {e.className} ({e.description})
- {fmtStackTrace(e.stackTrace)} -
- }.getOrElse("")} + {errorMessage.map { e =>
{e}
}.getOrElse("")} } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index a3f824a4e1f57..30971f769682f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -91,13 +91,13 @@ private[ui] class StageTableBase( {s.name} - val details = if (s.details.nonEmpty) ( + val details = if (s.details.nonEmpty) { +show details - ) + } listener.stageIdToDescription.get(s.stageId) .map(d =>
{d}
{nameLink} {killLink}
) diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index a8d12bb2a0165..9930c717492f2 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -121,4 +121,7 @@ private[spark] object AkkaUtils extends Logging { def maxFrameSizeBytes(conf: SparkConf): Int = { conf.getInt("spark.akka.frameSize", 10) * 1024 * 1024 } + + /** Space reserved for extra data in an Akka message besides serialized task or task result. */ + val reservedSizeBytes = 200 * 1024 } diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala index 8e9c3036d09c2..1d5467060623c 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala @@ -125,16 +125,16 @@ private[spark] object FileAppender extends Logging { val validatedParams: Option[(Long, String)] = rollingInterval match { case "daily" => logInfo(s"Rolling executor logs enabled for $file with daily rolling") - Some(24 * 60 * 60 * 1000L, "--YYYY-MM-dd") + Some(24 * 60 * 60 * 1000L, "--yyyy-MM-dd") case "hourly" => logInfo(s"Rolling executor logs enabled for $file with hourly rolling") - Some(60 * 60 * 1000L, "--YYYY-MM-dd--HH") + Some(60 * 60 * 1000L, "--yyyy-MM-dd--HH") case "minutely" => logInfo(s"Rolling executor logs enabled for $file with rolling every minute") - Some(60 * 1000L, "--YYYY-MM-dd--HH-mm") + Some(60 * 1000L, "--yyyy-MM-dd--HH-mm") case IntParam(seconds) => logInfo(s"Rolling executor logs enabled for $file with rolling $seconds seconds") - Some(seconds * 1000L, "--YYYY-MM-dd--HH-mm-ss") + Some(seconds * 1000L, "--yyyy-MM-dd--HH-mm-ss") case _ => logWarning(s"Illegal interval for rolling executor logs [$rollingInterval], " + s"rolling logs not enabled") diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index 1bbbd20cf076f..e579421676343 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -19,7 +19,7 @@ package org.apache.spark.util.logging import java.io.{File, FileFilter, InputStream} -import org.apache.commons.io.FileUtils +import com.google.common.io.Files import org.apache.spark.SparkConf import RollingFileAppender._ @@ -83,7 +83,7 @@ private[spark] class RollingFileAppender( logDebug(s"Attempting to rollover file $activeFile to file $rolloverFile") if (activeFile.exists) { if (!rolloverFile.exists) { - FileUtils.moveFile(activeFile, rolloverFile) + Files.move(activeFile, rolloverFile) logInfo(s"Rolled over $activeFile to $rolloverFile") } else { // In case the rollover file name clashes, make a unique file name. @@ -100,7 +100,7 @@ private[spark] class RollingFileAppender( logWarning(s"Rollover file $rolloverFile already exists, " + s"rolled over $activeFile to file $altRolloverFile") - FileUtils.moveFile(activeFile, altRolloverFile) + Files.move(activeFile, altRolloverFile) } } else { logWarning(s"File $activeFile does not exist") diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala index 84e5c3c917dcb..d7b7219e179d0 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala @@ -109,7 +109,7 @@ private[spark] class SizeBasedRollingPolicy( } @volatile private var bytesWrittenSinceRollover = 0L - val formatter = new SimpleDateFormat("--YYYY-MM-dd--HH-mm-ss--SSSS") + val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS") /** Should rollover if the next set of bytes is going to exceed the size limit */ def shouldRollover(bytesToBeWritten: Long): Boolean = { diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index f64f3c9036034..fc00458083a33 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -99,7 +99,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { test("ShuffledRDD") { testRDD(rdd => { // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD - new ShuffledRDD[Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner) + new ShuffledRDD[Int, Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner) }) } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 47112ce66d695..b40fee7e9ab23 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -56,8 +56,11 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { } // If the Kryo serializer is not used correctly, the shuffle would fail because the // default Java serializer cannot handle the non serializable class. - val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( - b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf)) + val c = new ShuffledRDD[Int, + NonJavaSerializableClass, + NonJavaSerializableClass, + (Int, NonJavaSerializableClass)](b, new HashPartitioner(NUM_BLOCKS)) + c.setSerializer(new KryoSerializer(conf)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 10) @@ -78,8 +81,11 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { } // If the Kryo serializer is not used correctly, the shuffle would fail because the // default Java serializer cannot handle the non serializable class. - val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( - b, new HashPartitioner(3)).setSerializer(new KryoSerializer(conf)) + val c = new ShuffledRDD[Int, + NonJavaSerializableClass, + NonJavaSerializableClass, + (Int, NonJavaSerializableClass)](b, new HashPartitioner(3)) + c.setSerializer(new KryoSerializer(conf)) assert(c.count === 10) } @@ -94,7 +100,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { // NOTE: The default Java serializer doesn't create zero-sized blocks. // So, use Kryo - val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) + val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10)) .setSerializer(new KryoSerializer(conf)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId @@ -120,7 +126,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { val b = a.map(x => (x, x*2)) // NOTE: The default Java serializer should create zero-sized blocks - val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) + val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 4) @@ -141,8 +147,8 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) - val results = new ShuffledRDD[Int, Int, MutablePair[Int, Int]](pairs, new HashPartitioner(2)) - .collect() + val results = new ShuffledRDD[Int, Int, Int, MutablePair[Int, Int]](pairs, + new HashPartitioner(2)).collect() data.foreach { pair => results should contain (pair) } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 94fba102865b3..67e3be21c3c93 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -77,6 +77,22 @@ class SparkContextSchedulerCreationSuite } } + test("local-default-parallelism") { + val defaultParallelism = System.getProperty("spark.default.parallelism") + System.setProperty("spark.default.parallelism", "16") + val sched = createTaskScheduler("local") + + sched.backend match { + case s: LocalBackend => assert(s.defaultParallelism() === 16) + case _ => fail() + } + + Option(defaultParallelism) match { + case Some(v) => System.setProperty("spark.default.parallelism", v) + case _ => System.clearProperty("spark.default.parallelism") + } + } + test("simr") { createTaskScheduler("simr://uri").backend match { case s: SimrSchedulerBackend => // OK diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 0e5625b7645d5..0f9cbe213ea17 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -276,7 +276,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { // we can optionally shuffle to keep the upstream parallel val coalesced5 = data.coalesce(1, shuffle = true) val isEquals = coalesced5.dependencies.head.rdd.dependencies.head.rdd. - asInstanceOf[ShuffledRDD[_, _, _]] != null + asInstanceOf[ShuffledRDD[_, _, _, _]] != null assert(isEquals) // when shuffling, we can increase the number of partitions @@ -509,7 +509,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("takeSample") { val n = 1000000 val data = sc.parallelize(1 to n, 2) - + for (num <- List(5, 20, 100)) { val sample = data.takeSample(withReplacement=false, num=num) assert(sample.size === num) // Got exactly num elements @@ -704,11 +704,11 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(ancestors3.count(_.isInstanceOf[MappedRDD[_, _]]) === 2) // Any ancestors before the shuffle are not considered - assert(ancestors4.size === 1) - assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 1) - assert(ancestors5.size === 4) - assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 1) - assert(ancestors5.count(_.isInstanceOf[MapPartitionsRDD[_, _]]) === 1) + assert(ancestors4.size === 0) + assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 0) + assert(ancestors5.size === 3) + assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 1) + assert(ancestors5.count(_.isInstanceOf[MapPartitionsRDD[_, _]]) === 0) assert(ancestors5.count(_.isInstanceOf[MappedValuesRDD[_, _, _]]) === 2) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index efef9d26dadca..f77661ccbd1c5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -35,7 +35,7 @@ class CoarseGrainedSchedulerBackendSuite extends FunSuite with LocalSparkContext val thrown = intercept[SparkException] { larger.collect() } - assert(thrown.getMessage.contains("Consider using broadcast variables for large values")) + assert(thrown.getMessage.contains("using broadcast variables for large values")) val smaller = sc.parallelize(1 to 4).collect() assert(smaller.size === 4) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 45368328297d3..8dd2a9b9f7373 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -115,6 +115,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F sc = new SparkContext("local", "DAGSchedulerSuite") sparkListener.successfulStages.clear() sparkListener.failedStages.clear() + failure = null sc.addSparkListener(sparkListener) taskSets.clear() cancelledStages.clear() @@ -314,6 +315,53 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assertDataStructuresEmpty } + test("job cancellation no-kill backend") { + // make sure that the DAGScheduler doesn't crash when the TaskScheduler + // doesn't implement killTask() + val noKillTaskScheduler = new TaskScheduler() { + override def rootPool: Pool = null + override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def start() = {} + override def stop() = {} + override def submitTasks(taskSet: TaskSet) = { + taskSets += taskSet + } + override def cancelTasks(stageId: Int, interruptThread: Boolean) { + throw new UnsupportedOperationException + } + override def setDAGScheduler(dagScheduler: DAGScheduler) = {} + override def defaultParallelism() = 2 + } + val noKillScheduler = new DAGScheduler( + sc, + noKillTaskScheduler, + sc.listenerBus, + mapOutputTracker, + blockManagerMaster, + sc.env) { + override def runLocally(job: ActiveJob) { + // don't bother with the thread while unit testing + runLocallyWithinThread(job) + } + } + dagEventProcessTestActor = TestActorRef[DAGSchedulerEventProcessActor]( + Props(classOf[DAGSchedulerEventProcessActor], noKillScheduler))(system) + val rdd = makeRdd(1, Nil) + val jobId = submit(rdd, Array(0)) + cancel(jobId) + // Because the job wasn't actually cancelled, we shouldn't have received a failure message. + assert(failure === null) + + // When the task set completes normally, state should be correctly updated. + complete(taskSets(0), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty + + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(sparkListener.failedStages.isEmpty) + assert(sparkListener.successfulStages.contains(0)) + } + test("run trivial shuffle") { val shuffleMapRdd = makeRdd(2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index abd7b22310f1a..6df0a080961b6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -181,7 +181,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) listener.stageInfos.size should be {2} // Shuffle map stage + result stage val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 2).get - stageInfo3.rddInfos.size should be {2} // ShuffledRDD, MapPartitionsRDD + stageInfo3.rddInfos.size should be {1} // ShuffledRDD stageInfo3.rddInfos.forall(_.numPartitions == 4) should be {true} stageInfo3.rddInfos.exists(_.name == "Trois") should be {true} } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 6f1fd25764544..59a618956a356 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -77,6 +77,10 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex override def isExecutorAlive(execId: String): Boolean = executors.contains(execId) override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host) + + def addExecutor(execId: String, host: String) { + executors.put(execId, host) + } } class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { @@ -400,6 +404,36 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(sched.taskSetsFailed.contains(taskSet.id)) } + test("new executors get added") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(4, + Seq(TaskLocation("host1", "execA")), + Seq(TaskLocation("host1", "execB")), + Seq(TaskLocation("host2", "execC")), + Seq()) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + // All tasks added to no-pref list since no preferred location is available + assert(manager.pendingTasksWithNoPrefs.size === 4) + // Only ANY is valid + assert(manager.myLocalityLevels.sameElements(Array(ANY))) + // Add a new executor + sched.addExecutor("execD", "host1") + manager.executorAdded() + // Task 0 and 1 should be removed from no-pref list + assert(manager.pendingTasksWithNoPrefs.size === 2) + // Valid locality should contain NODE_LOCAL and ANY + assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, ANY))) + // Add another executor + sched.addExecutor("execC", "host2") + manager.executorAdded() + // No-pref list now only contains task 3 + assert(manager.pendingTasksWithNoPrefs.size === 1) + // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL and ANY + assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) + } + def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index cdd6b3d8feed7..79280d1a06653 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -128,6 +128,21 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { check(1.0 until 1000000.0 by 2.0) } + test("asJavaIterable") { + // Serialize a collection wrapped by asJavaIterable + val ser = new KryoSerializer(conf).newInstance() + val a = ser.serialize(scala.collection.convert.WrapAsJava.asJavaIterable(Seq(12345))) + val b = ser.deserialize[java.lang.Iterable[Int]](a) + assert(b.iterator().next() === 12345) + + // Serialize a normal Java collection + val col = new java.util.ArrayList[Int] + col.add(54321) + val c = ser.serialize(col) + val d = ser.deserialize[java.lang.Iterable[Int]](c) + assert(b.iterator().next() === 12345) + } + test("custom registrator") { val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index c3a14f48de38e..e0fec6a068bd1 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs import org.scalatest.FunSuite import org.scalatest.Matchers -import org.apache.spark.{LocalSparkContext, SparkConf, Success} +import org.apache.spark._ import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics} import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -101,4 +101,32 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-2", fail()) .shuffleRead == 1000) } + + test("test task success vs failure counting for different task end reasons") { + val conf = new SparkConf() + val listener = new JobProgressListener(conf) + val metrics = new TaskMetrics() + val taskInfo = new TaskInfo(1234L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL) + taskInfo.finishTime = 1 + val task = new ShuffleMapTask(0, null, null, 0, null) + val taskType = Utils.getFormattedClassName(task) + + // Go through all the failure cases to make sure we are counting them as failures. + val taskFailedReasons = Seq( + Resubmitted, + new FetchFailed(null, 0, 0, 0), + new ExceptionFailure("Exception", "description", null, None), + TaskResultLost, + TaskKilled, + ExecutorLostFailure, + UnknownReason) + for (reason <- taskFailedReasons) { + listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, reason, taskInfo, metrics)) + assert(listener.stageIdToTasksComplete.get(task.stageId) === None) + } + + // Make sure we count success as success. + listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, metrics)) + assert(listener.stageIdToTasksComplete.get(task.stageId) === Some(1)) + } } diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 02e228945bbd9..ca37d707b06ca 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -18,13 +18,16 @@ package org.apache.spark.util import java.io._ +import java.nio.charset.Charset import scala.collection.mutable.HashSet import scala.reflect._ -import org.apache.commons.io.{FileUtils, IOUtils} -import org.apache.spark.{Logging, SparkConf} import org.scalatest.{BeforeAndAfter, FunSuite} + +import com.google.common.io.Files + +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.util.logging.{RollingFileAppender, SizeBasedRollingPolicy, TimeBasedRollingPolicy, FileAppender} class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { @@ -41,11 +44,11 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { test("basic file appender") { val testString = (1 to 1000).mkString(", ") - val inputStream = IOUtils.toInputStream(testString) + val inputStream = new ByteArrayInputStream(testString.getBytes(Charset.forName("UTF-8"))) val appender = new FileAppender(inputStream, testFile) inputStream.close() appender.awaitTermination() - assert(FileUtils.readFileToString(testFile) === testString) + assert(Files.toString(testFile, Charset.forName("UTF-8")) === testString) } test("rolling file appender - time-based rolling") { @@ -93,7 +96,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { val allGeneratedFiles = new HashSet[String]() val items = (1 to 10).map { _.toString * 10000 } for (i <- 0 until items.size) { - testOutputStream.write(items(i).getBytes("UTF8")) + testOutputStream.write(items(i).getBytes(Charset.forName("UTF-8"))) testOutputStream.flush() allGeneratedFiles ++= RollingFileAppender.getSortedRolledOverFiles( testFile.getParentFile.toString, testFile.getName).map(_.toString) @@ -197,7 +200,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { // send data to appender through the input stream, and wait for the data to be written val expectedText = textToAppend.mkString("") for (i <- 0 until textToAppend.size) { - outputStream.write(textToAppend(i).getBytes("UTF8")) + outputStream.write(textToAppend(i).getBytes(Charset.forName("UTF-8"))) outputStream.flush() Thread.sleep(sleepTimeBetweenTexts) } @@ -212,7 +215,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { logInfo("Filtered files: \n" + generatedFiles.mkString("\n")) assert(generatedFiles.size > 1) val allText = generatedFiles.map { file => - FileUtils.readFileToString(file) + Files.toString(file, Charset.forName("UTF-8")) }.mkString("") assert(allText === expectedText) generatedFiles diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index f72389b6b323f..495e1b7a0a214 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -366,7 +366,7 @@ class JsonProtocolSuite extends FunSuite { private def assertJsonStringEquals(json1: String, json2: String) { val formatJsonString = (json: String) => json.replaceAll("[\\s|]", "") - formatJsonString(json1) === formatJsonString(json2) + assert(formatJsonString(json1) === formatJsonString(json2)) } private def assertSeqEquals[T](seq1: Seq[T], seq2: Seq[T], assertEquals: (T, T) => Unit) { @@ -449,7 +449,7 @@ class JsonProtocolSuite extends FunSuite { } private def makeStageInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = { - val rddInfos = (1 to a % 5).map { i => makeRddInfo(a % i, b % i, c % i, d % i, e % i) } + val rddInfos = (0 until a % 5).map { i => makeRddInfo(a + i, b + i, c + i, d + i, e + i) } new StageInfo(a, "greetings", b, rddInfos, "details") } @@ -493,20 +493,19 @@ class JsonProtocolSuite extends FunSuite { private val stageSubmittedJsonString = """ {"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":100,"Stage Name": - "greetings","Number of Tasks":200,"RDD Info":{"RDD ID":100,"Name":"mayor","Storage - Level":{"Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":true, - "Replication":1},"Number of Partitions":200,"Number of Cached Partitions":300, - "Memory Size":400,"Disk Size":500,"Tachyon Size":0},"Emitted Task Size Warning":false}, - "Properties":{"France":"Paris","Germany":"Berlin","Russia":"Moscow","Ukraine":"Kiev"}} + "greetings","Number of Tasks":200,"RDD Info":[],"Details":"details", + "Emitted Task Size Warning":false},"Properties":{"France":"Paris","Germany":"Berlin", + "Russia":"Moscow","Ukraine":"Kiev"}} """ private val stageCompletedJsonString = """ {"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":101,"Stage Name": - "greetings","Number of Tasks":201,"RDD Info":{"RDD ID":101,"Name":"mayor","Storage + "greetings","Number of Tasks":201,"RDD Info":[{"RDD ID":101,"Name":"mayor","Storage Level":{"Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":true, "Replication":1},"Number of Partitions":201,"Number of Cached Partitions":301, - "Memory Size":401,"Disk Size":501,"Tachyon Size":0},"Emitted Task Size Warning":false}} + "Memory Size":401,"Tachyon Size":0,"Disk Size":501}],"Details":"details", + "Emitted Task Size Warning":false}} """ private val taskStartJsonString = @@ -538,9 +537,9 @@ class JsonProtocolSuite extends FunSuite { 900,"Total Blocks Fetched":1500,"Remote Blocks Fetched":800,"Local Blocks Fetched": 700,"Fetch Wait Time":900,"Remote Bytes Read":1000},"Shuffle Write Metrics": {"Shuffle Bytes Written":1200,"Shuffle Write Time":1500},"Updated Blocks": - [{"Block ID":{"Type":"RDDBlockId","RDD ID":0,"Split Index":0},"Status": - {"Storage Level":{"Use Disk":true,"Use Memory":true,"Use Tachyon":false, - "Deserialized":false,"Replication":2},"Memory Size":0,"Disk Size":0,"Tachyon Size":0}}]}} + [{"Block ID":"rdd_0_0","Status":{"Storage Level":{"Use Disk":true,"Use Memory":true, + "Use Tachyon":false,"Deserialized":false,"Replication":2},"Memory Size":0,"Tachyon Size":0, + "Disk Size":0}}]}} """ private val jobStartJsonString = diff --git a/dev/audit-release/blank_maven_build/pom.xml b/dev/audit-release/blank_maven_build/pom.xml index 047659e4a8b7c..02dd9046c9a49 100644 --- a/dev/audit-release/blank_maven_build/pom.xml +++ b/dev/audit-release/blank_maven_build/pom.xml @@ -28,10 +28,6 @@ Spray.cc repository http://repo.spray.cc - - Akka repository - http://repo.akka.io/releases - Spark Staging Repo ${spark.release.repository} diff --git a/dev/audit-release/blank_sbt_build/build.sbt b/dev/audit-release/blank_sbt_build/build.sbt index 1cf52743f27f4..696c7f651837c 100644 --- a/dev/audit-release/blank_sbt_build/build.sbt +++ b/dev/audit-release/blank_sbt_build/build.sbt @@ -25,5 +25,4 @@ libraryDependencies += "org.apache.spark" % System.getenv.get("SPARK_MODULE") % resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), - "Akka Repository" at "http://repo.akka.io/releases/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/maven_app_core/pom.xml b/dev/audit-release/maven_app_core/pom.xml index 76a381f8e17e0..b516396825573 100644 --- a/dev/audit-release/maven_app_core/pom.xml +++ b/dev/audit-release/maven_app_core/pom.xml @@ -28,10 +28,6 @@ Spray.cc repository http://repo.spray.cc - - Akka repository - http://repo.akka.io/releases - Spark Staging Repo ${spark.release.repository} diff --git a/dev/audit-release/sbt_app_core/build.sbt b/dev/audit-release/sbt_app_core/build.sbt index 97a8cc3a4e095..291b1d6440bac 100644 --- a/dev/audit-release/sbt_app_core/build.sbt +++ b/dev/audit-release/sbt_app_core/build.sbt @@ -25,5 +25,4 @@ libraryDependencies += "org.apache.spark" %% "spark-core" % System.getenv.get("S resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), - "Akka Repository" at "http://repo.akka.io/releases/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/sbt_app_ganglia/build.sbt b/dev/audit-release/sbt_app_ganglia/build.sbt index 55db675c722d1..6d9474acf5bbc 100644 --- a/dev/audit-release/sbt_app_ganglia/build.sbt +++ b/dev/audit-release/sbt_app_ganglia/build.sbt @@ -27,5 +27,4 @@ libraryDependencies += "org.apache.spark" %% "spark-ganglia-lgpl" % System.geten resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), - "Akka Repository" at "http://repo.akka.io/releases/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/sbt_app_graphx/build.sbt b/dev/audit-release/sbt_app_graphx/build.sbt index 66f2db357d49b..dd11245e67d44 100644 --- a/dev/audit-release/sbt_app_graphx/build.sbt +++ b/dev/audit-release/sbt_app_graphx/build.sbt @@ -25,5 +25,4 @@ libraryDependencies += "org.apache.spark" %% "spark-graphx" % System.getenv.get( resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), - "Akka Repository" at "http://repo.akka.io/releases/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/sbt_app_hive/build.sbt b/dev/audit-release/sbt_app_hive/build.sbt index 7ac1be729c561..a0d4f25da5842 100644 --- a/dev/audit-release/sbt_app_hive/build.sbt +++ b/dev/audit-release/sbt_app_hive/build.sbt @@ -25,5 +25,4 @@ libraryDependencies += "org.apache.spark" %% "spark-hive" % System.getenv.get("S resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), - "Akka Repository" at "http://repo.akka.io/releases/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/sbt_app_sql/build.sbt b/dev/audit-release/sbt_app_sql/build.sbt index 6e0ad3b4b2960..9116180f71a44 100644 --- a/dev/audit-release/sbt_app_sql/build.sbt +++ b/dev/audit-release/sbt_app_sql/build.sbt @@ -25,5 +25,4 @@ libraryDependencies += "org.apache.spark" %% "spark-sql" % System.getenv.get("SP resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), - "Akka Repository" at "http://repo.akka.io/releases/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/sbt_app_streaming/build.sbt b/dev/audit-release/sbt_app_streaming/build.sbt index 492e5e7c8d763..cb369d516dd16 100644 --- a/dev/audit-release/sbt_app_streaming/build.sbt +++ b/dev/audit-release/sbt_app_streaming/build.sbt @@ -25,5 +25,4 @@ libraryDependencies += "org.apache.spark" %% "spark-streaming" % System.getenv.g resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), - "Akka Repository" at "http://repo.akka.io/releases/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index ffb70096d6014..c44320239bbbf 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -130,7 +130,9 @@ def merge_pr(pr_num, target_ref): merge_message_flags += ["-m", title] if body != None: - merge_message_flags += ["-m", body] + # We remove @ symbols from the body to avoid triggering e-mails + # to people every time someone creates a public fork of Spark. + merge_message_flags += ["-m", body.replace("@", "")] authors = "\n".join(["Author: %s" % a for a in distinct_authors]) diff --git a/dev/run-tests b/dev/run-tests index c82a47ebb618b..d9df020f7563c 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -21,6 +21,9 @@ FWDIR="$(cd `dirname $0`/..; pwd)" cd $FWDIR +export SPARK_HADOOP_VERSION=2.3.0 +export SPARK_YARN=true + # Remove work directory rm -rf ./work diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 4ba20e590f2c2..b30ab1e5218c0 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -136,21 +136,31 @@

{{ page.title }}

- + }); + diff --git a/docs/index.md b/docs/index.md index 1a4ff3dbf57be..4ac0982ae54f1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -6,7 +6,7 @@ title: Spark Overview Apache Spark is a fast and general-purpose cluster computing system. It provides high-level APIs in Java, Scala and Python, and an optimized engine that supports general execution graphs. -It also supports a rich set of higher-level tools including [Shark](http://shark.cs.berkeley.edu) (Hive on Spark), [Spark SQL](sql-programming-guide.html) for structured data, [MLlib](mllib-guide.html) for machine learning, [GraphX](graphx-programming-guide.html) for graph processing, and [Spark Streaming](streaming-programming-guide.html). +It also supports a rich set of higher-level tools including [Spark SQL](sql-programming-guide.html) for SQL and structured data processing, [MLlib](mllib-guide.html) for machine learning, [GraphX](graphx-programming-guide.html) for graph processing, and [Spark Streaming](streaming-programming-guide.html). # Downloading @@ -109,10 +109,9 @@ options for deployment: **External Resources:** * [Spark Homepage](http://spark.apache.org) -* [Shark](http://shark.cs.berkeley.edu): Apache Hive over Spark * [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here * [AMP Camps](http://ampcamp.berkeley.edu/): a series of training camps at UC Berkeley that featured talks and - exercises about Spark, Shark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/3/), + exercises about Spark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/3/), [slides](http://ampcamp.berkeley.edu/3/) and [exercises](http://ampcamp.berkeley.edu/3/exercises/) are available online for free. * [Code Examples](http://spark.apache.org/examples.html): more are also available in the `examples` subfolder of Spark ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), diff --git a/docs/monitoring.md b/docs/monitoring.md index 2b9e9e5bd7ea0..84073fe4d949a 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -35,11 +35,13 @@ If Spark is run on Mesos or YARN, it is still possible to reconstruct the UI of application through Spark's history server, provided that the application's event logs exist. You can start a the history server by executing: - ./sbin/start-history-server.sh + ./sbin/start-history-server.sh -The base logging directory must be supplied, and should contain sub-directories that each -represents an application's event logs. This creates a web interface at -`http://:18080` by default. The history server can be configured as follows: +When using the file-system provider class (see spark.history.provider below), the base logging +directory must be supplied in the spark.history.fs.logDirectory configuration option, +and should contain sub-directories that each represents an application's event logs. This creates a +web interface at `http://:18080` by default. The history server can be configured as +follows: @@ -69,7 +71,14 @@ represents an application's event logs. This creates a web interface at
Environment VariableMeaning
- + + + + + + - + + + + + +
Property NameDefaultMeaning
spark.history.updateIntervalspark.history.providerorg.apache.spark.deploy.history.FsHistoryProviderName of the class implementing the application history backend. Currently there is only + one implementation, provided by Spark, which looks for application logs stored in the + file system.
spark.history.fs.updateInterval 10 The period, in seconds, at which information displayed by this history server is updated. @@ -78,7 +87,7 @@ represents an application's event logs. This creates a web interface at
spark.history.retainedApplications25050 The number of application UIs to retain. If this cap is exceeded, then the oldest applications will be removed. diff --git a/docs/quick-start.md b/docs/quick-start.md index 64023994771b7..23313d8aa6152 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -266,8 +266,6 @@ version := "1.0" scalaVersion := "{{site.SCALA_VERSION}}" libraryDependencies += "org.apache.spark" %% "spark-core" % "{{site.SPARK_VERSION}}" - -resolvers += "Akka Repository" at "http://repo.akka.io/releases/" {% endhighlight %} For sbt to work correctly, we'll need to layout `SimpleApp.scala` and `simple.sbt` @@ -349,12 +347,6 @@ Note that Spark artifacts are tagged with a Scala version. Simple Project jar 1.0 - - - Akka repository - http://repo.akka.io/releases - - org.apache.spark diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index fecd8f2cc2d48..5d8d603aa3e37 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -95,10 +95,19 @@ Most of the configs are the same for Spark on YARN as for other deployment modes The amount of off heap memory (in megabytes) to be allocated per driver. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc.
spark.yarn.jar(none) + The location of the Spark jar file, in case overriding the default location is desired. + By default, Spark on YARN will use a Spark jar installed locally, but the Spark jar can also be + in a world-readable location on HDFS. This allows YARN to cache it on nodes so that it doesn't + need to be distributed each time an application runs. To point to a jar on HDFS, for example, + set this configuration to "hdfs:///some/path". +
-By default, Spark on YARN will use a Spark jar installed locally, but the Spark JAR can also be in a world-readable location on HDFS. This allows YARN to cache it on nodes so that it doesn't need to be distributed each time an application runs. To point to a JAR on HDFS, `export SPARK_JAR=hdfs:///some/path`. - # Launching Spark on YARN Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. @@ -119,7 +128,7 @@ For example: --num-executors 3 \ --driver-memory 4g \ --executor-memory 2g \ - --executor-cores 1 + --executor-cores 1 \ lib/spark-examples*.jar \ 10 @@ -156,7 +165,20 @@ all environment variables used for launching each container. This process is use classpath problems in particular. (Note that enabling this requires admin privileges on cluster settings and a restart of all node managers. Thus, this is not applicable to hosted clusters). -# Important Notes +To use a custom log4j configuration for the application master or executors, there are two options: + +- upload a custom log4j.properties using spark-submit, by adding it to the "--files" list of files + to be uploaded with the application. +- add "-Dlog4j.configuration=" to "spark.driver.extraJavaOptions" + (for the driver) or "spark.executor.extraJavaOptions" (for executors). Note that if using a file, + the "file:" protocol should be explicitly provided, and the file needs to exist locally on all + the nodes. + +Note that for the first option, both executors and the application master will share the same +log4j configuration, which may cause issues when they run on the same node (e.g. trying to write +to the same log file). + +# Important notes - Before Hadoop 2.2, YARN does not support cores in container resource requests. Thus, when running against an earlier version, the numbers of cores given via command line arguments cannot be passed to YARN. Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. - The local directories used by Spark executors will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. diff --git a/ec2/spark-ec2 b/ec2/spark-ec2 index 454057aa0d279..31f9771223e51 100755 --- a/ec2/spark-ec2 +++ b/ec2/spark-ec2 @@ -19,4 +19,4 @@ # cd "`dirname $0`" -PYTHONPATH="./third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" python ./spark_ec2.py $@ +PYTHONPATH="./third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" python ./spark_ec2.py "$@" diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index a40311d9fcf02..e22d93bd31bc2 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -203,6 +203,8 @@ def get_spark_shark_version(opts): # Attempt to resolve an appropriate AMI given the architecture and # region of the request. +# Information regarding Amazon Linux AMI instance type was update on 2014-6-20: +# http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ def get_spark_ami(opts): instance_types = { "m1.small": "pvm", @@ -218,10 +220,12 @@ def get_spark_ami(opts): "cc1.4xlarge": "hvm", "cc2.8xlarge": "hvm", "cg1.4xlarge": "hvm", - "hs1.8xlarge": "hvm", - "hi1.4xlarge": "hvm", - "m3.xlarge": "hvm", - "m3.2xlarge": "hvm", + "hs1.8xlarge": "pvm", + "hi1.4xlarge": "pvm", + "m3.medium": "pvm", + "m3.large": "pvm", + "m3.xlarge": "pvm", + "m3.2xlarge": "pvm", "cr1.8xlarge": "hvm", "i2.xlarge": "hvm", "i2.2xlarge": "hvm", @@ -526,7 +530,8 @@ def wait_for_cluster(conn, wait_secs, master_nodes, slave_nodes): # Get number of local disks available for a given EC2 instance type. def get_num_disks(instance_type): - # From http://docs.amazonwebservices.com/AWSEC2/latest/UserGuide/index.html?InstanceStorage.html + # From http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html + # Updated 2014-6-20 disks_by_instance = { "m1.small": 1, "m1.medium": 1, @@ -544,8 +549,10 @@ def get_num_disks(instance_type): "hs1.8xlarge": 24, "cr1.8xlarge": 2, "hi1.4xlarge": 2, - "m3.xlarge": 0, - "m3.2xlarge": 0, + "m3.medium": 1, + "m3.large": 1, + "m3.xlarge": 2, + "m3.2xlarge": 2, "i2.xlarge": 1, "i2.2xlarge": 2, "i2.4xlarge": 4, @@ -559,7 +566,9 @@ def get_num_disks(instance_type): "r3.xlarge": 1, "r3.2xlarge": 1, "r3.4xlarge": 1, - "r3.8xlarge": 2 + "r3.8xlarge": 2, + "g2.2xlarge": 1, + "t1.micro": 0 } if instance_type in disks_by_instance: return disks_by_instance[instance_type] @@ -770,12 +779,16 @@ def real_main(): setup_cluster(conn, master_nodes, slave_nodes, opts, True) elif action == "destroy": - response = raw_input("Are you sure you want to destroy the cluster " + - cluster_name + "?\nALL DATA ON ALL NODES WILL BE LOST!!\n" + - "Destroy cluster " + cluster_name + " (y/N): ") + print "Are you sure you want to destroy the cluster %s?" % cluster_name + print "The following instances will be terminated:" + (master_nodes, slave_nodes) = get_existing_cluster( + conn, opts, cluster_name, die_on_error=False) + for inst in master_nodes + slave_nodes: + print "> %s" % inst.public_dns_name + + msg = "ALL DATA ON ALL NODES WILL BE LOST!!\nDestroy cluster %s (y/N): " % cluster_name + response = raw_input(msg) if response == "y": - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) print "Terminating master..." for inst in master_nodes: inst.terminate() diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 21443ebbbfb0e..38095e88dcea9 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -38,7 +38,7 @@ import org.apache.spark.streaming.receiver.Receiver /** * Input stream that pulls messages from a Kafka Broker. * - * @param kafkaParams Map of kafka configuration paramaters. + * @param kafkaParams Map of kafka configuration parameters. * See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. @@ -76,29 +76,31 @@ class KafkaReceiver[ // Connection to Kafka var consumerConnector : ConsumerConnector = null - def onStop() { } + def onStop() { + if (consumerConnector != null) { + consumerConnector.shutdown() + } + } def onStart() { - // In case we are using multiple Threads to handle Kafka Messages - val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) - logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("group.id")) // Kafka connection properties val props = new Properties() kafkaParams.foreach(param => props.put(param._1, param._2)) + val zkConnect = kafkaParams("zookeeper.connect") // Create the connection to the cluster - logInfo("Connecting to Zookeper: " + kafkaParams("zookeeper.connect")) + logInfo("Connecting to Zookeeper: " + zkConnect) val consumerConfig = new ConsumerConfig(props) consumerConnector = Consumer.create(consumerConfig) - logInfo("Connected to " + kafkaParams("zookeeper.connect")) + logInfo("Connected to " + zkConnect) - // When autooffset.reset is defined, it is our responsibility to try and whack the + // When auto.offset.reset is defined, it is our responsibility to try and whack the // consumer group zk node. if (kafkaParams.contains("auto.offset.reset")) { - tryZookeeperConsumerGroupCleanup(kafkaParams("zookeeper.connect"), kafkaParams("group.id")) + tryZookeeperConsumerGroupCleanup(zkConnect, kafkaParams("group.id")) } val keyDecoder = manifest[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) @@ -112,10 +114,14 @@ class KafkaReceiver[ val topicMessageStreams = consumerConnector.createMessageStreams( topics, keyDecoder, valueDecoder) - - // Start the messages handler for each partition - topicMessageStreams.values.foreach { streams => - streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } + val executorPool = Executors.newFixedThreadPool(topics.values.sum) + try { + // Start the messages handler for each partition + topicMessageStreams.values.foreach { streams => + streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } + } + } finally { + executorPool.shutdown() // Just causes threads to terminate after work is done } } @@ -124,30 +130,35 @@ class KafkaReceiver[ extends Runnable { def run() { logInfo("Starting MessageHandler.") - for (msgAndMetadata <- stream) { - store((msgAndMetadata.key, msgAndMetadata.message)) + try { + for (msgAndMetadata <- stream) { + store((msgAndMetadata.key, msgAndMetadata.message)) + } + } catch { + case e: Throwable => logError("Error handling message; exiting", e) } } } - // It is our responsibility to delete the consumer group when specifying autooffset.reset. This + // It is our responsibility to delete the consumer group when specifying auto.offset.reset. This // is because Kafka 0.7.2 only honors this param when the group is not in zookeeper. // // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied - // from Kafkas' ConsoleConsumer. See code related to 'autooffset.reset' when it is set to + // from Kafka's ConsoleConsumer. See code related to 'auto.offset.reset' when it is set to // 'smallest'/'largest': // scalastyle:off // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala // scalastyle:on private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) { + val dir = "/consumers/" + groupId + logInfo("Cleaning up temporary Zookeeper data under " + dir + ".") + val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer) try { - val dir = "/consumers/" + groupId - logInfo("Cleaning up temporary zookeeper data under " + dir + ".") - val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer) zk.deleteRecursive(dir) - zk.close() } catch { - case _ : Throwable => // swallow + case e: Throwable => logWarning("Error cleaning up temporary Zookeeper data", e) + } finally { + zk.close() } } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala index 1c6d7e59e9a27..d85afa45b1264 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala @@ -62,7 +62,8 @@ class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/* , AnyRe private[graphx] class VertexBroadcastMsgRDDFunctions[T: ClassTag](self: RDD[VertexBroadcastMsg[T]]) { def partitionBy(partitioner: Partitioner): RDD[VertexBroadcastMsg[T]] = { - val rdd = new ShuffledRDD[PartitionID, (VertexId, T), VertexBroadcastMsg[T]](self, partitioner) + val rdd = new ShuffledRDD[PartitionID, (VertexId, T), (VertexId, T), VertexBroadcastMsg[T]]( + self, partitioner) // Set a custom serializer if the data is of int or double type. if (classTag[T] == ClassTag.Int) { @@ -84,7 +85,7 @@ class MsgRDDFunctions[T: ClassTag](self: RDD[MessageToPartition[T]]) { * Return a copy of the RDD partitioned using the specified partitioner. */ def partitionBy(partitioner: Partitioner): RDD[MessageToPartition[T]] = { - new ShuffledRDD[PartitionID, T, MessageToPartition[T]](self, partitioner) + new ShuffledRDD[PartitionID, T, T, MessageToPartition[T]](self, partitioner) } } @@ -103,7 +104,7 @@ object MsgRDDFunctions { private[graphx] class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) { def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = { - val rdd = new ShuffledRDD[VertexId, VD, (VertexId, VD)](self, partitioner) + val rdd = new ShuffledRDD[VertexId, VD, VD, (VertexId, VD)](self, partitioner) // Set a custom serializer if the data is of int or double type. if (classTag[VD] == ClassTag.Int) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index d02e9238adba5..3827ac8d0fd6a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -46,8 +46,8 @@ private[graphx] class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) { /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */ def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = { - new ShuffledRDD[VertexId, (PartitionID, Byte), RoutingTableMessage](self, partitioner) - .setSerializer(new RoutingTableMessageSerializer) + new ShuffledRDD[VertexId, (PartitionID, Byte), (PartitionID, Byte), RoutingTableMessage]( + self, partitioner).setSerializer(new RoutingTableMessageSerializer) } } diff --git a/make-distribution.sh b/make-distribution.sh index ae52b4976dc25..86868438e75c3 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -84,17 +84,28 @@ while (( "$#" )); do shift done +if [ -z "$JAVA_HOME" ]; then + # Fall back on JAVA_HOME from rpm, if found + if which rpm &>/dev/null; then + RPM_JAVA_HOME=$(rpm -E %java_home 2>/dev/null) + if [ "$RPM_JAVA_HOME" != "%java_home" ]; then + JAVA_HOME=$RPM_JAVA_HOME + echo "No JAVA_HOME set, proceeding with '$JAVA_HOME' learned from rpm" + fi + fi +fi + if [ -z "$JAVA_HOME" ]; then echo "Error: JAVA_HOME is not set, cannot proceed." exit -1 fi -VERSION=$(mvn help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) -if [ $? != 0 ]; then +if ! which mvn &>/dev/null; then echo -e "You need Maven installed to build Spark." echo -e "Download Maven from https://maven.apache.org/" exit -1; fi +VERSION=$(mvn help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) JAVA_CMD="$JAVA_HOME"/bin/java JAVA_VERSION=$("$JAVA_CMD" -version 2>&1) diff --git a/mllib/pom.xml b/mllib/pom.xml index 878cb83dbf783..b622f96dd7901 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -84,5 +84,13 @@ scalatest-maven-plugin + + + ../python + + pyspark/mllib/*.py + + + diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7b42b35a06380..9a5bb1f1c1194 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,7 +34,10 @@ object MimaExcludes { val excludes = SparkBuild.SPARK_VERSION match { case v if v.startsWith("1.1") => - Seq(MimaBuild.excludeSparkPackage("graphx")) ++ + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("graphx") + ) ++ Seq( // Adding new method to JavaRDLike trait // We should probably mark this as a developer API. diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 19235d5f79f85..0dbead4415b02 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -43,18 +43,23 @@ def launch_gateway(): # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdout=PIPE, stdin=PIPE, stderr=PIPE, preexec_fn=preexec_func) + proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func) else: # preexec_fn not supported on Windows - proc = Popen(command, stdout=PIPE, stdin=PIPE, stderr=PIPE) - + proc = Popen(command, stdout=PIPE, stdin=PIPE) + try: # Determine which ephemeral port the server started on: - gateway_port = int(proc.stdout.readline()) - except: - error_code = proc.poll() - raise Exception("Launching GatewayServer failed with exit code %d: %s" % - (error_code, "".join(proc.stderr.readlines()))) + gateway_port = proc.stdout.readline() + gateway_port = int(gateway_port) + except ValueError: + (stdout, _) = proc.communicate() + exit_code = proc.poll() + error_msg = "Launching GatewayServer failed" + error_msg += " with exit code %d!" % exit_code if exit_code else "! " + error_msg += "(Warning: unexpected output detected.)\n\n" + error_msg += gateway_port + stdout + raise Exception(error_msg) # Create a thread to echo output from the GatewayServer, which is required # for Java log output to show up: diff --git a/sbin/start-history-server.sh b/sbin/start-history-server.sh index 4a90c68763b68..e30493da32a7a 100755 --- a/sbin/start-history-server.sh +++ b/sbin/start-history-server.sh @@ -19,19 +19,18 @@ # Starts the history server on the machine this script is executed on. # -# Usage: start-history-server.sh [] -# Example: ./start-history-server.sh --dir /tmp/spark-events --port 18080 +# Usage: start-history-server.sh +# +# Use the SPARK_HISTORY_OPTS environment variable to set history server configuration. # sbin=`dirname "$0"` sbin=`cd "$sbin"; pwd` -if [ $# -lt 1 ]; then - echo "Usage: ./start-history-server.sh " - echo "Example: ./start-history-server.sh /tmp/spark-events" - exit +if [ $# != 0 ]; then + echo "Using command line arguments for setting the log directory is deprecated. Please " + echo "set the spark.history.fs.logDirectory configuration option instead." + export SPARK_HISTORY_OPTS="$SPARK_HISTORY_OPTS -Dspark.history.fs.logDirectory=$1" fi -LOG_DIR=$1 - -"$sbin"/spark-daemon.sh start org.apache.spark.deploy.history.HistoryServer 1 --dir "$LOG_DIR" +exec "$sbin"/spark-daemon.sh start org.apache.spark.deploy.history.HistoryServer 1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 196695a0a188f..ada48eaf5dc0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -30,53 +30,56 @@ import org.apache.spark.sql.catalyst.types._ object ScalaReflection { import scala.reflect.runtime.universe._ + case class Schema(dataType: DataType, nullable: Boolean) + /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { - case s: StructType => - s.fields.map(f => AttributeReference(f.name, f.dataType, nullable = true)()) + case Schema(s: StructType, _) => + s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) } - /** Returns a catalyst DataType for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: DataType = schemaFor(typeOf[T]) + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ + def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T]) - /** Returns a catalyst DataType for the given Scala Type using reflection. */ - def schemaFor(tpe: `Type`): DataType = tpe match { + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ + def schemaFor(tpe: `Type`): Schema = tpe match { case t if t <:< typeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - schemaFor(optType) + Schema(schemaFor(optType).dataType, nullable = true) case t if t <:< typeOf[Product] => val params = t.member("": TermName).asMethod.paramss - StructType( - params.head.map(p => - StructField(p.name.toString, schemaFor(p.typeSignature), nullable = true))) + Schema(StructType( + params.head.map { p => + val Schema(dataType, nullable) = schemaFor(p.typeSignature) + StructField(p.name.toString, dataType, nullable) + }), nullable = true) // Need to decide if we actually need a special type here. - case t if t <:< typeOf[Array[Byte]] => BinaryType + case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) case t if t <:< typeOf[Array[_]] => sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") case t if t <:< typeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - ArrayType(schemaFor(elementType)) + Schema(ArrayType(schemaFor(elementType).dataType), nullable = true) case t if t <:< typeOf[Map[_,_]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - MapType(schemaFor(keyType), schemaFor(valueType)) - case t if t <:< typeOf[String] => StringType - case t if t <:< typeOf[Timestamp] => TimestampType - case t if t <:< typeOf[BigDecimal] => DecimalType - case t if t <:< typeOf[java.lang.Integer] => IntegerType - case t if t <:< typeOf[java.lang.Long] => LongType - case t if t <:< typeOf[java.lang.Double] => DoubleType - case t if t <:< typeOf[java.lang.Float] => FloatType - case t if t <:< typeOf[java.lang.Short] => ShortType - case t if t <:< typeOf[java.lang.Byte] => ByteType - case t if t <:< typeOf[java.lang.Boolean] => BooleanType - // TODO: The following datatypes could be marked as non-nullable. - case t if t <:< definitions.IntTpe => IntegerType - case t if t <:< definitions.LongTpe => LongType - case t if t <:< definitions.DoubleTpe => DoubleType - case t if t <:< definitions.FloatTpe => FloatType - case t if t <:< definitions.ShortTpe => ShortType - case t if t <:< definitions.ByteTpe => ByteType - case t if t <:< definitions.BooleanTpe => BooleanType + Schema(MapType(schemaFor(keyType).dataType, schemaFor(valueType).dataType), nullable = true) + case t if t <:< typeOf[String] => Schema(StringType, nullable = true) + case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true) + case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) + case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) + case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) + case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) + case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) + case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) + case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) + case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) + case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) + case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) + case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) + case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) + case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) + case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) } implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index c9b7cea6a3e5f..2c71d2c7b3563 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -45,8 +45,10 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) { * that schema. * * In contrast to a normal projection, a MutableProjection reuses the same underlying row object - * each time an input row is added. This significatly reduces the cost of calcuating the - * projection, but means that it is not safe + * each time an input row is added. This significantly reduces the cost of calculating the + * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()` + * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()` + * and hold on to the returned [[Row]] before calling `next()`. */ case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = @@ -67,7 +69,7 @@ case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) } /** - * A mutable wrapper that makes two rows appear appear as a single concatenated row. Designed to + * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to * be instantiated once per thread and reused. */ class JoinedRow extends Row { @@ -81,6 +83,18 @@ class JoinedRow extends Row { this } + /** Updates this JoinedRow by updating its left base row. Returns itself. */ + def withLeft(newLeft: Row): Row = { + row1 = newLeft + this + } + + /** Updates this JoinedRow by updating its right base row. Returns itself. */ + def withRight(newRight: Row): Row = { + row2 = newRight + this + } + def iterator = row1.iterator ++ row2.iterator def length = row1.length + row2.length @@ -124,4 +138,9 @@ class JoinedRow extends Row { } new GenericRow(copiedValues) } + + override def toString() = { + val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) + s"[${row.mkString(",")}]" + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala index 7c616788a3830..582334aa42590 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala @@ -21,5 +21,4 @@ abstract class BaseRelation extends LeafNode { self: Product => def tableName: String - def isPartitioned: Boolean = false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala new file mode 100644 index 0000000000000..489d7e9c2437f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import java.sql.Timestamp + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ + +case class PrimitiveData( + intField: Int, + longField: Long, + doubleField: Double, + floatField: Float, + shortField: Short, + byteField: Byte, + booleanField: Boolean) + +case class NullableData( + intField: java.lang.Integer, + longField: java.lang.Long, + doubleField: java.lang.Double, + floatField: java.lang.Float, + shortField: java.lang.Short, + byteField: java.lang.Byte, + booleanField: java.lang.Boolean, + stringField: String, + decimalField: BigDecimal, + timestampField: Timestamp, + binaryField: Array[Byte]) + +case class OptionalData( + intField: Option[Int], + longField: Option[Long], + doubleField: Option[Double], + floatField: Option[Float], + shortField: Option[Short], + byteField: Option[Byte], + booleanField: Option[Boolean]) + +case class ComplexData( + arrayField: Seq[Int], + mapField: Map[Int, String], + structField: PrimitiveData) + +class ScalaReflectionSuite extends FunSuite { + import ScalaReflection._ + + test("primitive data") { + val schema = schemaFor[PrimitiveData] + assert(schema === Schema( + StructType(Seq( + StructField("intField", IntegerType, nullable = false), + StructField("longField", LongType, nullable = false), + StructField("doubleField", DoubleType, nullable = false), + StructField("floatField", FloatType, nullable = false), + StructField("shortField", ShortType, nullable = false), + StructField("byteField", ByteType, nullable = false), + StructField("booleanField", BooleanType, nullable = false))), + nullable = true)) + } + + test("nullable data") { + val schema = schemaFor[NullableData] + assert(schema === Schema( + StructType(Seq( + StructField("intField", IntegerType, nullable = true), + StructField("longField", LongType, nullable = true), + StructField("doubleField", DoubleType, nullable = true), + StructField("floatField", FloatType, nullable = true), + StructField("shortField", ShortType, nullable = true), + StructField("byteField", ByteType, nullable = true), + StructField("booleanField", BooleanType, nullable = true), + StructField("stringField", StringType, nullable = true), + StructField("decimalField", DecimalType, nullable = true), + StructField("timestampField", TimestampType, nullable = true), + StructField("binaryField", BinaryType, nullable = true))), + nullable = true)) + } + + test("optinal data") { + val schema = schemaFor[OptionalData] + assert(schema === Schema( + StructType(Seq( + StructField("intField", IntegerType, nullable = true), + StructField("longField", LongType, nullable = true), + StructField("doubleField", DoubleType, nullable = true), + StructField("floatField", FloatType, nullable = true), + StructField("shortField", ShortType, nullable = true), + StructField("byteField", ByteType, nullable = true), + StructField("booleanField", BooleanType, nullable = true))), + nullable = true)) + } + + test("complex data") { + val schema = schemaFor[ComplexData] + assert(schema === Schema( + StructType(Seq( + StructField("arrayField", ArrayType(IntegerType), nullable = true), + StructField("mapField", MapType(IntegerType, StringType), nullable = true), + StructField( + "structField", + StructType(Seq( + StructField("intField", IntegerType, nullable = false), + StructField("longField", LongType, nullable = false), + StructField("doubleField", DoubleType, nullable = false), + StructField("floatField", FloatType, nullable = false), + StructField("shortField", ShortType, nullable = false), + StructField("byteField", ByteType, nullable = false), + StructField("booleanField", BooleanType, nullable = false))), + nullable = true))), + nullable = true)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index b378252ba2f55..2fe7f94663996 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -29,9 +29,26 @@ import scala.collection.JavaConverters._ */ trait SQLConf { + /** ************************ Spark SQL Params/Hints ******************* */ + // TODO: refactor so that these hints accessors don't pollute the name space of SQLContext? + /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt + /** + * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to + * a broadcast value during the physical executions of join operations. Setting this to 0 + * effectively disables auto conversion. + * Hive setting: hive.auto.convert.join.noconditionaltask.size. + */ + private[spark] def autoConvertJoinSize: Int = + get("spark.sql.auto.convert.join.size", "10000").toInt + + /** A comma-separated list of table names marked to be broadcasted during joins. */ + private[spark] def joinBroadcastTables: String = get("spark.sql.join.broadcastTables", "") + + /** ********************** SQLConf functionality methods ************ */ + @transient private val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index c60af28b2a1f3..7edb548678c33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -170,7 +170,11 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = { - catalog.registerTable(None, tableName, rdd.logicalPlan) + val name = tableName + val newPlan = rdd.logicalPlan transform { + case s @ SparkLogicalPlan(ExistingRdd(_, _), _) => s.copy(tableName = name) + } + catalog.registerTable(None, tableName, newPlan) } /** @@ -186,18 +190,23 @@ class SQLContext(@transient val sparkContext: SparkContext) /** Caches the specified table in-memory. */ def cacheTable(tableName: String): Unit = { - val currentTable = catalog.lookupRelation(None, tableName) - val useCompression = - sparkContext.conf.getBoolean("spark.sql.inMemoryColumnarStorage.compressed", false) - val asInMemoryRelation = - InMemoryRelation(useCompression, executePlan(currentTable).executedPlan) + val currentTable = table(tableName).queryExecution.analyzed + val asInMemoryRelation = currentTable match { + case _: InMemoryRelation => + currentTable.logicalPlan + + case _ => + val useCompression = + sparkContext.conf.getBoolean("spark.sql.inMemoryColumnarStorage.compressed", false) + InMemoryRelation(useCompression, executePlan(currentTable).executedPlan) + } catalog.registerTable(None, tableName, asInMemoryRelation) } /** Removes the specified table from the in-memory cache. */ def uncacheTable(tableName: String): Unit = { - EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) match { + table(tableName).queryExecution.analyzed match { // This is kind of a hack to make sure that if this was just an RDD registered as a table, // we reregister the RDD as a table. case inMem @ InMemoryRelation(_, _, e: ExistingRdd) => @@ -213,8 +222,8 @@ class SQLContext(@transient val sparkContext: SparkContext) /** Returns true if the table is currently cached in-memory. */ def isCached(tableName: String): Boolean = { - val relation = catalog.lookupRelation(None, tableName) - EliminateAnalysisOperators(relation) match { + val relation = table(tableName).queryExecution.analyzed + relation match { case _: InMemoryRelation => true case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index ff6deeda2394d..790d9ef22cf16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -137,26 +137,25 @@ class JavaSQLContext(val sqlContext: SQLContext) { val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") fields.map { property => - val dataType = property.getPropertyType match { - case c: Class[_] if c == classOf[java.lang.String] => StringType - case c: Class[_] if c == java.lang.Short.TYPE => ShortType - case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType - case c: Class[_] if c == java.lang.Long.TYPE => LongType - case c: Class[_] if c == java.lang.Double.TYPE => DoubleType - case c: Class[_] if c == java.lang.Byte.TYPE => ByteType - case c: Class[_] if c == java.lang.Float.TYPE => FloatType - case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType - - case c: Class[_] if c == classOf[java.lang.Short] => ShortType - case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType - case c: Class[_] if c == classOf[java.lang.Long] => LongType - case c: Class[_] if c == classOf[java.lang.Double] => DoubleType - case c: Class[_] if c == classOf[java.lang.Byte] => ByteType - case c: Class[_] if c == classOf[java.lang.Float] => FloatType - case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType + val (dataType, nullable) = property.getPropertyType match { + case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) + case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) + case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) + case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) + case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) + case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) + case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) + case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) + + case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) + case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) + case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) + case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) + case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) + case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) + case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) } - // TODO: Nullability could be stricter. - AttributeReference(property.getName, dataType, nullable = true)() + AttributeReference(property.getName, dataType, nullable)() } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index f46fa0516566f..00010ef6e798a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -47,7 +47,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una iter.map(r => mutablePair.update(hashExpressions(r), r)) } val part = new HashPartitioner(numPartitions) - val shuffled = new ShuffledRDD[Row, Row, MutablePair[Row, Row]](rdd, part) + val shuffled = new ShuffledRDD[Row, Row, Row, MutablePair[Row, Row]](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) @@ -60,7 +60,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una iter.map(row => mutablePair.update(row, null)) } val part = new RangePartitioner(numPartitions, rdd, ascending = true) - val shuffled = new ShuffledRDD[Row, Null, MutablePair[Row, Null]](rdd, part) + val shuffled = new ShuffledRDD[Row, Null, Null, MutablePair[Row, Null]](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._1) @@ -71,7 +71,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una iter.map(r => mutablePair.update(null, r)) } val partitioner = new HashPartitioner(1) - val shuffled = new ShuffledRDD[Null, Row, MutablePair[Null, Row]](rdd, partitioner) + val shuffled = new ShuffledRDD[Null, Row, Row, MutablePair[Null, Row]](rdd, partitioner) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 07967fe75e882..27dc091b85812 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.{Logging, Row} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.catalyst.plans.{QueryPlan, logical} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.BaseRelation import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan /** * :: DeveloperApi :: @@ -66,19 +66,20 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { * linking. */ @DeveloperApi -case class SparkLogicalPlan(alreadyPlanned: SparkPlan) - extends logical.LogicalPlan with MultiInstanceRelation { +case class SparkLogicalPlan(alreadyPlanned: SparkPlan, tableName: String = "SparkLogicalPlan") + extends BaseRelation with MultiInstanceRelation { def output = alreadyPlanned.output - def references = Set.empty - def children = Nil + override def references = Set.empty + override def children = Nil override final def newInstance: this.type = { SparkLogicalPlan( alreadyPlanned match { case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) case _ => sys.error("Multiple instance of the same relation detected.") - }).asInstanceOf[this.type] + }, tableName) + .asInstanceOf[this.type] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index bd8ae4cddef89..3cd29967d1cd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -21,10 +21,10 @@ import org.apache.spark.sql.{SQLContext, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{BaseRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.parquet._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} +import org.apache.spark.sql.parquet._ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => @@ -45,14 +45,52 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Uses the HashFilteredJoin pattern to find joins where at least some of the predicates can be + * evaluated by matching hash keys. + */ object HashJoin extends Strategy with PredicateHelper { + private[this] def broadcastHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: LogicalPlan, + right: LogicalPlan, + condition: Option[Expression], + side: BuildSide) = { + val broadcastHashJoin = execution.BroadcastHashJoin( + leftKeys, rightKeys, side, planLater(left), planLater(right))(sqlContext) + condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil + } + + def broadcastTables: Seq[String] = sqlContext.joinBroadcastTables.split(",").toBuffer + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // Find inner joins where at least some predicates can be evaluated by matching hash keys - // using the HashFilteredJoin pattern. + case HashFilteredJoin( + Inner, + leftKeys, + rightKeys, + condition, + left, + right @ PhysicalOperation(_, _, b: BaseRelation)) + if broadcastTables.contains(b.tableName) => + broadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight) + + case HashFilteredJoin( + Inner, + leftKeys, + rightKeys, + condition, + left @ PhysicalOperation(_, _, b: BaseRelation), + right) + if broadcastTables.contains(b.tableName) => + broadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft) + case HashFilteredJoin(Inner, leftKeys, rightKeys, condition, left, right) => val hashJoin = - execution.HashJoin(leftKeys, rightKeys, BuildRight, planLater(left), planLater(right)) + execution.ShuffledHashJoin( + leftKeys, rightKeys, BuildRight, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + case _ => Nil } } @@ -62,10 +100,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => // Collect all aggregate expressions. val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a}) + aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a }) // Collect all aggregate expressions that can be computed partially. val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p}) + aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p }) // Only do partial aggregation if supported by all aggregate expressions. if (allAggregates.size == partialAggregates.size) { @@ -242,7 +280,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.ExistingRdd(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil - case SparkLogicalPlan(existingPlan) => existingPlan :: Nil + case SparkLogicalPlan(existingPlan, _) => existingPlan :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 18f4a5877bb21..a278f1ca98476 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -105,7 +105,7 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext iter.take(limit).map(row => mutablePair.update(false, row)) } val part = new HashPartitioner(1) - val shuffled = new ShuffledRDD[Boolean, Row, MutablePair[Boolean, Row]](rdd, part) + val shuffled = new ShuffledRDD[Boolean, Row, Row, MutablePair[Boolean, Row]](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.mapPartitions(_.take(limit).map(_._2)) } @@ -205,4 +205,3 @@ object ExistingRdd { case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { override def execute() = rdd } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 84bdde38b7e9e..32c5f26fe8aa0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -18,12 +18,15 @@ package org.apache.spark.sql.execution import scala.collection.mutable.{ArrayBuffer, BitSet} +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent._ +import scala.concurrent.duration._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical._ @DeveloperApi sealed abstract class BuildSide @@ -34,28 +37,19 @@ case object BuildLeft extends BuildSide @DeveloperApi case object BuildRight extends BuildSide -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class HashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - - override def outputPartitioning: Partitioning = left.outputPartitioning +trait HashJoin { + val leftKeys: Seq[Expression] + val rightKeys: Seq[Expression] + val buildSide: BuildSide + val left: SparkPlan + val right: SparkPlan - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - val (buildPlan, streamedPlan) = buildSide match { + lazy val (buildPlan, streamedPlan) = buildSide match { case BuildLeft => (left, right) case BuildRight => (right, left) } - val (buildKeys, streamedKeys) = buildSide match { + lazy val (buildKeys, streamedKeys) = buildSide match { case BuildLeft => (leftKeys, rightKeys) case BuildRight => (rightKeys, leftKeys) } @@ -66,73 +60,74 @@ case class HashJoin( @transient lazy val streamSideKeyGenerator = () => new MutableProjection(streamedKeys, streamedPlan.output) - def execute() = { - - buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - // TODO: Use Spark's HashMap implementation. - val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() - var currentRow: Row = null - - // Create a mapping of buildKeys -> rows - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if(!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new ArrayBuffer[Row]() - hashTable.put(rowKey, newMatchList) - newMatchList - } else { - existingMatchList - } - matchList += currentRow.copy() + def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { + // TODO: Use Spark's HashMap implementation. + + val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() + var currentRow: Row = null + + // Create a mapping of buildKeys -> rows + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if(!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new ArrayBuffer[Row]() + hashTable.put(rowKey, newMatchList) + newMatchList + } else { + existingMatchList } + matchList += currentRow.copy() } + } - new Iterator[Row] { - private[this] var currentStreamedRow: Row = _ - private[this] var currentHashMatches: ArrayBuffer[Row] = _ - private[this] var currentMatchPosition: Int = -1 + new Iterator[Row] { + private[this] var currentStreamedRow: Row = _ + private[this] var currentHashMatches: ArrayBuffer[Row] = _ + private[this] var currentMatchPosition: Int = -1 - // Mutable per row objects. - private[this] val joinRow = new JoinedRow + // Mutable per row objects. + private[this] val joinRow = new JoinedRow - private[this] val joinKeys = streamSideKeyGenerator() + private[this] val joinKeys = streamSideKeyGenerator() - override final def hasNext: Boolean = - (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || + override final def hasNext: Boolean = + (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || (streamIter.hasNext && fetchNext()) - override final def next() = { - val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) - currentMatchPosition += 1 - ret + override final def next() = { + val ret = buildSide match { + case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) } + currentMatchPosition += 1 + ret + } - /** - * Searches the streamed iterator for the next row that has at least one match in hashtable. - * - * @return true if the search is successful, and false the streamed iterator runs out of - * tuples. - */ - private final def fetchNext(): Boolean = { - currentHashMatches = null - currentMatchPosition = -1 - - while (currentHashMatches == null && streamIter.hasNext) { - currentStreamedRow = streamIter.next() - if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatches = hashTable.get(joinKeys.currentValue) - } + /** + * Searches the streamed iterator for the next row that has at least one match in hashtable. + * + * @return true if the search is successful, and false if the streamed iterator runs out of + * tuples. + */ + private final def fetchNext(): Boolean = { + currentHashMatches = null + currentMatchPosition = -1 + + while (currentHashMatches == null && streamIter.hasNext) { + currentStreamedRow = streamIter.next() + if (!joinKeys(currentStreamedRow).anyNull) { + currentHashMatches = hashTable.get(joinKeys.currentValue) } + } - if (currentHashMatches == null) { - false - } else { - currentMatchPosition = 0 - true - } + if (currentHashMatches == null) { + false + } else { + currentMatchPosition = 0 + true } } } @@ -141,32 +136,49 @@ case class HashJoin( /** * :: DeveloperApi :: - * Build the right table's join keys into a HashSet, and iteratively go through the left - * table, to find the if join keys are in the Hash set. + * Performs an inner hash join of two child relations by first shuffling the data using the join + * keys. */ @DeveloperApi -case class LeftSemiJoinHash( +case class ShuffledHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + buildSide: BuildSide, left: SparkPlan, - right: SparkPlan) extends BinaryNode { + right: SparkPlan) extends BinaryNode with HashJoin { override def outputPartitioning: Partitioning = left.outputPartitioning override def requiredChildDistribution = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - val (buildPlan, streamedPlan) = (right, left) - val (buildKeys, streamedKeys) = (rightKeys, leftKeys) + def execute() = { + buildPlan.execute().zipPartitions(streamedPlan.execute()) { + (buildIter, streamIter) => joinIterators(buildIter, streamIter) + } + } +} - def output = left.output +/** + * :: DeveloperApi :: + * Build the right table's join keys into a HashSet, and iteratively go through the left + * table, to find the if join keys are in the Hash set. + */ +@DeveloperApi +case class LeftSemiJoinHash( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashJoin { - @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output) - @transient lazy val streamSideKeyGenerator = - () => new MutableProjection(streamedKeys, streamedPlan.output) + val buildSide = BuildRight - def execute() = { + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def output = left.output + def execute() = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => val hashSet = new java.util.HashSet[Row]() var currentRow: Row = null @@ -191,6 +203,43 @@ case class LeftSemiJoinHash( } } + +/** + * :: DeveloperApi :: + * Performs an inner hash join of two child relations. When the output RDD of this operator is + * being constructed, a Spark job is asynchronously started to calculate the values for the + * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed + * relation is not shuffled. + */ +@DeveloperApi +case class BroadcastHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan)(@transient sqlContext: SQLContext) extends BinaryNode with HashJoin { + + override def otherCopyArgs = sqlContext :: Nil + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution = + UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + + @transient + lazy val broadcastFuture = future { + sqlContext.sparkContext.broadcast(buildPlan.executeCollect()) + } + + def execute() = { + val broadcastRelation = Await.result(broadcastFuture, 5.minute) + + streamedPlan.execute().mapPartitions { streamedIter => + joinIterators(broadcastRelation.value.iterator, streamedIter) + } + } +} + /** * :: DeveloperApi :: * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys @@ -220,7 +269,6 @@ case class LeftSemiJoinBNL( .map(c => BindReferences.bindReference(c, left.output ++ right.output)) .getOrElse(Literal(true))) - def execute() = { val broadcastedRelation = sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) @@ -284,7 +332,6 @@ case class BroadcastNestedLoopJoin( .map(c => BindReferences.bindReference(c, left.output ++ right.output)) .getOrElse(Literal(true))) - def execute() = { val broadcastedRelation = sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 96c131a7f8af1..9c4771d1a9846 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -44,8 +44,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} * @param path The path to the Parquet file. */ private[sql] case class ParquetRelation( - val path: String, - @transient val conf: Option[Configuration] = None) extends LeafNode with MultiInstanceRelation { + path: String, + @transient conf: Option[Configuration] = None) extends LeafNode with MultiInstanceRelation { + self: Product => /** Schema derived from ParquetFile */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index c794da4da4069..c3c0dcb1aa00b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -20,10 +20,30 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ class CachedTableSuite extends QueryTest { TestData // Load test tables. + test("SPARK-1669: cacheTable should be idempotent") { + assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + + cacheTable("testData") + table("testData").queryExecution.analyzed match { + case _: InMemoryRelation => + case _ => + fail("testData should be cached") + } + + cacheTable("testData") + table("testData").queryExecution.analyzed match { + case InMemoryRelation(_, _, _: InMemoryColumnarTableScan) => + fail("cacheTable is not idempotent") + + case _ => + } + } + test("read from cached table and uncache") { TestSQLContext.cacheTable("testData") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index fb599e1e01e73..e4a64a7a482b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.test._ /* Implicits */ @@ -149,102 +148,4 @@ class DslQuerySuite extends QueryTest { test("zero count") { assert(emptyTableData.count() === 0) } - - test("inner join where, one match per row") { - checkAnswer( - upperCaseData.join(lowerCaseData, Inner).where('n === 'N), - Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d") - )) - } - - test("inner join ON, one match per row") { - checkAnswer( - upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)), - Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d") - )) - } - - test("inner join, where, multiple matches") { - val x = testData2.where('a === 1).as('x) - val y = testData2.where('a === 1).as('y) - checkAnswer( - x.join(y).where("x.a".attr === "y.a".attr), - (1,1,1,1) :: - (1,1,1,2) :: - (1,2,1,1) :: - (1,2,1,2) :: Nil - ) - } - - test("inner join, no matches") { - val x = testData2.where('a === 1).as('x) - val y = testData2.where('a === 2).as('y) - checkAnswer( - x.join(y).where("x.a".attr === "y.a".attr), - Nil) - } - - test("big inner join, 4 matches per row") { - val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) - val bigDataX = bigData.as('x) - val bigDataY = bigData.as('y) - - checkAnswer( - bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr), - testData.flatMap( - row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq) - } - - test("cartisian product join") { - checkAnswer( - testData3.join(testData3), - (1, null, 1, null) :: - (1, null, 2, 2) :: - (2, 2, 1, null) :: - (2, 2, 2, 2) :: Nil) - } - - test("left outer join") { - checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)), - (1, "A", 1, "a") :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) - } - - test("right outer join") { - checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)), - (1, "a", 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) - } - - test("full outer join") { - val left = upperCaseData.where('N <= 4).as('left) - val right = upperCaseData.where('N >= 3).as('right) - - checkAnswer( - left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala new file mode 100644 index 0000000000000..3d7d5eedbe8ed --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ + +class JoinSuite extends QueryTest { + + // Ensures tables are loaded. + TestData + + test("equi-join is hash-join") { + val x = testData2.as('x) + val y = testData2.as('y) + val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed + val planned = planner.HashJoin(join) + assert(planned.size === 1) + } + + test("plans broadcast hash join, given hints") { + + def mkTest(buildSide: BuildSide, leftTable: String, rightTable: String) = { + TestSQLContext.set("spark.sql.join.broadcastTables", + s"${if (buildSide == BuildRight) rightTable else leftTable}") + val rdd = sql(s"""SELECT * FROM $leftTable JOIN $rightTable ON key = a""") + // Using `sparkPlan` because for relevant patterns in HashJoin to be + // matched, other strategies need to be applied. + val physical = rdd.queryExecution.sparkPlan + val bhj = physical.collect { case j: BroadcastHashJoin if j.buildSide == buildSide => j } + + assert(bhj.size === 1, "planner does not pick up hint to generate broadcast hash join") + checkAnswer( + rdd, + Seq( + (1, "1", 1, 1), + (1, "1", 1, 2), + (2, "2", 2, 1), + (2, "2", 2, 2), + (3, "3", 3, 1), + (3, "3", 3, 2) + )) + } + + mkTest(BuildRight, "testData", "testData2") + mkTest(BuildLeft, "testData", "testData2") + } + + test("multiple-key equi-join is hash-join") { + val x = testData2.as('x) + val y = testData2.as('y) + val join = x.join(y, Inner, + Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed + val planned = planner.HashJoin(join) + assert(planned.size === 1) + } + + test("inner join where, one match per row") { + checkAnswer( + upperCaseData.join(lowerCaseData, Inner).where('n === 'N), + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + )) + } + + test("inner join ON, one match per row") { + checkAnswer( + upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)), + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + )) + } + + test("inner join, where, multiple matches") { + val x = testData2.where('a === 1).as('x) + val y = testData2.where('a === 1).as('y) + checkAnswer( + x.join(y).where("x.a".attr === "y.a".attr), + (1,1,1,1) :: + (1,1,1,2) :: + (1,2,1,1) :: + (1,2,1,2) :: Nil + ) + } + + test("inner join, no matches") { + val x = testData2.where('a === 1).as('x) + val y = testData2.where('a === 2).as('y) + checkAnswer( + x.join(y).where("x.a".attr === "y.a".attr), + Nil) + } + + test("big inner join, 4 matches per row") { + val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) + val bigDataX = bigData.as('x) + val bigDataY = bigData.as('y) + + checkAnswer( + bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr), + testData.flatMap( + row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq) + } + + test("cartisian product join") { + checkAnswer( + testData3.join(testData3), + (1, null, 1, null) :: + (1, null, 2, 2) :: + (2, 2, 1, null) :: + (2, 2, 2, 2) :: Nil) + } + + test("left outer join") { + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)), + (1, "A", 1, "a") :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + } + + test("right outer join") { + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)), + (1, "a", 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + } + + test("full outer join") { + val left = upperCaseData.where('N <= 4).as('left) + val right = upperCaseData.where('N >= 3).as('right) + + checkAnswer( + left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index ef84ead2e6e8b..8e1e1971d968b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -35,7 +35,7 @@ class QueryTest extends PlanTest { case singleItem => Seq(Seq(singleItem)) } - val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s}.nonEmpty + val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty def prepareAnswer(answer: Seq[Any]) = if (!isSorted) answer.sortBy(_.toString) else answer val sparkAnswer = try rdd.collect().toSeq catch { case e: Exception => @@ -48,7 +48,7 @@ class QueryTest extends PlanTest { """.stripMargin) } - if(prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) { + if (prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) { fail(s""" |Results do not match for query: |${rdd.logicalPlan} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index e9360b0fc7910..bf7fafe952303 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.test._ /* Implicits */ @@ -404,5 +406,4 @@ class SQLQuerySuite extends QueryTest { ) clear() } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index df6b118360d01..215618e852eb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -57,21 +57,4 @@ class PlannerSuite extends FunSuite { val planned = PartialAggregation(query) assert(planned.isEmpty) } - - test("equi-join is hash-join") { - val x = testData2.as('x) - val y = testData2.as('y) - val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed - val planned = planner.HashJoin(join) - assert(planned.size === 1) - } - - test("multiple-key equi-join is hash-join") { - val x = testData2.as('x) - val y = testData2.as('y) - val join = x.join(y, Inner, - Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed - val planned = planner.HashJoin(join) - assert(planned.size === 1) - } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 7695242a81601..7aedfcd74189b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -258,7 +258,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { struct.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ))=> + case (seq: Seq[_], ArrayType(typ)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") case (map: Map[_,_], MapType(kType, vType)) => map.map { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index faa30c9ae5cca..90eacf4268780 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -34,9 +34,8 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.execution.SparkLogicalPlan -import org.apache.spark.sql.hive.execution.{HiveTableScan, InsertIntoHiveTable} -import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} +import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.hive.execution.HiveTableScan /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -259,8 +258,6 @@ private[hive] case class MetastoreRelation new Partition(hiveQlTable, p) } - override def isPartitioned = hiveQlTable.isPartitioned - val tableDesc = new TableDesc( Class.forName(hiveQlTable.getSerializationLib).asInstanceOf[Class[Deserializer]], hiveQlTable.getInputFormatClass, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index b073dc3895f05..b70104dd5be5a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -38,8 +38,6 @@ import scala.collection.JavaConversions._ */ private[hive] case object NativePlaceholder extends Command -private[hive] case class DfsCommand(cmd: String) extends Command - private[hive] case class ShellCommand(cmd: String) extends Command private[hive] case class SourceCommand(filePath: String) extends Command @@ -227,15 +225,15 @@ private[hive] object HiveQl { SetCommand(Some(key), Some(value)) } } else if (sql.trim.toLowerCase.startsWith("cache table")) { - CacheCommand(sql.drop(12).trim, true) + CacheCommand(sql.trim.drop(12).trim, true) } else if (sql.trim.toLowerCase.startsWith("uncache table")) { - CacheCommand(sql.drop(14).trim, false) + CacheCommand(sql.trim.drop(14).trim, false) } else if (sql.trim.toLowerCase.startsWith("add jar")) { - AddJar(sql.drop(8)) + AddJar(sql.trim.drop(8)) } else if (sql.trim.toLowerCase.startsWith("add file")) { - AddFile(sql.drop(9)) - } else if (sql.trim.startsWith("dfs")) { - DfsCommand(sql) + AddFile(sql.trim.drop(9)) + } else if (sql.trim.toLowerCase.startsWith("dfs")) { + NativeCommand(sql) } else if (sql.trim.startsWith("source")) { SourceCommand(sql.split(" ").toSeq match { case Seq("source", filePath) => filePath }) } else if (sql.trim.startsWith("!")) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 594a803806ede..c2b0b00aa5852 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConversions._ +import java.util.{HashMap => JHashMap} + import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.metastore.MetaStoreUtils import org.apache.hadoop.hive.ql.Context @@ -88,6 +90,12 @@ case class InsertIntoHiveTable( val wrappedSeq = s.map(wrap(_, oi.getListElementObjectInspector)) seqAsJavaList(wrappedSeq) + case (m: Map[_, _], oi: MapObjectInspector) => + val keyOi = oi.getMapKeyObjectInspector + val valueOi = oi.getMapValueObjectInspector + val wrappedMap = m.map { case (key, value) => wrap(key, keyOi) -> wrap(value, valueOi) } + mapAsJavaMap(wrappedMap) + case (obj, _) => obj } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index ad5e24c62c621..9b105308ab7cf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -84,7 +84,7 @@ private[hive] object HiveFunctionRegistry case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType - + // java class case c: Class[_] if c == classOf[java.lang.String] => StringType case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType @@ -98,7 +98,7 @@ private[hive] object HiveFunctionRegistry case c: Class[_] if c == classOf[java.lang.Byte] => ByteType case c: Class[_] if c == classOf[java.lang.Float] => FloatType case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType - + // primitive type case c: Class[_] if c == java.lang.Short.TYPE => ShortType case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType @@ -107,7 +107,7 @@ private[hive] object HiveFunctionRegistry case c: Class[_] if c == java.lang.Byte.TYPE => ByteType case c: Class[_] if c == java.lang.Float.TYPE => FloatType case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType - + case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType)) } } @@ -148,7 +148,7 @@ private[hive] trait HiveFunctionFactory { case p: java.lang.Byte => p case p: java.lang.Boolean => p case str: String => str - case p: BigDecimal => p + case p: java.math.BigDecimal => p case p: Array[Byte] => p case p: java.sql.Timestamp => p } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index d855310253bf3..9f1cd703103ec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -228,7 +228,7 @@ class HiveQuerySuite extends HiveComparisonTest { val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} TestHive.sparkContext.parallelize(fixture).registerAsTable("having_test") - val results = + val results = hql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() .map(x => Pair(x.getString(0), x.getInt(1))) @@ -236,7 +236,7 @@ class HiveQuerySuite extends HiveComparisonTest { assert(results === Array(Pair("foo", 4))) TestHive.reset() } - + test("SPARK-2180: HAVING with non-boolean clause raises no exceptions") { hql("select key, count(*) c from src group by key having c").collect() } @@ -370,6 +370,16 @@ class HiveQuerySuite extends HiveComparisonTest { } } + test("SPARK-2263: Insert Map values") { + hql("CREATE TABLE m(value MAP)") + hql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") + hql("SELECT * FROM m").collect().zip(hql("SELECT * FROM src LIMIT 10").collect()).map { + case (Row(map: Map[Int, String]), Row(key: Int, value: String)) => + assert(map.size === 1) + assert(map.head === (key, value)) + } + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" @@ -460,7 +470,6 @@ class HiveQuerySuite extends HiveComparisonTest { // Put tests that depend on specific Hive settings before these last two test, // since they modify /clear stuff. - } // for SPARK-2180 test diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index a9e3f42a3adfc..f944d010660eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -122,6 +122,3 @@ class PairUdf extends GenericUDF { override def getDisplayString(p1: Array[String]): String = "" } - - - diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 34434449a0d77..4d7c84f443879 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -26,6 +26,11 @@ import scala.collection.JavaConversions._ * A set of test cases that validate partition and column pruning. */ class PruningSuite extends HiveComparisonTest { + // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset + // the environment to ensure all referenced tables in this suites are not cached in-memory. + // Refer to https://issues.apache.org/jira/browse/SPARK-2283 for details. + TestHive.reset() + // Column pruning tests createPruningTest("Column pruning - with partitioned table", diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 8f2267599914c..556f49342977a 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -154,7 +154,7 @@ trait ClientBase extends Logging { } /** Copy the file into HDFS if needed. */ - private def copyRemoteFile( + private[yarn] def copyRemoteFile( dstDir: Path, originalPath: Path, replication: Short, @@ -213,10 +213,19 @@ trait ClientBase extends Logging { val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - Map( - ClientBase.SPARK_JAR -> ClientBase.getSparkJar, ClientBase.APP_JAR -> args.userJar, - ClientBase.LOG4J_PROP -> System.getenv(ClientBase.LOG4J_CONF_ENV_KEY) - ).foreach { case(destName, _localPath) => + val oldLog4jConf = Option(System.getenv("SPARK_LOG4J_CONF")) + if (oldLog4jConf.isDefined) { + logWarning( + "SPARK_LOG4J_CONF detected in the system environment. This variable has been " + + "deprecated. Please refer to the \"Launching Spark on YARN\" documentation " + + "for alternatives.") + } + + List( + (ClientBase.SPARK_JAR, ClientBase.sparkJar(sparkConf), ClientBase.CONF_SPARK_JAR), + (ClientBase.APP_JAR, args.userJar, ClientBase.CONF_SPARK_USER_JAR), + ("log4j.properties", oldLog4jConf.getOrElse(null), null) + ).foreach { case(destName, _localPath, confKey) => val localPath: String = if (_localPath != null) _localPath.trim() else "" if (! localPath.isEmpty()) { val localURI = new URI(localPath) @@ -225,6 +234,8 @@ trait ClientBase extends Logging { val destPath = copyRemoteFile(dst, qualifyForLocal(localURI), replication, setPermissions) distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, destName, statCache) + } else if (confKey != null) { + sparkConf.set(confKey, localPath) } } } @@ -246,6 +257,8 @@ trait ClientBase extends Logging { if (addToClasspath) { cachedSecondaryJarLinks += linkname } + } else if (addToClasspath) { + cachedSecondaryJarLinks += file.trim() } } } @@ -265,14 +278,10 @@ trait ClientBase extends Logging { val env = new HashMap[String, String]() val extraCp = sparkConf.getOption("spark.driver.extraClassPath") - val log4jConf = System.getenv(ClientBase.LOG4J_CONF_ENV_KEY) - ClientBase.populateClasspath(yarnConf, sparkConf, log4jConf, env, extraCp) + ClientBase.populateClasspath(args, yarnConf, sparkConf, env, extraCp) env("SPARK_YARN_MODE") = "true" env("SPARK_YARN_STAGING_DIR") = stagingDir env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() - if (log4jConf != null) { - env(ClientBase.LOG4J_CONF_ENV_KEY) = log4jConf - } // Set the environment variables to be passed on to the executors. distCacheMgr.setDistFilesEnv(env) @@ -285,7 +294,6 @@ trait ClientBase extends Logging { // Pass SPARK_YARN_USER_ENV itself to the AM so it can use it to set up executor environments. env("SPARK_YARN_USER_ENV") = userEnvs } - env } @@ -310,6 +318,37 @@ trait ClientBase extends Logging { logInfo("Setting up container launch context") val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) amContainer.setLocalResources(localResources) + + // In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to + // executors. But we can't just set spark.executor.extraJavaOptions, because the driver's + // SparkContext will not let that set spark* system properties, which is expected behavior for + // Yarn clients. So propagate it through the environment. + // + // Note that to warn the user about the deprecation in cluster mode, some code from + // SparkConf#validateSettings() is duplicated here (to avoid triggering the condition + // described above). + if (args.amClass == classOf[ApplicationMaster].getName) { + sys.env.get("SPARK_JAVA_OPTS").foreach { value => + val warning = + s""" + |SPARK_JAVA_OPTS was detected (set to '$value'). + |This is deprecated in Spark 1.0+. + | + |Please instead use: + | - ./spark-submit with conf/spark-defaults.conf to set defaults for an application + | - ./spark-submit with --driver-java-options to set -X options for a driver + | - spark.executor.extraJavaOptions to set -X options for executors + """.stripMargin + logWarning(warning) + for (proc <- Seq("driver", "executor")) { + val key = s"spark.$proc.extraJavaOptions" + if (sparkConf.contains(key)) { + throw new SparkException(s"Found both $key and SPARK_JAVA_OPTS. Use only the former.") + } + } + env("SPARK_JAVA_OPTS") = value + } + } amContainer.setEnvironment(env) val amMemory = calculateAMMemory(newApp) @@ -341,30 +380,20 @@ trait ClientBase extends Logging { javaOpts += "-XX:CMSIncrementalDutyCycle=10" } - // SPARK_JAVA_OPTS is deprecated, but for backwards compatibility: - sys.env.get("SPARK_JAVA_OPTS").foreach { opts => - sparkConf.set("spark.executor.extraJavaOptions", opts) - sparkConf.set("spark.driver.extraJavaOptions", opts) - } - + // Forward the Spark configuration to the application master / executors. // TODO: it might be nicer to pass these as an internal environment variable rather than // as Java options, due to complications with string parsing of nested quotes. - if (args.amClass == classOf[ExecutorLauncher].getName) { - // If we are being launched in client mode, forward the spark-conf options - // onto the executor launcher - for ((k, v) <- sparkConf.getAll) { - javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" - } - } else { - // If we are being launched in standalone mode, capture and forward any spark - // system properties (e.g. set by spark-class). - for ((k, v) <- sys.props.filterKeys(_.startsWith("spark"))) { - javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" - } - sys.props.get("spark.driver.extraJavaOptions").foreach(opts => javaOpts += opts) - sys.props.get("spark.driver.libraryPath").foreach(p => javaOpts += s"-Djava.library.path=$p") + for ((k, v) <- sparkConf.getAll) { + javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" + } + + if (args.amClass == classOf[ApplicationMaster].getName) { + sparkConf.getOption("spark.driver.extraJavaOptions") + .orElse(sys.env.get("SPARK_JAVA_OPTS")) + .foreach(opts => javaOpts += opts) + sparkConf.getOption("spark.driver.libraryPath") + .foreach(p => javaOpts += s"-Djava.library.path=$p") } - javaOpts += ClientBase.getLog4jConfiguration(localResources) // Command for the ApplicationMaster val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++ @@ -377,7 +406,10 @@ trait ClientBase extends Logging { "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout", "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") - logInfo("Command for starting the Spark ApplicationMaster: " + commands) + logInfo("Yarn AM launch context:") + logInfo(s" class: ${args.amClass}") + logInfo(s" env: $env") + logInfo(s" command: ${commands.mkString(" ")}") // TODO: it would be nicer to just make sure there are no null commands here val printableCommands = commands.map(s => if (s == null) "null" else s).toList @@ -391,12 +423,39 @@ trait ClientBase extends Logging { object ClientBase extends Logging { val SPARK_JAR: String = "__spark__.jar" val APP_JAR: String = "__app__.jar" - val LOG4J_PROP: String = "log4j.properties" - val LOG4J_CONF_ENV_KEY: String = "SPARK_LOG4J_CONF" val LOCAL_SCHEME = "local" + val CONF_SPARK_JAR = "spark.yarn.jar" + /** + * This is an internal config used to propagate the location of the user's jar file to the + * driver/executors. + */ + val CONF_SPARK_USER_JAR = "spark.yarn.user.jar" + /** + * This is an internal config used to propagate the list of extra jars to add to the classpath + * of executors. + */ val CONF_SPARK_YARN_SECONDARY_JARS = "spark.yarn.secondary.jars" + val ENV_SPARK_JAR = "SPARK_JAR" - def getSparkJar = sys.env.get("SPARK_JAR").getOrElse(SparkContext.jarOfClass(this.getClass).head) + /** + * Find the user-defined Spark jar if configured, or return the jar containing this + * class if not. + * + * This method first looks in the SparkConf object for the CONF_SPARK_JAR key, and in the + * user environment if that is not found (for backwards compatibility). + */ + def sparkJar(conf: SparkConf) = { + if (conf.contains(CONF_SPARK_JAR)) { + conf.get(CONF_SPARK_JAR) + } else if (System.getenv(ENV_SPARK_JAR) != null) { + logWarning( + s"$ENV_SPARK_JAR detected in the system environment. This variable has been deprecated " + + s"in favor of the $CONF_SPARK_JAR configuration variable.") + System.getenv(ENV_SPARK_JAR) + } else { + SparkContext.jarOfClass(this.getClass).head + } + } def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) = { val classPathElementsToAdd = getYarnAppClasspath(conf) ++ getMRAppClasspath(conf) @@ -469,71 +528,74 @@ object ClientBase extends Logging { triedDefault.toOption } + def populateClasspath(args: ClientArguments, conf: Configuration, sparkConf: SparkConf, + env: HashMap[String, String], extraClassPath: Option[String] = None) { + extraClassPath.foreach(addClasspathEntry(_, env)) + addClasspathEntry(Environment.PWD.$(), env) + + // Normally the users app.jar is last in case conflicts with spark jars + if (sparkConf.get("spark.yarn.user.classpath.first", "false").toBoolean) { + addUserClasspath(args, sparkConf, env) + addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env) + ClientBase.populateHadoopClasspath(conf, env) + } else { + addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env) + ClientBase.populateHadoopClasspath(conf, env) + addUserClasspath(args, sparkConf, env) + } + + // Append all jar files under the working directory to the classpath. + addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + "*", env); + } /** - * Returns the java command line argument for setting up log4j. If there is a log4j.properties - * in the given local resources, it is used, otherwise the SPARK_LOG4J_CONF environment variable - * is checked. + * Adds the user jars which have local: URIs (or alternate names, such as APP_JAR) explicitly + * to the classpath. */ - def getLog4jConfiguration(localResources: HashMap[String, LocalResource]): String = { - var log4jConf = LOG4J_PROP - if (!localResources.contains(log4jConf)) { - log4jConf = System.getenv(LOG4J_CONF_ENV_KEY) match { - case conf: String => - val confUri = new URI(conf) - if (ClientBase.LOCAL_SCHEME.equals(confUri.getScheme())) { - "file://" + confUri.getPath() - } else { - ClientBase.LOG4J_PROP - } - case null => "log4j-spark-container.properties" + private def addUserClasspath(args: ClientArguments, conf: SparkConf, + env: HashMap[String, String]) = { + if (args != null) { + addFileToClasspath(args.userJar, APP_JAR, env) + if (args.addJars != null) { + args.addJars.split(",").foreach { case file: String => + addFileToClasspath(file, null, env) + } } + } else { + val userJar = conf.get(CONF_SPARK_USER_JAR, null) + addFileToClasspath(userJar, APP_JAR, env) + + val cachedSecondaryJarLinks = conf.get(CONF_SPARK_YARN_SECONDARY_JARS, "").split(",") + cachedSecondaryJarLinks.foreach(jar => addFileToClasspath(jar, null, env)) } - " -Dlog4j.configuration=" + log4jConf } - def populateClasspath(conf: Configuration, sparkConf: SparkConf, log4jConf: String, - env: HashMap[String, String], extraClassPath: Option[String] = None) { - - if (log4jConf != null) { - // If a custom log4j config file is provided as a local: URI, add its parent directory to the - // classpath. Note that this only works if the custom config's file name is - // "log4j.properties". - val localPath = getLocalPath(log4jConf) - if (localPath != null) { - val parentPath = new File(localPath).getParent() - YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, parentPath, - File.pathSeparator) + /** + * Adds the given path to the classpath, handling "local:" URIs correctly. + * + * If an alternate name for the file is given, and it's not a "local:" file, the alternate + * name will be added to the classpath (relative to the job's work directory). + * + * If not a "local:" file and no alternate name, the environment is not modified. + * + * @param path Path to add to classpath (optional). + * @param fileName Alternate name for the file (optional). + * @param env Map holding the environment variables. + */ + private def addFileToClasspath(path: String, fileName: String, + env: HashMap[String, String]) : Unit = { + if (path != null) { + scala.util.control.Exception.ignoring(classOf[URISyntaxException]) { + val localPath = getLocalPath(path) + if (localPath != null) { + addClasspathEntry(localPath, env) + return + } } } - - /** Add entry to the classpath. */ - def addClasspathEntry(path: String) = YarnSparkHadoopUtil.addToEnvironment(env, - Environment.CLASSPATH.name, path, File.pathSeparator) - /** Add entry to the classpath. Interpreted as a path relative to the working directory. */ - def addPwdClasspathEntry(entry: String) = - addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + entry) - - extraClassPath.foreach(addClasspathEntry) - - val cachedSecondaryJarLinks = - sparkConf.getOption(CONF_SPARK_YARN_SECONDARY_JARS).getOrElse("").split(",") - .filter(_.nonEmpty) - // Normally the users app.jar is last in case conflicts with spark jars - if (sparkConf.get("spark.yarn.user.classpath.first", "false").toBoolean) { - addPwdClasspathEntry(APP_JAR) - cachedSecondaryJarLinks.foreach(addPwdClasspathEntry) - addPwdClasspathEntry(SPARK_JAR) - ClientBase.populateHadoopClasspath(conf, env) - } else { - addPwdClasspathEntry(SPARK_JAR) - ClientBase.populateHadoopClasspath(conf, env) - addPwdClasspathEntry(APP_JAR) - cachedSecondaryJarLinks.foreach(addPwdClasspathEntry) + if (fileName != null) { + addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + fileName, env); } - // Append all class files and jar files under the working directory to the classpath. - addClasspathEntry(Environment.PWD.$()) - addPwdClasspathEntry("*") } /** @@ -547,4 +609,8 @@ object ClientBase extends Logging { null } + private def addClasspathEntry(path: String, env: HashMap[String, String]) = + YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, path, + File.pathSeparator) + } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index 43dbb2464f929..4ba7133a959ed 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -55,10 +55,12 @@ trait ExecutorRunnableUtil extends Logging { sys.props.get("spark.executor.extraJavaOptions").foreach { opts => javaOpts += opts } + sys.env.get("SPARK_JAVA_OPTS").foreach { opts => + javaOpts += opts + } javaOpts += "-Djava.io.tmpdir=" + new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) - javaOpts += ClientBase.getLog4jConfiguration(localResources) // Certain configs need to be passed here because they are needed before the Executor // registers with the Scheduler and transfers the spark configs. Since the Executor backend @@ -166,13 +168,8 @@ trait ExecutorRunnableUtil extends Logging { def prepareEnvironment: HashMap[String, String] = { val env = new HashMap[String, String]() - val extraCp = sparkConf.getOption("spark.executor.extraClassPath") - val log4jConf = System.getenv(ClientBase.LOG4J_CONF_ENV_KEY) - ClientBase.populateClasspath(yarnConf, sparkConf, log4jConf, env, extraCp) - if (log4jConf != null) { - env(ClientBase.LOG4J_CONF_ENV_KEY) = log4jConf - } + ClientBase.populateClasspath(null, yarnConf, sparkConf, env, extraCp) // Allow users to specify some environment variables YarnSparkHadoopUtil.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"), diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 412dfe38d55eb..fd2694fe7278d 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -63,7 +63,7 @@ private[spark] class YarnClientSchedulerBackend( // variables. List(("--driver-memory", "SPARK_MASTER_MEMORY", "spark.master.memory"), ("--driver-memory", "SPARK_DRIVER_MEMORY", "spark.driver.memory"), - ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.worker.instances"), + ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.executor.instances"), ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"), ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"), ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala index 608c6e92624c6..686714dc36488 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala @@ -17,22 +17,31 @@ package org.apache.spark.deploy.yarn +import java.io.File import java.net.URI +import com.google.common.io.Files import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants.Environment - +import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse +import org.apache.hadoop.yarn.api.records.ContainerLaunchContext +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.mockito.Matchers._ +import org.mockito.Mockito._ import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers._ +import org.scalatest.Matchers import scala.collection.JavaConversions._ import scala.collection.mutable.{ HashMap => MutableHashMap } import scala.util.Try +import org.apache.spark.SparkConf +import org.apache.spark.util.Utils -class ClientBaseSuite extends FunSuite { +class ClientBaseSuite extends FunSuite with Matchers { test("default Yarn application classpath") { ClientBase.getDefaultYarnApplicationClasspath should be(Some(Fixtures.knownDefYarnAppCP)) @@ -68,6 +77,67 @@ class ClientBaseSuite extends FunSuite { } } + private val SPARK = "local:/sparkJar" + private val USER = "local:/userJar" + private val ADDED = "local:/addJar1,local:/addJar2,/addJar3" + + test("Local jar URIs") { + val conf = new Configuration() + val sparkConf = new SparkConf().set(ClientBase.CONF_SPARK_JAR, SPARK) + val env = new MutableHashMap[String, String]() + val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf) + + ClientBase.populateClasspath(args, conf, sparkConf, env, None) + + val cp = env("CLASSPATH").split(File.pathSeparator) + s"$SPARK,$USER,$ADDED".split(",").foreach({ entry => + val uri = new URI(entry) + if (ClientBase.LOCAL_SCHEME.equals(uri.getScheme())) { + cp should contain (uri.getPath()) + } else { + cp should not contain (uri.getPath()) + } + }) + cp should contain (Environment.PWD.$()) + cp should contain (s"${Environment.PWD.$()}${File.separator}*") + cp should not contain (ClientBase.SPARK_JAR) + cp should not contain (ClientBase.APP_JAR) + } + + test("Jar path propagation through SparkConf") { + val conf = new Configuration() + val sparkConf = new SparkConf().set(ClientBase.CONF_SPARK_JAR, SPARK) + val yarnConf = new YarnConfiguration() + val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf) + + val client = spy(new DummyClient(args, conf, sparkConf, yarnConf)) + doReturn(new Path("/")).when(client).copyRemoteFile(any(classOf[Path]), + any(classOf[Path]), anyShort(), anyBoolean()) + + var tempDir = Files.createTempDir(); + try { + client.prepareLocalResources(tempDir.getAbsolutePath()) + sparkConf.getOption(ClientBase.CONF_SPARK_USER_JAR) should be (Some(USER)) + + // The non-local path should be propagated by name only, since it will end up in the app's + // staging dir. + val expected = ADDED.split(",") + .map(p => { + val uri = new URI(p) + if (ClientBase.LOCAL_SCHEME == uri.getScheme()) { + p + } else { + Option(uri.getFragment()).getOrElse(new File(p).getName()) + } + }) + .mkString(",") + + sparkConf.getOption(ClientBase.CONF_SPARK_YARN_SECONDARY_JARS) should be (Some(expected)) + } finally { + Utils.deleteRecursively(tempDir) + } + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = @@ -109,4 +179,18 @@ class ClientBaseSuite extends FunSuite { def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B = Try(clazz.getField(field)).map(_.get(null).asInstanceOf[A]).toOption.map(mapTo).getOrElse(defaults) + private class DummyClient( + val args: ClientArguments, + val conf: Configuration, + val sparkConf: SparkConf, + val yarnConf: YarnConfiguration) extends ClientBase { + + override def calculateAMMemory(newApp: GetNewApplicationResponse): Int = + throw new UnsupportedOperationException() + + override def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = + throw new UnsupportedOperationException() + + } + } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 117b33f466f85..07ba0a4b30bd7 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -81,6 +81,7 @@ class ExecutorRunnable( val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores, localResources) + logInfo(s"Setting up executor with environment: $env") logInfo("Setting up executor with commands: " + commands) ctx.setCommands(commands)