Skip to content

Commit

Permalink
Merge branch 'master' of git://git.apache.org/spark into timeline-vie…
Browse files Browse the repository at this point in the history
…wer-feature
  • Loading branch information
sarutak committed Apr 6, 2015
2 parents dab7cc1 + 9fe4125 commit fcdab7d
Show file tree
Hide file tree
Showing 44 changed files with 946 additions and 763 deletions.
59 changes: 59 additions & 0 deletions bin/load-spark-env.cmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
@echo off

rem
rem Licensed to the Apache Software Foundation (ASF) under one or more
rem contributor license agreements. See the NOTICE file distributed with
rem this work for additional information regarding copyright ownership.
rem The ASF licenses this file to You under the Apache License, Version 2.0
rem (the "License"); you may not use this file except in compliance with
rem the License. You may obtain a copy of the License at
rem
rem http://www.apache.org/licenses/LICENSE-2.0
rem
rem Unless required by applicable law or agreed to in writing, software
rem distributed under the License is distributed on an "AS IS" BASIS,
rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
rem See the License for the specific language governing permissions and
rem limitations under the License.
rem

rem This script loads spark-env.cmd if it exists, and ensures it is only loaded once.
rem spark-env.cmd is loaded from SPARK_CONF_DIR if set, or within the current directory's
rem conf/ subdirectory.

if [%SPARK_ENV_LOADED%] == [] (
set SPARK_ENV_LOADED=1

if not [%SPARK_CONF_DIR%] == [] (
set user_conf_dir=%SPARK_CONF_DIR%
) else (
set user_conf_dir=%~dp0..\..\conf
)

call :LoadSparkEnv
)

rem Setting SPARK_SCALA_VERSION if not already set.

set ASSEMBLY_DIR2=%SPARK_HOME%/assembly/target/scala-2.11
set ASSEMBLY_DIR1=%SPARK_HOME%/assembly/target/scala-2.10

if [%SPARK_SCALA_VERSION%] == [] (

if exist %ASSEMBLY_DIR2% if exist %ASSEMBLY_DIR1% (
echo "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected."
echo "Either clean one of them or, set SPARK_SCALA_VERSION=2.11 in spark-env.cmd."
exit 1
)
if exist %ASSEMBLY_DIR2% (
set SPARK_SCALA_VERSION=2.11
) else (
set SPARK_SCALA_VERSION=2.10
)
)
exit /b 0

:LoadSparkEnv
if exist "%user_conf_dir%\spark-env.cmd" (
call "%user_conf_dir%\spark-env.cmd"
)
3 changes: 1 addition & 2 deletions bin/pyspark2.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ rem
rem Figure out where the Spark framework is installed
set SPARK_HOME=%~dp0..

rem Load environment variables from conf\spark-env.cmd, if it exists
if exist "%SPARK_HOME%\conf\spark-env.cmd" call "%SPARK_HOME%\conf\spark-env.cmd"
call %SPARK_HOME%\bin\load-spark-env.cmd

rem Figure out which Python to use.
if "x%PYSPARK_DRIVER_PYTHON%"=="x" (
Expand Down
3 changes: 1 addition & 2 deletions bin/run-example2.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ set FWDIR=%~dp0..\
rem Export this as SPARK_HOME
set SPARK_HOME=%FWDIR%

rem Load environment variables from conf\spark-env.cmd, if it exists
if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
call %SPARK_HOME%\bin\load-spark-env.cmd

rem Test that an argument was given
if not "x%1"=="x" goto arg_given
Expand Down
3 changes: 1 addition & 2 deletions bin/spark-class2.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ rem
rem Figure out where the Spark framework is installed
set SPARK_HOME=%~dp0..

rem Load environment variables from conf\spark-env.cmd, if it exists
if exist "%SPARK_HOME%\conf\spark-env.cmd" call "%SPARK_HOME%\conf\spark-env.cmd"
call %SPARK_HOME%\bin\load-spark-env.cmd

rem Test that an argument was given
if "x%1"=="x" (
Expand Down
66 changes: 43 additions & 23 deletions core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

package org.apache.spark

import scala.concurrent.duration._
import scala.collection.mutable
import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors}

import akka.actor.{Actor, Cancellable}
import scala.collection.mutable

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext}
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.scheduler.{SlaveLost, TaskScheduler}
import org.apache.spark.util.ActorLogReceive
import org.apache.spark.util.Utils

/**
* A heartbeat from executors to the driver. This is a shared message used by several internal
Expand All @@ -51,9 +51,11 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)
* Lives in the driver to receive heartbeats from executors..
*/
private[spark] class HeartbeatReceiver(sc: SparkContext)
extends Actor with ActorLogReceive with Logging {
extends ThreadSafeRpcEndpoint with Logging {

override val rpcEnv: RpcEnv = sc.env.rpcEnv

private var scheduler: TaskScheduler = null
private[spark] var scheduler: TaskScheduler = null

// executor ID -> timestamp of when the last heartbeat from this executor was received
private val executorLastSeen = new mutable.HashMap[String, Long]
Expand All @@ -69,34 +71,44 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
sc.conf.getOption("spark.network.timeoutInterval").map(_.toLong * 1000).
getOrElse(sc.conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000))

private var timeoutCheckingTask: Cancellable = null

override def preStart(): Unit = {
import context.dispatcher
timeoutCheckingTask = context.system.scheduler.schedule(0.seconds,
checkTimeoutIntervalMs.milliseconds, self, ExpireDeadHosts)
super.preStart()
private var timeoutCheckingTask: ScheduledFuture[_] = null

private val timeoutCheckingThread = Executors.newSingleThreadScheduledExecutor(
Utils.namedThreadFactory("heartbeat-timeout-checking-thread"))

private val killExecutorThread = Executors.newSingleThreadExecutor(
Utils.namedThreadFactory("kill-executor-thread"))

override def onStart(): Unit = {
timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
Option(self).foreach(_.send(ExpireDeadHosts))
}
}, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS)
}

override def receiveWithLogging: PartialFunction[Any, Unit] = {

override def receive: PartialFunction[Any, Unit] = {
case ExpireDeadHosts =>
expireDeadHosts()
case TaskSchedulerIsSet =>
scheduler = sc.taskScheduler
}

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) =>
if (scheduler != null) {
val unknownExecutor = !scheduler.executorHeartbeatReceived(
executorId, taskMetrics, blockManagerId)
val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
executorLastSeen(executorId) = System.currentTimeMillis()
sender ! response
context.reply(response)
} else {
// Because Executor will sleep several seconds before sending the first "Heartbeat", this
// case rarely happens. However, if it really happens, log it and ask the executor to
// register itself again.
logWarning(s"Dropping $heartbeat because TaskScheduler is not ready yet")
sender ! HeartbeatResponse(reregisterBlockManager = true)
context.reply(HeartbeatResponse(reregisterBlockManager = true))
}
case ExpireDeadHosts =>
expireDeadHosts()
}

private def expireDeadHosts(): Unit = {
Expand All @@ -109,17 +121,25 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
scheduler.executorLost(executorId, SlaveLost("Executor heartbeat " +
s"timed out after ${now - lastSeenMs} ms"))
if (sc.supportDynamicAllocation) {
sc.killExecutor(executorId)
// Asynchronously kill the executor to avoid blocking the current thread
killExecutorThread.submit(new Runnable {
override def run(): Unit = sc.killExecutor(executorId)
})
}
executorLastSeen.remove(executorId)
}
}
}

override def postStop(): Unit = {
override def onStop(): Unit = {
if (timeoutCheckingTask != null) {
timeoutCheckingTask.cancel()
timeoutCheckingTask.cancel(true)
}
super.postStop()
timeoutCheckingThread.shutdownNow()
killExecutorThread.shutdownNow()
}
}

object HeartbeatReceiver {
val ENDPOINT_NAME = "HeartbeatReceiver"
}
61 changes: 29 additions & 32 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@ import java.io._
import java.util.concurrent.ConcurrentHashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}

import scala.collection.mutable.{HashSet, HashMap, Map}
import scala.concurrent.Await
import scala.collection.mutable.{HashSet, Map}
import scala.collection.JavaConversions._
import scala.reflect.ClassTag

import akka.actor._
import akka.pattern.ask

import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.MetadataFetchFailedException
import org.apache.spark.storage.BlockManagerId
Expand All @@ -38,34 +36,35 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage

/** Actor class for MapOutputTrackerMaster */
private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf)
extends Actor with ActorLogReceive with Logging {
/** RpcEndpoint class for MapOutputTrackerMaster */
private[spark] class MapOutputTrackerMasterEndpoint(
override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf)
extends RpcEndpoint with Logging {
val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)

override def receiveWithLogging: PartialFunction[Any, Unit] = {
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GetMapOutputStatuses(shuffleId: Int) =>
val hostPort = sender.path.address.hostPort
val hostPort = context.sender.address.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
val serializedSize = mapOutputStatuses.size
if (serializedSize > maxAkkaFrameSize) {
val msg = s"Map output statuses were $serializedSize bytes which " +
s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."

/* For SPARK-1244 we'll opt for just logging an error and then throwing an exception.
* Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239)
* will ultimately remove this entire code path. */
/* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender.
* A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */
val exception = new SparkException(msg)
logError(msg, exception)
throw exception
context.sendFailure(exception)
} else {
context.reply(mapOutputStatuses)
}
sender ! mapOutputStatuses

case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
sender ! true
context.stop(self)
logInfo("MapOutputTrackerMasterEndpoint stopped!")
context.reply(true)
stop()
}
}

Expand All @@ -75,12 +74,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
* (driver and executor) use different HashMap to store its metadata.
*/
private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging {
private val timeout = AkkaUtils.askTimeout(conf)
private val retryAttempts = AkkaUtils.numRetries(conf)
private val retryIntervalMs = AkkaUtils.retryWaitMs(conf)

/** Set to the MapOutputTrackerActor living on the driver. */
var trackerActor: ActorRef = _
/** Set to the MapOutputTrackerMasterEndpoint living on the driver. */
var trackerEndpoint: RpcEndpointRef = _

/**
* This HashMap has different behavior for the driver and the executors.
Expand All @@ -105,22 +101,22 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
private val fetching = new HashSet[Int]

/**
* Send a message to the trackerActor and get its result within a default timeout, or
* Send a message to the trackerEndpoint and get its result within a default timeout, or
* throw a SparkException if this fails.
*/
protected def askTracker(message: Any): Any = {
protected def askTracker[T: ClassTag](message: Any): T = {
try {
AkkaUtils.askWithReply(message, trackerActor, retryAttempts, retryIntervalMs, timeout)
trackerEndpoint.askWithReply[T](message)
} catch {
case e: Exception =>
logError("Error communicating with MapOutputTracker", e)
throw new SparkException("Error communicating with MapOutputTracker", e)
}
}

/** Send a one-way message to the trackerActor, to which we expect it to reply with true. */
/** Send a one-way message to the trackerEndpoint, to which we expect it to reply with true. */
protected def sendTracker(message: Any) {
val response = askTracker(message)
val response = askTracker[Boolean](message)
if (response != true) {
throw new SparkException(
"Error reply received from MapOutputTracker. Expecting true, got " + response.toString)
Expand Down Expand Up @@ -157,11 +153,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging

if (fetchedStatuses == null) {
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
// This try-finally prevents hangs due to timeouts:
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]]
val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
Expand Down Expand Up @@ -328,7 +323,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
override def stop() {
sendTracker(StopMapOutputTracker)
mapStatuses.clear()
trackerActor = null
trackerEndpoint = null
metadataCleaner.cancel()
cachedSerializedStatuses.clear()
}
Expand All @@ -350,6 +345,8 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr

private[spark] object MapOutputTracker extends Logging {

val ENDPOINT_NAME = "MapOutputTracker"

// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
Expand Down
Loading

0 comments on commit fcdab7d

Please sign in to comment.