diff --git a/.gitignore b/.gitignore
index a31bf7e0091f4..34939e3a97aaa 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,9 +1,12 @@
*~
+*.#*
+*#*#
*.swp
*.ipr
*.iml
*.iws
.idea/
+.idea_modules/
sbt/*.jar
.settings
.cache
@@ -16,9 +19,11 @@ third_party/libmesos.so
third_party/libmesos.dylib
conf/java-opts
conf/*.sh
+conf/*.cmd
conf/*.properties
conf/*.conf
conf/*.xml
+conf/slaves
docs/_site
docs/api
target/
diff --git a/.rat-excludes b/.rat-excludes
index fb6323daf9211..b14ad53720f32 100644
--- a/.rat-excludes
+++ b/.rat-excludes
@@ -19,7 +19,9 @@ log4j.properties
log4j.properties.template
metrics.properties.template
slaves
+slaves.template
spark-env.sh
+spark-env.cmd
spark-env.sh.template
log4j-defaults.properties
bootstrap-tooltip.js
@@ -58,3 +60,4 @@ dist/*
.*iws
logs
.*scalastyle-output.xml
+.*dependency-reduced-pom.xml
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 604b1ab3de6a8..31a01e4d8e1de 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -141,7 +141,9 @@
com.google.common.**
- com.google.common.base.Optional**
+ com/google/common/base/Absent*
+ com/google/common/base/Optional*
+ com/google/common/base/Present*
@@ -347,5 +349,15 @@
+
+ kinesis-asl
+
+
+ org.apache.httpcomponents
+ httpclient
+ ${commons.httpclient.version}
+
+
+
diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties
index 30b4baa4d714a..789869f72e3b0 100644
--- a/bagel/src/test/resources/log4j.properties
+++ b/bagel/src/test/resources/log4j.properties
@@ -21,7 +21,7 @@ log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
-log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd
index 5ad52452a5c98..3cd0579aea8d3 100644
--- a/bin/compute-classpath.cmd
+++ b/bin/compute-classpath.cmd
@@ -36,7 +36,13 @@ rem Load environment variables from conf\spark-env.cmd, if it exists
if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
rem Build up classpath
-set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH%;%FWDIR%conf
+set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH%
+
+if not "x%SPARK_CONF_DIR%"=="x" (
+ set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR%
+) else (
+ set CLASSPATH=%CLASSPATH%;%FWDIR%conf
+)
if exist "%FWDIR%RELEASE" (
for %%d in ("%FWDIR%lib\spark-assembly*.jar") do (
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
index 0f63e36d8aeca..905bbaf99b374 100755
--- a/bin/compute-classpath.sh
+++ b/bin/compute-classpath.sh
@@ -27,8 +27,14 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
. "$FWDIR"/bin/load-spark-env.sh
+CLASSPATH="$SPARK_CLASSPATH:$SPARK_SUBMIT_CLASSPATH"
+
# Build up classpath
-CLASSPATH="$SPARK_CLASSPATH:$SPARK_SUBMIT_CLASSPATH:$FWDIR/conf"
+if [ -n "$SPARK_CONF_DIR" ]; then
+ CLASSPATH="$CLASSPATH:$SPARK_CONF_DIR"
+else
+ CLASSPATH="$CLASSPATH:$FWDIR/conf"
+fi
ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SCALA_VERSION"
diff --git a/bin/pyspark b/bin/pyspark
index 5142411e36974..6655725ef8e8e 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -52,10 +52,20 @@ fi
# Figure out which Python executable to use
if [[ -z "$PYSPARK_PYTHON" ]]; then
- PYSPARK_PYTHON="python"
+ if [[ "$IPYTHON" = "1" || -n "$IPYTHON_OPTS" ]]; then
+ # for backward compatibility
+ PYSPARK_PYTHON="ipython"
+ else
+ PYSPARK_PYTHON="python"
+ fi
fi
export PYSPARK_PYTHON
+if [[ -z "$PYSPARK_PYTHON_OPTS" && -n "$IPYTHON_OPTS" ]]; then
+ # for backward compatibility
+ PYSPARK_PYTHON_OPTS="$IPYTHON_OPTS"
+fi
+
# Add the PySpark classes to the Python path:
export PYTHONPATH="$SPARK_HOME/python/:$PYTHONPATH"
export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH"
@@ -64,11 +74,6 @@ export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH"
export OLD_PYTHONSTARTUP="$PYTHONSTARTUP"
export PYTHONSTARTUP="$FWDIR/python/pyspark/shell.py"
-# If IPython options are specified, assume user wants to run IPython
-if [[ -n "$IPYTHON_OPTS" ]]; then
- IPYTHON=1
-fi
-
# Build up arguments list manually to preserve quotes and backslashes.
# We export Spark submit arguments as an environment variable because shell.py must run as a
# PYTHONSTARTUP script, which does not take in arguments. This is required for IPython notebooks.
@@ -106,10 +111,5 @@ if [[ "$1" =~ \.py$ ]]; then
else
# PySpark shell requires special handling downstream
export PYSPARK_SHELL=1
- # Only use ipython if no command line arguments were provided [SPARK-1134]
- if [[ "$IPYTHON" = "1" ]]; then
- exec ${PYSPARK_PYTHON:-ipython} $IPYTHON_OPTS
- else
- exec "$PYSPARK_PYTHON"
- fi
+ exec "$PYSPARK_PYTHON" $PYSPARK_PYTHON_OPTS
fi
diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd
index 2c4b08af8d4c3..a0e66abcc26c9 100644
--- a/bin/pyspark2.cmd
+++ b/bin/pyspark2.cmd
@@ -33,7 +33,7 @@ for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*
)
if [%FOUND_JAR%] == [0] (
echo Failed to find Spark assembly JAR.
- echo You need to build Spark with sbt\sbt assembly before running this program.
+ echo You need to build Spark before running this program.
goto exit
)
:skip_build_test
diff --git a/bin/run-example2.cmd b/bin/run-example2.cmd
index b29bf90c64e90..b49d0dcb4ff2d 100644
--- a/bin/run-example2.cmd
+++ b/bin/run-example2.cmd
@@ -52,7 +52,7 @@ if exist "%FWDIR%RELEASE" (
)
if "x%SPARK_EXAMPLES_JAR%"=="x" (
echo Failed to find Spark examples assembly JAR.
- echo You need to build Spark with sbt\sbt assembly before running this program.
+ echo You need to build Spark before running this program.
goto exit
)
diff --git a/bin/spark-class b/bin/spark-class
index 613dc9c4566f2..e8201c18d52de 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -146,7 +146,7 @@ fi
if [[ "$1" =~ org.apache.spark.tools.* ]]; then
if test -z "$SPARK_TOOLS_JAR"; then
echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SCALA_VERSION/" 1>&2
- echo "You need to build spark before running $1." 1>&2
+ echo "You need to build Spark before running $1." 1>&2
exit 1
fi
CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR"
diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd
index 6c5672819172b..da46543647efd 100644
--- a/bin/spark-class2.cmd
+++ b/bin/spark-class2.cmd
@@ -104,7 +104,7 @@ for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*
)
if "%FOUND_JAR%"=="0" (
echo Failed to find Spark assembly JAR.
- echo You need to build Spark with sbt\sbt assembly before running this program.
+ echo You need to build Spark before running this program.
goto exit
)
:skip_build_test
diff --git a/bin/spark-sql b/bin/spark-sql
index ae096530cad04..63d00437d508d 100755
--- a/bin/spark-sql
+++ b/bin/spark-sql
@@ -24,7 +24,6 @@
set -o posix
CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver"
-CLASS_NOT_FOUND_EXIT_STATUS=1
# Figure out where Spark is installed
FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
@@ -53,13 +52,4 @@ source "$FWDIR"/bin/utils.sh
SUBMIT_USAGE_FUNCTION=usage
gatherSparkSubmitOpts "$@"
-"$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_OPTS[@]}" spark-internal "${APPLICATION_OPTS[@]}"
-exit_status=$?
-
-if [[ exit_status -eq CLASS_NOT_FOUND_EXIT_STATUS ]]; then
- echo
- echo "Failed to load Spark SQL CLI main class $CLASS."
- echo "You need to build Spark with -Phive."
-fi
-
-exit $exit_status
+exec "$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_OPTS[@]}" spark-internal "${APPLICATION_OPTS[@]}"
diff --git a/bin/utils.sh b/bin/utils.sh
index 0804b1ed9f231..22ea2b9a6d586 100755
--- a/bin/utils.sh
+++ b/bin/utils.sh
@@ -17,7 +17,7 @@
# limitations under the License.
#
-# Gather all all spark-submit options into SUBMISSION_OPTS
+# Gather all spark-submit options into SUBMISSION_OPTS
function gatherSparkSubmitOpts() {
if [ -z "$SUBMIT_USAGE_FUNCTION" ]; then
diff --git a/conf/slaves b/conf/slaves.template
similarity index 100%
rename from conf/slaves
rename to conf/slaves.template
diff --git a/core/pom.xml b/core/pom.xml
index 2a81f6df289c0..a5a178079bc57 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -322,6 +322,17 @@
+
+ maven-clean-plugin
+
+
+
+ ${basedir}/../python/build
+
+
+ true
+
+ org.apache.maven.pluginsmaven-shade-plugin
@@ -343,7 +354,9 @@
com.google.guava:guava
+ com/google/common/base/Absent*com/google/common/base/Optional*
+ com/google/common/base/Present*
diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java
new file mode 100644
index 0000000000000..4e6d708af0ea7
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/TaskContext.java
@@ -0,0 +1,269 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import scala.Function0;
+import scala.Function1;
+import scala.Unit;
+import scala.collection.JavaConversions;
+
+import org.apache.spark.annotation.DeveloperApi;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.util.TaskCompletionListener;
+import org.apache.spark.util.TaskCompletionListenerException;
+
+/**
+* :: DeveloperApi ::
+* Contextual information about a task which can be read or mutated during execution.
+*/
+@DeveloperApi
+public class TaskContext implements Serializable {
+
+ private int stageId;
+ private int partitionId;
+ private long attemptId;
+ private boolean runningLocally;
+ private TaskMetrics taskMetrics;
+
+ /**
+ * :: DeveloperApi ::
+ * Contextual information about a task which can be read or mutated during execution.
+ *
+ * @param stageId stage id
+ * @param partitionId index of the partition
+ * @param attemptId the number of attempts to execute this task
+ * @param runningLocally whether the task is running locally in the driver JVM
+ * @param taskMetrics performance metrics of the task
+ */
+ @DeveloperApi
+ public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally,
+ TaskMetrics taskMetrics) {
+ this.attemptId = attemptId;
+ this.partitionId = partitionId;
+ this.runningLocally = runningLocally;
+ this.stageId = stageId;
+ this.taskMetrics = taskMetrics;
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Contextual information about a task which can be read or mutated during execution.
+ *
+ * @param stageId stage id
+ * @param partitionId index of the partition
+ * @param attemptId the number of attempts to execute this task
+ * @param runningLocally whether the task is running locally in the driver JVM
+ */
+ @DeveloperApi
+ public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally) {
+ this.attemptId = attemptId;
+ this.partitionId = partitionId;
+ this.runningLocally = runningLocally;
+ this.stageId = stageId;
+ this.taskMetrics = TaskMetrics.empty();
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Contextual information about a task which can be read or mutated during execution.
+ *
+ * @param stageId stage id
+ * @param partitionId index of the partition
+ * @param attemptId the number of attempts to execute this task
+ */
+ @DeveloperApi
+ public TaskContext(int stageId, int partitionId, long attemptId) {
+ this.attemptId = attemptId;
+ this.partitionId = partitionId;
+ this.runningLocally = false;
+ this.stageId = stageId;
+ this.taskMetrics = TaskMetrics.empty();
+ }
+
+ private static ThreadLocal taskContext =
+ new ThreadLocal();
+
+ /**
+ * :: Internal API ::
+ * This is spark internal API, not intended to be called from user programs.
+ */
+ public static void setTaskContext(TaskContext tc) {
+ taskContext.set(tc);
+ }
+
+ public static TaskContext get() {
+ return taskContext.get();
+ }
+
+ /** :: Internal API :: */
+ public static void unset() {
+ taskContext.remove();
+ }
+
+ // List of callback functions to execute when the task completes.
+ private transient List onCompleteCallbacks =
+ new ArrayList();
+
+ // Whether the corresponding task has been killed.
+ private volatile boolean interrupted = false;
+
+ // Whether the task has completed.
+ private volatile boolean completed = false;
+
+ /**
+ * Checks whether the task has completed.
+ */
+ public boolean isCompleted() {
+ return completed;
+ }
+
+ /**
+ * Checks whether the task has been killed.
+ */
+ public boolean isInterrupted() {
+ return interrupted;
+ }
+
+ /**
+ * Add a (Java friendly) listener to be executed on task completion.
+ * This will be called in all situation - success, failure, or cancellation.
+ *
+ * An example use is for HadoopRDD to register a callback to close the input stream.
+ */
+ public TaskContext addTaskCompletionListener(TaskCompletionListener listener) {
+ onCompleteCallbacks.add(listener);
+ return this;
+ }
+
+ /**
+ * Add a listener in the form of a Scala closure to be executed on task completion.
+ * This will be called in all situations - success, failure, or cancellation.
+ *
+ * An example use is for HadoopRDD to register a callback to close the input stream.
+ */
+ public TaskContext addTaskCompletionListener(final Function1 f) {
+ onCompleteCallbacks.add(new TaskCompletionListener() {
+ @Override
+ public void onTaskCompletion(TaskContext context) {
+ f.apply(context);
+ }
+ });
+ return this;
+ }
+
+ /**
+ * Add a callback function to be executed on task completion. An example use
+ * is for HadoopRDD to register a callback to close the input stream.
+ * Will be called in any situation - success, failure, or cancellation.
+ *
+ * Deprecated: use addTaskCompletionListener
+ *
+ * @param f Callback function.
+ */
+ @Deprecated
+ public void addOnCompleteCallback(final Function0 f) {
+ onCompleteCallbacks.add(new TaskCompletionListener() {
+ @Override
+ public void onTaskCompletion(TaskContext context) {
+ f.apply();
+ }
+ });
+ }
+
+ /**
+ * ::Internal API::
+ * Marks the task as completed and triggers the listeners.
+ */
+ public void markTaskCompleted() throws TaskCompletionListenerException {
+ completed = true;
+ List errorMsgs = new ArrayList(2);
+ // Process complete callbacks in the reverse order of registration
+ List revlist =
+ new ArrayList(onCompleteCallbacks);
+ Collections.reverse(revlist);
+ for (TaskCompletionListener tcl: revlist) {
+ try {
+ tcl.onTaskCompletion(this);
+ } catch (Throwable e) {
+ errorMsgs.add(e.getMessage());
+ }
+ }
+
+ if (!errorMsgs.isEmpty()) {
+ throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs));
+ }
+ }
+
+ /**
+ * ::Internal API::
+ * Marks the task for interruption, i.e. cancellation.
+ */
+ public void markInterrupted() {
+ interrupted = true;
+ }
+
+ @Deprecated
+ /** Deprecated: use getStageId() */
+ public int stageId() {
+ return stageId;
+ }
+
+ @Deprecated
+ /** Deprecated: use getPartitionId() */
+ public int partitionId() {
+ return partitionId;
+ }
+
+ @Deprecated
+ /** Deprecated: use getAttemptId() */
+ public long attemptId() {
+ return attemptId;
+ }
+
+ @Deprecated
+ /** Deprecated: use isRunningLocally() */
+ public boolean runningLocally() {
+ return runningLocally;
+ }
+
+ public boolean isRunningLocally() {
+ return runningLocally;
+ }
+
+ public int getStageId() {
+ return stageId;
+ }
+
+ public int getPartitionId() {
+ return partitionId;
+ }
+
+ public long getAttemptId() {
+ return attemptId;
+ }
+
+ /** ::Internal API:: */
+ public TaskMetrics taskMetrics() {
+ return taskMetrics;
+ }
+}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index 445110d63e184..152bde5f6994f 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -51,6 +51,11 @@ table.sortable thead {
cursor: pointer;
}
+table.sortable td {
+ word-wrap: break-word;
+ max-width: 600px;
+}
+
.progress {
margin-bottom: 0px; position: relative
}
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index f8584b90cabe6..d89bb50076c9a 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -168,8 +168,6 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
arr.iterator.asInstanceOf[Iterator[T]]
case Right(it) =>
// There is not enough space to cache this partition in memory
- logWarning(s"Not enough space to cache partition $key in memory! " +
- s"Free memory is ${blockManager.memoryStore.freeMemory} bytes.")
val returnValues = it.asInstanceOf[Iterator[T]]
if (putLevel.useDisk) {
logWarning(s"Persisting partition $key to disk instead.")
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 75ea535f2f57b..e8f761eaa5799 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -83,6 +83,15 @@ trait FutureAction[T] extends Future[T] {
*/
@throws(classOf[Exception])
def get(): T = Await.result(this, Duration.Inf)
+
+ /**
+ * Returns the job IDs run by the underlying async operation.
+ *
+ * This returns the current snapshot of the job list. Certain operations may run multiple
+ * jobs, so multiple calls to this method may return different lists.
+ */
+ def jobIds: Seq[Int]
+
}
@@ -150,8 +159,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
}
}
- /** Get the corresponding job id for this action. */
- def jobId = jobWaiter.jobId
+ def jobIds = Seq(jobWaiter.jobId)
}
@@ -171,6 +179,8 @@ class ComplexFutureAction[T] extends FutureAction[T] {
// is cancelled before the action was even run (and thus we have no thread to interrupt).
@volatile private var _cancelled: Boolean = false
+ @volatile private var jobs: Seq[Int] = Nil
+
// A promise used to signal the future.
private val p = promise[T]()
@@ -219,6 +229,8 @@ class ComplexFutureAction[T] extends FutureAction[T] {
}
}
+ this.jobs = jobs ++ job.jobIds
+
// Wait for the job to complete. If the action is cancelled (with an interrupt),
// cancel the job and stop the execution. This is not in a synchronized block because
// Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
@@ -255,4 +267,7 @@ class ComplexFutureAction[T] extends FutureAction[T] {
override def isCompleted: Boolean = p.isCompleted
override def value: Option[Try[T]] = p.future.value
+
+ def jobIds = jobs
+
}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 51705c895a55c..4cb0bd4142435 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -18,10 +18,12 @@
package org.apache.spark
import java.io._
+import java.util.concurrent.ConcurrentHashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.{HashSet, HashMap, Map}
import scala.concurrent.Await
+import scala.collection.JavaConversions._
import akka.actor._
import akka.pattern.ask
@@ -84,6 +86,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
* On the master, it serves as the source of map outputs recorded from ShuffleMapTasks.
* On the workers, it simply serves as a cache, in which a miss triggers a fetch from the
* master's corresponding HashMap.
+ *
+ * Note: because mapStatuses is accessed concurrently, subclasses should make sure it's a
+ * thread-safe map.
*/
protected val mapStatuses: Map[Int, Array[MapStatus]]
@@ -339,11 +344,11 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
* MapOutputTrackerMaster.
*/
private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {
- protected val mapStatuses = new HashMap[Int, Array[MapStatus]]
+ protected val mapStatuses: Map[Int, Array[MapStatus]] =
+ new ConcurrentHashMap[Int, Array[MapStatus]]
}
private[spark] object MapOutputTracker {
- private val LOG_BASE = 1.1
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
@@ -379,34 +384,8 @@ private[spark] object MapOutputTracker {
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId)
} else {
- (status.location, decompressSize(status.compressedSizes(reduceId)))
+ (status.location, status.getSizeForBlock(reduceId))
}
}
}
-
- /**
- * Compress a size in bytes to 8 bits for efficient reporting of map output sizes.
- * We do this by encoding the log base 1.1 of the size as an integer, which can support
- * sizes up to 35 GB with at most 10% error.
- */
- def compressSize(size: Long): Byte = {
- if (size == 0) {
- 0
- } else if (size <= 1L) {
- 1
- } else {
- math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte
- }
- }
-
- /**
- * Decompress an 8-bit encoded block size, using the reverse operation of compressSize.
- */
- def decompressSize(compressedSize: Byte): Long = {
- if (compressedSize == 0) {
- 0
- } else {
- math.pow(LOG_BASE, compressedSize & 0xFF).toLong
- }
- }
}
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 3832a780ec4bc..0e0f1a7b2377e 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -103,10 +103,9 @@ import org.apache.spark.deploy.SparkHadoopUtil
* and a Server, so for a particular connection is has to determine what to do.
* A ConnectionId was added to be able to track connections and is used to
* match up incoming messages with connections waiting for authentication.
- * If its acting as a client and trying to send a message to another ConnectionManager,
- * it blocks the thread calling sendMessage until the SASL negotiation has occurred.
* The ConnectionManager tracks all the sendingConnections using the ConnectionId
- * and waits for the response from the server and does the handshake.
+ * and waits for the response from the server and does the handshake before sending
+ * the real message.
*
* - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
* can be used. Yarn requires a specific AmIpFilter be installed for security to work
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 428f019b02a23..396cdd1247e07 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -187,6 +187,15 @@ class SparkContext(config: SparkConf) extends Logging {
val master = conf.get("spark.master")
val appName = conf.get("spark.app.name")
+ private[spark] val isEventLogEnabled = conf.getBoolean("spark.eventLog.enabled", false)
+ private[spark] val eventLogDir: Option[String] = {
+ if (isEventLogEnabled) {
+ Some(conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR).stripSuffix("/"))
+ } else {
+ None
+ }
+ }
+
// Generate the random name for a temp folder in Tachyon
// Add a timestamp as the suffix here to make it more safe
val tachyonFolderName = "spark-" + randomUUID.toString()
@@ -200,6 +209,7 @@ class SparkContext(config: SparkConf) extends Logging {
private[spark] val listenerBus = new LiveListenerBus
// Create the Spark execution environment (cache, map output tracker, etc)
+ conf.set("spark.executor.id", "driver")
private[spark] val env = SparkEnv.create(
conf,
"",
@@ -232,19 +242,6 @@ class SparkContext(config: SparkConf) extends Logging {
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf)
- // Optionally log Spark events
- private[spark] val eventLogger: Option[EventLoggingListener] = {
- if (conf.getBoolean("spark.eventLog.enabled", false)) {
- val logger = new EventLoggingListener(appName, conf, hadoopConfiguration)
- logger.start()
- listenerBus.addListener(logger)
- Some(logger)
- } else None
- }
-
- // At this point, all relevant SparkListeners have been registered, so begin releasing events
- listenerBus.start()
-
val startTime = System.currentTimeMillis()
// Add each JAR given through the constructor
@@ -309,6 +306,29 @@ class SparkContext(config: SparkConf) extends Logging {
// constructor
taskScheduler.start()
+ val applicationId: String = taskScheduler.applicationId()
+ conf.set("spark.app.id", applicationId)
+
+ val metricsSystem = env.metricsSystem
+
+ // The metrics system for Driver need to be set spark.app.id to app ID.
+ // So it should start after we get app ID from the task scheduler and set spark.app.id.
+ metricsSystem.start()
+
+ // Optionally log Spark events
+ private[spark] val eventLogger: Option[EventLoggingListener] = {
+ if (isEventLogEnabled) {
+ val logger =
+ new EventLoggingListener(applicationId, eventLogDir.get, conf, hadoopConfiguration)
+ logger.start()
+ listenerBus.addListener(logger)
+ Some(logger)
+ } else None
+ }
+
+ // At this point, all relevant SparkListeners have been registered, so begin releasing events
+ listenerBus.start()
+
private[spark] val cleaner: Option[ContextCleaner] = {
if (conf.getBoolean("spark.cleaner.referenceTracking", true)) {
Some(new ContextCleaner(this))
@@ -411,8 +431,8 @@ class SparkContext(config: SparkConf) extends Logging {
// Post init
taskScheduler.postStartHook()
- private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
- private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
+ private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler)
+ private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager)
private def initDriverMetrics() {
SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource)
@@ -759,20 +779,20 @@ class SparkContext(config: SparkConf) extends Logging {
/**
* Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values
* with `+=`. Only the driver can access the accumuable's `value`.
- * @tparam T accumulator type
- * @tparam R type that can be added to the accumulator
+ * @tparam R accumulator result type
+ * @tparam T type that can be added to the accumulator
*/
- def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) =
+ def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) =
new Accumulable(initialValue, param)
/**
* Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the
* Spark UI. Tasks can add values to the accumuable using the `+=` operator. Only the driver can
* access the accumuable's `value`.
- * @tparam T accumulator type
- * @tparam R type that can be added to the accumulator
+ * @tparam R accumulator result type
+ * @tparam T type that can be added to the accumulator
*/
- def accumulable[T, R](initialValue: T, name: String)(implicit param: AccumulableParam[T, R]) =
+ def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) =
new Accumulable(initialValue, param, Some(name))
/**
@@ -1030,28 +1050,40 @@ class SparkContext(config: SparkConf) extends Logging {
}
/**
- * Support function for API backtraces.
+ * Set the thread-local property for overriding the call sites
+ * of actions and RDDs.
+ */
+ def setCallSite(shortCallSite: String) {
+ setLocalProperty(CallSite.SHORT_FORM, shortCallSite)
+ }
+
+ /**
+ * Set the thread-local property for overriding the call sites
+ * of actions and RDDs.
*/
- def setCallSite(site: String) {
- setLocalProperty("externalCallSite", site)
+ private[spark] def setCallSite(callSite: CallSite) {
+ setLocalProperty(CallSite.SHORT_FORM, callSite.shortForm)
+ setLocalProperty(CallSite.LONG_FORM, callSite.longForm)
}
/**
- * Support function for API backtraces.
+ * Clear the thread-local property for overriding the call sites
+ * of actions and RDDs.
*/
def clearCallSite() {
- setLocalProperty("externalCallSite", null)
+ setLocalProperty(CallSite.SHORT_FORM, null)
+ setLocalProperty(CallSite.LONG_FORM, null)
}
/**
* Capture the current user callsite and return a formatted version for printing. If the user
- * has overridden the call site, this will return the user's version.
+ * has overridden the call site using `setCallSite()`, this will return the user's version.
*/
private[spark] def getCallSite(): CallSite = {
- Option(getLocalProperty("externalCallSite")) match {
- case Some(callSite) => CallSite(callSite, longForm = "")
- case None => Utils.getCallSite
- }
+ Option(getLocalProperty(CallSite.SHORT_FORM)).map { case shortCallSite =>
+ val longCallSite = Option(getLocalProperty(CallSite.LONG_FORM)).getOrElse("")
+ CallSite(shortCallSite, longCallSite)
+ }.getOrElse(Utils.getCallSite())
}
/**
@@ -1266,7 +1298,7 @@ class SparkContext(config: SparkConf) extends Logging {
private def postApplicationStart() {
// Note: this code assumes that the task scheduler has been initialized and has contacted
// the cluster manager to get an application ID (in case the cluster manager provides one).
- listenerBus.post(SparkListenerApplicationStart(appName, taskScheduler.applicationId(),
+ listenerBus.post(SparkListenerApplicationStart(appName, Some(applicationId),
startTime, sparkUser))
}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 009ed64775844..aba713cb4267a 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -43,9 +43,8 @@ import org.apache.spark.util.{AkkaUtils, Utils}
* :: DeveloperApi ::
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
* including the serializer, Akka actor system, block manager, map output tracker, etc. Currently
- * Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these
- * objects needs to have the right SparkEnv set. You can get the current environment with
- * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set.
+ * Spark code finds the SparkEnv through a global variable, so all the threads can access the same
+ * SparkEnv. It can be accessed by SparkEnv.get (e.g. after creating a SparkContext).
*
* NOTE: This is not intended for external use. This is exposed for Shark and may be made private
* in a future release.
@@ -119,30 +118,28 @@ class SparkEnv (
}
object SparkEnv extends Logging {
- private val env = new ThreadLocal[SparkEnv]
- @volatile private var lastSetSparkEnv : SparkEnv = _
+ @volatile private var env: SparkEnv = _
private[spark] val driverActorSystemName = "sparkDriver"
private[spark] val executorActorSystemName = "sparkExecutor"
def set(e: SparkEnv) {
- lastSetSparkEnv = e
- env.set(e)
+ env = e
}
/**
- * Returns the ThreadLocal SparkEnv, if non-null. Else returns the SparkEnv
- * previously set in any thread.
+ * Returns the SparkEnv.
*/
def get: SparkEnv = {
- Option(env.get()).getOrElse(lastSetSparkEnv)
+ env
}
/**
* Returns the ThreadLocal SparkEnv.
*/
+ @deprecated("Use SparkEnv.get instead", "1.2")
def getThreadLocal: SparkEnv = {
- env.get()
+ env
}
private[spark] def create(
@@ -259,11 +256,15 @@ object SparkEnv extends Logging {
}
val metricsSystem = if (isDriver) {
+ // Don't start metrics system right now for Driver.
+ // We need to wait for the task scheduler to give us an app ID.
+ // Then we can start the metrics system.
MetricsSystem.createMetricsSystem("driver", conf, securityManager)
} else {
- MetricsSystem.createMetricsSystem("executor", conf, securityManager)
+ val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager)
+ ms.start()
+ ms
}
- metricsSystem.start()
// Set the sparkFiles directory, used when downloading dependencies. In local mode,
// this is a temporary directory; in distributed mode, this is the executor's current working
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index f6703986bdf11..376e69cd997d5 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -116,7 +116,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
}
}
} else {
- logWarning ("No need to commit output of task: " + taID.value)
+ logInfo ("No need to commit output of task: " + taID.value)
}
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
deleted file mode 100644
index 51b3e4d5e0936..0000000000000
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener}
-
-
-/**
- * :: DeveloperApi ::
- * Contextual information about a task which can be read or mutated during execution.
- *
- * @param stageId stage id
- * @param partitionId index of the partition
- * @param attemptId the number of attempts to execute this task
- * @param runningLocally whether the task is running locally in the driver JVM
- * @param taskMetrics performance metrics of the task
- */
-@DeveloperApi
-class TaskContext(
- val stageId: Int,
- val partitionId: Int,
- val attemptId: Long,
- val runningLocally: Boolean = false,
- private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty)
- extends Serializable with Logging {
-
- @deprecated("use partitionId", "0.8.1")
- def splitId = partitionId
-
- // List of callback functions to execute when the task completes.
- @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
-
- // Whether the corresponding task has been killed.
- @volatile private var interrupted: Boolean = false
-
- // Whether the task has completed.
- @volatile private var completed: Boolean = false
-
- /** Checks whether the task has completed. */
- def isCompleted: Boolean = completed
-
- /** Checks whether the task has been killed. */
- def isInterrupted: Boolean = interrupted
-
- // TODO: Also track whether the task has completed successfully or with exception.
-
- /**
- * Add a (Java friendly) listener to be executed on task completion.
- * This will be called in all situation - success, failure, or cancellation.
- *
- * An example use is for HadoopRDD to register a callback to close the input stream.
- */
- def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
- onCompleteCallbacks += listener
- this
- }
-
- /**
- * Add a listener in the form of a Scala closure to be executed on task completion.
- * This will be called in all situation - success, failure, or cancellation.
- *
- * An example use is for HadoopRDD to register a callback to close the input stream.
- */
- def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
- onCompleteCallbacks += new TaskCompletionListener {
- override def onTaskCompletion(context: TaskContext): Unit = f(context)
- }
- this
- }
-
- /**
- * Add a callback function to be executed on task completion. An example use
- * is for HadoopRDD to register a callback to close the input stream.
- * Will be called in any situation - success, failure, or cancellation.
- * @param f Callback function.
- */
- @deprecated("use addTaskCompletionListener", "1.1.0")
- def addOnCompleteCallback(f: () => Unit) {
- onCompleteCallbacks += new TaskCompletionListener {
- override def onTaskCompletion(context: TaskContext): Unit = f()
- }
- }
-
- /** Marks the task as completed and triggers the listeners. */
- private[spark] def markTaskCompleted(): Unit = {
- completed = true
- val errorMsgs = new ArrayBuffer[String](2)
- // Process complete callbacks in the reverse order of registration
- onCompleteCallbacks.reverse.foreach { listener =>
- try {
- listener.onTaskCompletion(this)
- } catch {
- case e: Throwable =>
- errorMsgs += e.getMessage
- logError("Error in TaskCompletionListener", e)
- }
- }
- if (errorMsgs.nonEmpty) {
- throw new TaskCompletionListenerException(errorMsgs)
- }
- }
-
- /** Marks the task for interruption, i.e. cancellation. */
- private[spark] def markInterrupted(): Unit = {
- interrupted = true
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 880f61c49726e..0846225e4f992 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -469,6 +469,22 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)})
}
+ /**
+ * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or
+ * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each
+ * element (k, w) in `other`, the resulting RDD will either contain all pairs
+ * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements
+ * in `this` have key k. Uses the given Partitioner to partition the output RDD.
+ */
+ def fullOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
+ : JavaPairRDD[K, (Optional[V], Optional[W])] = {
+ val joinResult = rdd.fullOuterJoin(other, partitioner)
+ fromRDD(joinResult.mapValues{ case (v, w) =>
+ (JavaUtils.optionToOptional(v), JavaUtils.optionToOptional(w))
+ })
+ }
+
/**
* Simplified version of combineByKey that hash-partitions the resulting RDD using the existing
* partitioner/parallelism level.
@@ -563,6 +579,38 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)})
}
+ /**
+ * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or
+ * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each
+ * element (k, w) in `other`, the resulting RDD will either contain all pairs
+ * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements
+ * in `this` have key k. Hash-partitions the resulting RDD using the existing partitioner/
+ * parallelism level.
+ */
+ def fullOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Optional[V], Optional[W])] = {
+ val joinResult = rdd.fullOuterJoin(other)
+ fromRDD(joinResult.mapValues{ case (v, w) =>
+ (JavaUtils.optionToOptional(v), JavaUtils.optionToOptional(w))
+ })
+ }
+
+ /**
+ * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or
+ * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each
+ * element (k, w) in `other`, the resulting RDD will either contain all pairs
+ * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements
+ * in `this` have key k. Hash-partitions the resulting RDD into the given number of partitions.
+ */
+ def fullOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int)
+ : JavaPairRDD[K, (Optional[V], Optional[W])] = {
+ val joinResult = rdd.fullOuterJoin(other, numPartitions)
+ fromRDD(joinResult.mapValues{ case (v, w) =>
+ (JavaUtils.optionToOptional(v), JavaUtils.optionToOptional(w))
+ })
+ }
+
/**
* Return the key-value pairs in this RDD to the master as a Map.
*/
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 12b345a8fa7c3..c74f86548ef85 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -196,7 +196,6 @@ private[spark] class PythonRDD(
override def run(): Unit = Utils.logUncaughtExceptions {
try {
- SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
// Partition index
@@ -248,6 +247,11 @@ private[spark] class PythonRDD(
// will kill the whole executor (see org.apache.spark.executor.Executor).
_exception = e
worker.shutdownOutput()
+ } finally {
+ // Release memory used by this thread for shuffles
+ env.shuffleMemoryManager.releaseMemoryForThisThread()
+ // Release memory used by this thread for unrolling blocks
+ env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
}
}
}
@@ -339,26 +343,34 @@ private[spark] object PythonRDD extends Logging {
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
- val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
try {
- while (true) {
- val length = file.readInt()
- val obj = new Array[Byte](length)
- file.readFully(obj)
- objs.append(obj)
+ val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
+ try {
+ while (true) {
+ val length = file.readInt()
+ val obj = new Array[Byte](length)
+ file.readFully(obj)
+ objs.append(obj)
+ }
+ } catch {
+ case eof: EOFException => {}
}
- } catch {
- case eof: EOFException => {}
+ JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
+ } finally {
+ file.close()
}
- JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
- val length = file.readInt()
- val obj = new Array[Byte](length)
- file.readFully(obj)
- sc.broadcast(obj)
+ try {
+ val length = file.readInt()
+ val obj = new Array[Byte](length)
+ file.readFully(obj)
+ sc.broadcast(obj)
+ } finally {
+ file.close()
+ }
}
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
@@ -775,17 +787,36 @@ private[spark] object PythonRDD extends Logging {
}.toJavaRDD()
}
+ private class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
+ private val pickle = new Pickler()
+ private var batch = 1
+ private val buffer = new mutable.ArrayBuffer[Any]
+
+ override def hasNext(): Boolean = iter.hasNext
+
+ override def next(): Array[Byte] = {
+ while (iter.hasNext && buffer.length < batch) {
+ buffer += iter.next()
+ }
+ val bytes = pickle.dumps(buffer.toArray)
+ val size = bytes.length
+ // let 1M < size < 10M
+ if (size < 1024 * 1024) {
+ batch *= 2
+ } else if (size > 1024 * 1024 * 10 && batch > 1) {
+ batch /= 2
+ }
+ buffer.clear()
+ bytes
+ }
+ }
+
/**
* Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
* PySpark.
*/
def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
- jRDD.rdd.mapPartitions { iter =>
- val pickle = new Pickler
- iter.map { row =>
- pickle.dumps(row)
- }
- }
+ jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
index 6668797f5f8be..7903457b17e13 100644
--- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
@@ -68,8 +68,8 @@ private[python] object SerDeUtil extends Logging {
construct(args ++ Array(""))
} else if (args.length == 2 && args(1).isInstanceOf[String]) {
val typecode = args(0).asInstanceOf[String].charAt(0)
- val data: String = args(1).asInstanceOf[String]
- construct(typecode, machineCodes(typecode), data.getBytes("ISO-8859-1"))
+ val data: Array[Byte] = args(1).asInstanceOf[String].getBytes("ISO-8859-1")
+ construct(typecode, machineCodes(typecode), data)
} else {
super.construct(args)
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 942dc7d7eac87..4cd4f4f96fd16 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -163,18 +163,23 @@ private[broadcast] object HttpBroadcast extends Logging {
private def write(id: Long, value: Any) {
val file = getFile(id)
- val out: OutputStream = {
- if (compress) {
- compressionCodec.compressedOutputStream(new FileOutputStream(file))
- } else {
- new BufferedOutputStream(new FileOutputStream(file), bufferSize)
+ val fileOutputStream = new FileOutputStream(file)
+ try {
+ val out: OutputStream = {
+ if (compress) {
+ compressionCodec.compressedOutputStream(fileOutputStream)
+ } else {
+ new BufferedOutputStream(fileOutputStream, bufferSize)
+ }
}
+ val ser = SparkEnv.get.serializer.newInstance()
+ val serOut = ser.serializeStream(out)
+ serOut.writeObject(value)
+ serOut.close()
+ files += file
+ } finally {
+ fileOutputStream.close()
}
- val ser = SparkEnv.get.serializer.newInstance()
- val serOut = ser.serializeStream(out)
- serOut.writeObject(value)
- serOut.close()
- files += file
}
private def read[T: ClassTag](id: Long): T = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
index b66c3ba4d5fb0..79b4d7ea41a33 100644
--- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
@@ -54,9 +54,10 @@ object PythonRunner {
val pythonPath = PythonUtils.mergePythonPaths(pathElements: _*)
// Launch Python process
- val builder = new ProcessBuilder(Seq(pythonExec, "-u", formattedPythonFile) ++ otherArgs)
+ val builder = new ProcessBuilder(Seq(pythonExec, formattedPythonFile) ++ otherArgs)
val env = builder.environment()
env.put("PYTHONPATH", pythonPath)
+ env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
val process = builder.start()
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 5ed3575816a38..f97bf67fa5a3b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -54,7 +54,7 @@ object SparkSubmit {
private val SPARK_SHELL = "spark-shell"
private val PYSPARK_SHELL = "pyspark-shell"
- private val CLASS_NOT_FOUND_EXIT_STATUS = 1
+ private val CLASS_NOT_FOUND_EXIT_STATUS = 101
// Exposed for testing
private[spark] var exitFn: () => Unit = () => System.exit(-1)
@@ -172,7 +172,7 @@ object SparkSubmit {
// All cluster managers
OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"),
OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"),
- OptionAssigner(args.jars, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.jars"),
+ OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"),
OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT,
sysProp = "spark.driver.memory"),
OptionAssigner(args.driverExtraClassPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES,
@@ -183,6 +183,7 @@ object SparkSubmit {
sysProp = "spark.driver.extraLibraryPath"),
// Standalone cluster only
+ OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"),
OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"),
OptionAssigner(args.driverCores, STANDALONE, CLUSTER, clOption = "--cores"),
@@ -261,7 +262,7 @@ object SparkSubmit {
}
// In yarn-cluster mode, use yarn.Client as a wrapper around the user class
- if (clusterManager == YARN && deployMode == CLUSTER) {
+ if (isYarnCluster) {
childMainClass = "org.apache.spark.deploy.yarn.Client"
if (args.primaryResource != SPARK_INTERNAL) {
childArgs += ("--jar", args.primaryResource)
@@ -279,7 +280,7 @@ object SparkSubmit {
}
// Read from default spark properties, if any
- for ((k, v) <- args.getDefaultSparkProperties) {
+ for ((k, v) <- args.defaultSparkProperties) {
sysProps.getOrElseUpdate(k, v)
}
@@ -319,6 +320,10 @@ object SparkSubmit {
} catch {
case e: ClassNotFoundException =>
e.printStackTrace(printStream)
+ if (childMainClass.contains("thriftserver")) {
+ println(s"Failed to load main class $childMainClass.")
+ println("You need to build Spark with -Phive.")
+ }
System.exit(CLASS_NOT_FOUND_EXIT_STATUS)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index d545f58c5da7e..57b251ff47714 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -29,8 +29,9 @@ import org.apache.spark.util.Utils
/**
* Parses and encapsulates arguments from the spark-submit script.
+ * The env argument is used for testing.
*/
-private[spark] class SparkSubmitArguments(args: Seq[String]) {
+private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env) {
var master: String = null
var deployMode: String = null
var executorMemory: String = null
@@ -57,12 +58,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
var pyFiles: String = null
val sparkProperties: HashMap[String, String] = new HashMap[String, String]()
- parseOpts(args.toList)
- mergeSparkProperties()
- checkRequiredArguments()
-
- /** Return default present in the currently defined defaults file. */
- def getDefaultSparkProperties = {
+ /** Default properties present in the currently defined defaults file. */
+ lazy val defaultSparkProperties: HashMap[String, String] = {
val defaultProperties = new HashMap[String, String]()
if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile")
Option(propertiesFile).foreach { filename =>
@@ -79,6 +76,14 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
defaultProperties
}
+ // Respect SPARK_*_MEMORY for cluster mode
+ driverMemory = sys.env.get("SPARK_DRIVER_MEMORY").orNull
+ executorMemory = sys.env.get("SPARK_EXECUTOR_MEMORY").orNull
+
+ parseOpts(args.toList)
+ mergeSparkProperties()
+ checkRequiredArguments()
+
/**
* Fill in any undefined values based on the default properties file or options passed in through
* the '--conf' flag.
@@ -86,20 +91,12 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
private def mergeSparkProperties(): Unit = {
// Use common defaults file, if not specified by user
if (propertiesFile == null) {
- sys.env.get("SPARK_CONF_DIR").foreach { sparkConfDir =>
- val sep = File.separator
- val defaultPath = s"${sparkConfDir}${sep}spark-defaults.conf"
- val file = new File(defaultPath)
- if (file.exists()) {
- propertiesFile = file.getAbsolutePath
- }
- }
- }
+ val sep = File.separator
+ val sparkHomeConfig = env.get("SPARK_HOME").map(sparkHome => s"${sparkHome}${sep}conf")
+ val confDir = env.get("SPARK_CONF_DIR").orElse(sparkHomeConfig)
- if (propertiesFile == null) {
- sys.env.get("SPARK_HOME").foreach { sparkHome =>
- val sep = File.separator
- val defaultPath = s"${sparkHome}${sep}conf${sep}spark-defaults.conf"
+ confDir.foreach { sparkConfDir =>
+ val defaultPath = s"${sparkConfDir}${sep}spark-defaults.conf"
val file = new File(defaultPath)
if (file.exists()) {
propertiesFile = file.getAbsolutePath
@@ -107,24 +104,24 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
}
}
- val properties = getDefaultSparkProperties
+ val properties = HashMap[String, String]()
+ properties.putAll(defaultSparkProperties)
properties.putAll(sparkProperties)
// Use properties file as fallback for values which have a direct analog to
// arguments in this script.
- master = Option(master).getOrElse(properties.get("spark.master").orNull)
- executorMemory = Option(executorMemory)
- .getOrElse(properties.get("spark.executor.memory").orNull)
- executorCores = Option(executorCores)
- .getOrElse(properties.get("spark.executor.cores").orNull)
+ master = Option(master).orElse(properties.get("spark.master")).orNull
+ executorMemory = Option(executorMemory).orElse(properties.get("spark.executor.memory")).orNull
+ executorCores = Option(executorCores).orElse(properties.get("spark.executor.cores")).orNull
totalExecutorCores = Option(totalExecutorCores)
- .getOrElse(properties.get("spark.cores.max").orNull)
- name = Option(name).getOrElse(properties.get("spark.app.name").orNull)
- jars = Option(jars).getOrElse(properties.get("spark.jars").orNull)
+ .orElse(properties.get("spark.cores.max"))
+ .orNull
+ name = Option(name).orElse(properties.get("spark.app.name")).orNull
+ jars = Option(jars).orElse(properties.get("spark.jars")).orNull
// This supports env vars in older versions of Spark
- master = Option(master).getOrElse(System.getenv("MASTER"))
- deployMode = Option(deployMode).getOrElse(System.getenv("DEPLOY_MODE"))
+ master = Option(master).orElse(env.get("MASTER")).orNull
+ deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull
// Try to set main class from JAR if no --class argument is given
if (mainClass == null && !isPython && primaryResource != null) {
@@ -177,7 +174,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
}
if (master.startsWith("yarn")) {
- val hasHadoopEnv = sys.env.contains("HADOOP_CONF_DIR") || sys.env.contains("YARN_CONF_DIR")
+ val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR")
if (!hasHadoopEnv && !Utils.isTesting) {
throw new Exception(s"When running with master '$master' " +
"either HADOOP_CONF_DIR or YARN_CONF_DIR must be set in the environment.")
@@ -213,7 +210,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
| verbose $verbose
|
|Default properties from $propertiesFile:
- |${getDefaultSparkProperties.mkString(" ", "\n ", "\n")}
+ |${defaultSparkProperties.mkString(" ", "\n ", "\n")}
""".stripMargin
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
index 38b5d8e1739d0..a64170a47bc1c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
@@ -154,7 +154,8 @@ private[spark] object SparkSubmitDriverBootstrapper {
process.destroy()
}
}
- process.waitFor()
+ val returnCode = process.waitFor()
+ sys.exit(returnCode)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
index c4ef8b63b0071..d25c29113d6da 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
@@ -67,6 +67,7 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
}
private val appHeader = Seq(
+ "App ID",
"App Name",
"Started",
"Completed",
@@ -81,7 +82,8 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
val duration = UIUtils.formatDuration(info.endTime - info.startTime)
val lastUpdated = UIUtils.formatDate(info.lastUpdated)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
index aa85aa060d9c1..08a99bbe68578 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -83,15 +83,21 @@ private[spark] class FileSystemPersistenceEngine(
val serialized = serializer.toBinary(value)
val out = new FileOutputStream(file)
- out.write(serialized)
- out.close()
+ try {
+ out.write(serialized)
+ } finally {
+ out.close()
+ }
}
def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = {
val fileData = new Array[Byte](file.length().asInstanceOf[Int])
val dis = new DataInputStream(new FileInputStream(file))
- dis.readFully(fileData)
- dis.close()
+ try {
+ dis.readFully(fileData)
+ } finally {
+ dis.close()
+ }
val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 2a3bd6ba0b9dc..f98b531316a3d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -33,8 +33,8 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
import akka.serialization.SerializationExtension
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
-import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState,
- SparkHadoopUtil}
+import org.apache.spark.deploy.{ApplicationDescription, DriverDescription,
+ ExecutorState, SparkHadoopUtil}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.history.HistoryServer
import org.apache.spark.deploy.master.DriverState.DriverState
@@ -489,23 +489,24 @@ private[spark] class Master(
// First schedule drivers, they take strict precedence over applications
// Randomization helps balance drivers
val shuffledAliveWorkers = Random.shuffle(workers.toSeq.filter(_.state == WorkerState.ALIVE))
- val aliveWorkerNum = shuffledAliveWorkers.size
+ val numWorkersAlive = shuffledAliveWorkers.size
var curPos = 0
+
for (driver <- waitingDrivers.toList) { // iterate over a copy of waitingDrivers
// We assign workers to each waiting driver in a round-robin fashion. For each driver, we
// start from the last worker that was assigned a driver, and continue onwards until we have
// explored all alive workers.
- curPos = (curPos + 1) % aliveWorkerNum
- val startPos = curPos
var launched = false
- while (curPos != startPos && !launched) {
+ var numWorkersVisited = 0
+ while (numWorkersVisited < numWorkersAlive && !launched) {
val worker = shuffledAliveWorkers(curPos)
+ numWorkersVisited += 1
if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) {
launchDriver(worker, driver)
waitingDrivers -= driver
launched = true
}
- curPos = (curPos + 1) % aliveWorkerNum
+ curPos = (curPos + 1) % numWorkersAlive
}
}
@@ -692,16 +693,18 @@ private[spark] class Master(
app.desc.appUiUrl = notFoundBasePath
return false
}
- val fileSystem = Utils.getHadoopFileSystem(eventLogDir,
+
+ val appEventLogDir = EventLoggingListener.getLogDirPath(eventLogDir, app.id)
+ val fileSystem = Utils.getHadoopFileSystem(appEventLogDir,
SparkHadoopUtil.get.newConfiguration(conf))
- val eventLogInfo = EventLoggingListener.parseLoggingInfo(eventLogDir, fileSystem)
+ val eventLogInfo = EventLoggingListener.parseLoggingInfo(appEventLogDir, fileSystem)
val eventLogPaths = eventLogInfo.logPaths
val compressionCodec = eventLogInfo.compressionCodec
if (eventLogPaths.isEmpty) {
// Event logging is enabled for this application, but no event logs are found
val title = s"Application history not found (${app.id})"
- var msg = s"No event logs found for application $appName in $eventLogDir."
+ var msg = s"No event logs found for application $appName in $appEventLogDir."
logWarning(msg)
msg += " Did you specify the correct logging directory?"
msg = URLEncoder.encode(msg, "UTF-8")
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
index 12e98fd40d6c9..2e9be2a180c68 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
@@ -30,7 +30,7 @@ import org.apache.spark.util.Utils
private[spark]
object CommandUtils extends Logging {
def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = {
- val runner = getEnv("JAVA_HOME", command).map(_ + "/bin/java").getOrElse("java")
+ val runner = sys.env.get("JAVA_HOME").map(_ + "/bin/java").getOrElse("java")
// SPARK-698: do not call the run.cmd script, as process.destroy()
// fails to kill a process tree on Windows
@@ -38,9 +38,6 @@ object CommandUtils extends Logging {
command.arguments
}
- private def getEnv(key: String, command: Command): Option[String] =
- command.environment.get(key).orElse(Option(System.getenv(key)))
-
/**
* Attention: this must always be aligned with the environment variables in the run scripts and
* the way the JAVA_OPTS are assembled there.
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 00a43673e5cd3..71650cd773bcf 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -42,7 +42,7 @@ private[spark] class ExecutorRunner(
val workerId: String,
val host: String,
val sparkHome: File,
- val workDir: File,
+ val executorDir: File,
val workerUrl: String,
val conf: SparkConf,
var state: ExecutorState.Value)
@@ -130,12 +130,6 @@ private[spark] class ExecutorRunner(
*/
def fetchAndRunExecutor() {
try {
- // Create the executor's working directory
- val executorDir = new File(workDir, appId + "/" + execId)
- if (!executorDir.mkdirs()) {
- throw new IOException("Failed to create directory " + executorDir)
- }
-
// Launch the process
val command = getCommandSeq
logInfo("Launch command: " + command.mkString("\"", "\" \"", "\""))
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 0c454e4138c96..9b52cb06fb6fa 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -18,9 +18,11 @@
package org.apache.spark.deploy.worker
import java.io.File
+import java.io.IOException
import java.text.SimpleDateFormat
import java.util.Date
+import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap
import scala.concurrent.duration._
import scala.language.postfixOps
@@ -191,6 +193,7 @@ private[spark] class Worker(
changeMaster(masterUrl, masterWebUiUrl)
context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat)
if (CLEANUP_ENABLED) {
+ logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir")
context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis,
CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup)
}
@@ -201,10 +204,23 @@ private[spark] class Worker(
case WorkDirCleanup =>
// Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor
val cleanupFuture = concurrent.future {
- logInfo("Cleaning up oldest application directories in " + workDir + " ...")
- Utils.findOldFiles(workDir, APP_DATA_RETENTION_SECS)
- .foreach(Utils.deleteRecursively)
+ val appDirs = workDir.listFiles()
+ if (appDirs == null) {
+ throw new IOException("ERROR: Failed to list files in " + appDirs)
+ }
+ appDirs.filter { dir =>
+ // the directory is used by an application - check that the application is not running
+ // when cleaning up
+ val appIdFromDir = dir.getName
+ val isAppStillRunning = executors.values.map(_.appId).contains(appIdFromDir)
+ dir.isDirectory && !isAppStillRunning &&
+ !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECS)
+ }.foreach { dir =>
+ logInfo(s"Removing directory: ${dir.getPath}")
+ Utils.deleteRecursively(dir)
+ }
}
+
cleanupFuture onFailure {
case e: Throwable =>
logError("App dir cleanup failed: " + e.getMessage, e)
@@ -233,8 +249,15 @@ private[spark] class Worker(
} else {
try {
logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
+
+ // Create the executor's working directory
+ val executorDir = new File(workDir, appId + "/" + execId)
+ if (!executorDir.mkdirs()) {
+ throw new IOException("Failed to create directory " + executorDir)
+ }
+
val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_,
- self, workerId, host, sparkHome, workDir, akkaUrl, conf, ExecutorState.LOADING)
+ self, workerId, host, sparkHome, executorDir, akkaUrl, conf, ExecutorState.LOADING)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
@@ -242,12 +265,13 @@ private[spark] class Worker(
master ! ExecutorStateChanged(appId, execId, manager.state, None, None)
} catch {
case e: Exception => {
- logError("Failed to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
+ logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e)
if (executors.contains(appId + "/" + execId)) {
executors(appId + "/" + execId).kill()
executors -= appId + "/" + execId
}
- master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, None, None)
+ master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED,
+ Some(e.toString), None)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 13af5b6f5812d..06061edfc0844 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -106,6 +106,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
executorId: String,
hostname: String,
cores: Int,
+ appId: String,
workerUrl: Option[String]) {
SignalLogger.register(log)
@@ -122,7 +123,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
val driver = fetcher.actorSelection(driverUrl)
val timeout = AkkaUtils.askTimeout(executorConf)
val fut = Patterns.ask(driver, RetrieveSparkProps, timeout)
- val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]]
+ val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++
+ Seq[(String, String)](("spark.app.id", appId))
fetcher.shutdown()
// Create a new ActorSystem using driver's Spark properties to run the backend.
@@ -144,16 +146,16 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
def main(args: Array[String]) {
args.length match {
- case x if x < 4 =>
+ case x if x < 5 =>
System.err.println(
// Worker url is used in spark standalone mode to enforce fate-sharing with worker
"Usage: CoarseGrainedExecutorBackend " +
- " []")
+ " [] ")
System.exit(1)
- case 4 =>
- run(args(0), args(1), args(2), args(3).toInt, None)
- case x if x > 4 =>
- run(args(0), args(1), args(2), args(3).toInt, Some(args(4)))
+ case 5 =>
+ run(args(0), args(1), args(2), args(3).toInt, args(4), None)
+ case x if x > 5 =>
+ run(args(0), args(1), args(2), args(3).toInt, args(4), Some(args(5)))
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index acae448a9c66f..616c7e6a46368 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -24,6 +24,7 @@ import java.util.concurrent._
import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.util.control.NonFatal
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
@@ -73,6 +74,7 @@ private[spark] class Executor(
val executorSource = new ExecutorSource(this, executorId)
// Initialize Spark environment (using system properties read above)
+ conf.set("spark.executor.id", "executor." + executorId)
private val env = {
if (!isLocal) {
val _env = SparkEnv.create(conf, executorId, slaveHostname, 0,
@@ -146,7 +148,6 @@ private[spark] class Executor(
override def run() {
val startTime = System.currentTimeMillis()
- SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
@@ -156,7 +157,6 @@ private[spark] class Executor(
val startGCTime = gcTime
try {
- SparkEnv.set(env)
Accumulators.clear()
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
@@ -375,12 +375,17 @@ private[spark] class Executor(
}
val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId)
- val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef,
- retryAttempts, retryIntervalMs, timeout)
- if (response.reregisterBlockManager) {
- logWarning("Told to re-register on heartbeat")
- env.blockManager.reregister()
+ try {
+ val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef,
+ retryAttempts, retryIntervalMs, timeout)
+ if (response.reregisterBlockManager) {
+ logWarning("Told to re-register on heartbeat")
+ env.blockManager.reregister()
+ }
+ } catch {
+ case NonFatal(t) => logWarning("Issue communicating with driver in heartbeater", t)
}
+
Thread.sleep(interval)
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
index d6721586566c2..c4d73622c4727 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
@@ -37,8 +37,7 @@ private[spark] class ExecutorSource(val executor: Executor, executorId: String)
override val metricRegistry = new MetricRegistry()
- // TODO: It would be nice to pass the application name here
- override val sourceName = "executor.%s".format(executorId)
+ override val sourceName = "executor"
// Gauge for executor thread pool's actively executing task counts
metricRegistry.register(MetricRegistry.name("threadpool", "activeTasks"), new Gauge[Int] {
diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index a42c8b43bbf7f..bca0b152268ad 100644
--- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -52,7 +52,8 @@ private[spark] class MesosExecutorBackend
slaveInfo: SlaveInfo) {
logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
this.driver = driver
- val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
+ val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) ++
+ Seq[(String, String)](("spark.app.id", frameworkInfo.getId.getValue))
executor = new Executor(
executorInfo.getExecutorId.getValue,
slaveInfo.getHostname,
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 99a88c13456df..3e49b6235aff3 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -137,7 +137,6 @@ class TaskMetrics extends Serializable {
merged.localBlocksFetched += depMetrics.localBlocksFetched
merged.remoteBlocksFetched += depMetrics.remoteBlocksFetched
merged.remoteBytesRead += depMetrics.remoteBytesRead
- merged.shuffleFinishTime = math.max(merged.shuffleFinishTime, depMetrics.shuffleFinishTime)
}
_shuffleReadMetrics = Some(merged)
}
@@ -177,11 +176,6 @@ case class InputMetrics(readMethod: DataReadMethod.Value) {
*/
@DeveloperApi
class ShuffleReadMetrics extends Serializable {
- /**
- * Absolute time when this task finished reading shuffle data
- */
- var shuffleFinishTime: Long = -1
-
/**
* Number of blocks fetched in this shuffle by this task (remote or local)
*/
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
index c3dabd2e79995..3564ab2e2a162 100644
--- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
@@ -36,33 +36,31 @@ private[spark] class WholeTextFileRecordReader(
index: Integer)
extends RecordReader[String, String] {
- private val path = split.getPath(index)
- private val fs = path.getFileSystem(context.getConfiguration)
+ private[this] val path = split.getPath(index)
+ private[this] val fs = path.getFileSystem(context.getConfiguration)
// True means the current file has been processed, then skip it.
- private var processed = false
+ private[this] var processed = false
- private val key = path.toString
- private var value: String = null
+ private[this] val key = path.toString
+ private[this] var value: String = null
- override def initialize(split: InputSplit, context: TaskAttemptContext) = {}
+ override def initialize(split: InputSplit, context: TaskAttemptContext): Unit = {}
- override def close() = {}
+ override def close(): Unit = {}
- override def getProgress = if (processed) 1.0f else 0.0f
+ override def getProgress: Float = if (processed) 1.0f else 0.0f
- override def getCurrentKey = key
+ override def getCurrentKey: String = key
- override def getCurrentValue = value
+ override def getCurrentValue: String = value
- override def nextKeyValue = {
+ override def nextKeyValue(): Boolean = {
if (!processed) {
val fileIn = fs.open(path)
val innerBuffer = ByteStreams.toByteArray(fileIn)
-
value = new Text(innerBuffer).toString
Closeables.close(fileIn, false)
-
processed = true
true
} else {
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index 6ef817d0e587e..5dd67b0cbf683 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -63,15 +63,18 @@ import org.apache.spark.metrics.source.Source
*
* [options] is the specific property of this source or sink.
*/
-private[spark] class MetricsSystem private (val instance: String,
- conf: SparkConf, securityMgr: SecurityManager) extends Logging {
+private[spark] class MetricsSystem private (
+ val instance: String,
+ conf: SparkConf,
+ securityMgr: SecurityManager)
+ extends Logging {
- val confFile = conf.get("spark.metrics.conf", null)
- val metricsConfig = new MetricsConfig(Option(confFile))
+ private[this] val confFile = conf.get("spark.metrics.conf", null)
+ private[this] val metricsConfig = new MetricsConfig(Option(confFile))
- val sinks = new mutable.ArrayBuffer[Sink]
- val sources = new mutable.ArrayBuffer[Source]
- val registry = new MetricRegistry()
+ private val sinks = new mutable.ArrayBuffer[Sink]
+ private val sources = new mutable.ArrayBuffer[Source]
+ private val registry = new MetricRegistry()
// Treat MetricsServlet as a special sink as it should be exposed to add handlers to web ui
private var metricsServlet: Option[MetricsServlet] = None
@@ -80,10 +83,10 @@ private[spark] class MetricsSystem private (val instance: String,
def getServletHandlers = metricsServlet.map(_.getHandlers).getOrElse(Array())
metricsConfig.initialize()
- registerSources()
- registerSinks()
def start() {
+ registerSources()
+ registerSinks()
sinks.foreach(_.start)
}
@@ -91,14 +94,43 @@ private[spark] class MetricsSystem private (val instance: String,
sinks.foreach(_.stop)
}
- def report(): Unit = {
+ def report() {
sinks.foreach(_.report())
}
+ /**
+ * Build a name that uniquely identifies each metric source.
+ * The name is structured as follows: ...
+ * If either ID is not available, this defaults to just using .
+ *
+ * @param source Metric source to be named by this method.
+ * @return An unique metric name for each combination of
+ * application, executor/driver and metric source.
+ */
+ def buildRegistryName(source: Source): String = {
+ val appId = conf.getOption("spark.app.id")
+ val executorId = conf.getOption("spark.executor.id")
+ val defaultName = MetricRegistry.name(source.sourceName)
+
+ if (instance == "driver" || instance == "executor") {
+ if (appId.isDefined && executorId.isDefined) {
+ MetricRegistry.name(appId.get, executorId.get, source.sourceName)
+ } else {
+ // Only Driver and Executor are set spark.app.id and spark.executor.id.
+ // For instance, Master and Worker are not related to a specific application.
+ val warningMsg = s"Using default name $defaultName for source because %s is not set."
+ if (appId.isEmpty) { logWarning(warningMsg.format("spark.app.id")) }
+ if (executorId.isEmpty) { logWarning(warningMsg.format("spark.executor.id")) }
+ defaultName
+ }
+ } else { defaultName }
+ }
+
def registerSource(source: Source) {
sources += source
try {
- registry.register(source.sourceName, source.metricRegistry)
+ val regName = buildRegistryName(source)
+ registry.register(regName, source.metricRegistry)
} catch {
case e: IllegalArgumentException => logInfo("Metrics already registered", e)
}
@@ -106,8 +138,9 @@ private[spark] class MetricsSystem private (val instance: String,
def removeSource(source: Source) {
sources -= source
+ val regName = buildRegistryName(source)
registry.removeMatching(new MetricFilter {
- def matches(name: String, metric: Metric): Boolean = name.startsWith(source.sourceName)
+ def matches(name: String, metric: Metric): Boolean = name.startsWith(regName)
})
}
@@ -122,7 +155,7 @@ private[spark] class MetricsSystem private (val instance: String,
val source = Class.forName(classPath).newInstance()
registerSource(source.asInstanceOf[Source])
} catch {
- case e: Exception => logError("Source class " + classPath + " cannot be instantialized", e)
+ case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e)
}
}
}
@@ -155,8 +188,8 @@ private[spark] object MetricsSystem {
val SINK_REGEX = "^sink\\.(.+)\\.(.+)".r
val SOURCE_REGEX = "^source\\.(.+)\\.(.+)".r
- val MINIMAL_POLL_UNIT = TimeUnit.SECONDS
- val MINIMAL_POLL_PERIOD = 1
+ private[this] val MINIMAL_POLL_UNIT = TimeUnit.SECONDS
+ private[this] val MINIMAL_POLL_PERIOD = 1
def checkMinimalPollingPeriod(pollUnit: TimeUnit, pollPeriod: Int) {
val period = MINIMAL_POLL_UNIT.convert(pollPeriod, pollUnit)
@@ -166,7 +199,8 @@ private[spark] object MetricsSystem {
}
}
- def createMetricsSystem(instance: String, conf: SparkConf,
- securityMgr: SecurityManager): MetricsSystem =
+ def createMetricsSystem(
+ instance: String, conf: SparkConf, securityMgr: SecurityManager): MetricsSystem = {
new MetricsSystem(instance, conf, securityMgr)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
index e990c1da6730f..a4409181ec907 100644
--- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
@@ -17,15 +17,17 @@
package org.apache.spark.network
-import java.io.{FileInputStream, RandomAccessFile, File, InputStream}
+import java.io._
import java.nio.ByteBuffer
import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
+import scala.util.Try
+
import com.google.common.io.ByteStreams
import io.netty.buffer.{ByteBufInputStream, ByteBuf}
-import org.apache.spark.util.ByteBufferInputStream
+import org.apache.spark.util.{ByteBufferInputStream, Utils}
/**
@@ -71,18 +73,47 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt
try {
channel = new RandomAccessFile(file, "r").getChannel
channel.map(MapMode.READ_ONLY, offset, length)
+ } catch {
+ case e: IOException =>
+ Try(channel.size).toOption match {
+ case Some(fileLen) =>
+ throw new IOException(s"Error in reading $this (actual file length $fileLen)", e)
+ case None =>
+ throw new IOException(s"Error in opening $this", e)
+ }
} finally {
if (channel != null) {
- channel.close()
+ Utils.tryLog(channel.close())
}
}
}
override def inputStream(): InputStream = {
- val is = new FileInputStream(file)
- is.skip(offset)
- ByteStreams.limit(is, length)
+ var is: FileInputStream = null
+ try {
+ is = new FileInputStream(file)
+ is.skip(offset)
+ ByteStreams.limit(is, length)
+ } catch {
+ case e: IOException =>
+ if (is != null) {
+ Utils.tryLog(is.close())
+ }
+ Try(file.length).toOption match {
+ case Some(fileLen) =>
+ throw new IOException(s"Error in reading $this (actual file length $fileLen)", e)
+ case None =>
+ throw new IOException(s"Error in opening $this", e)
+ }
+ case e: Throwable =>
+ if (is != null) {
+ Utils.tryLog(is.close())
+ }
+ throw e
+ }
}
+
+ override def toString: String = s"${getClass.getName}($file, $offset, $length)"
}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
index 74074a8dcbfff..f368209980f93 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
@@ -20,23 +20,27 @@ package org.apache.spark.network.nio
import java.net._
import java.nio._
import java.nio.channels._
+import java.util.LinkedList
import org.apache.spark._
-import scala.collection.mutable.{ArrayBuffer, HashMap, Queue}
+import scala.collection.mutable.{ArrayBuffer, HashMap}
private[nio]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
- val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId)
+ val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId,
+ val securityMgr: SecurityManager)
extends Logging {
var sparkSaslServer: SparkSaslServer = null
var sparkSaslClient: SparkSaslClient = null
- def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = {
+ def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId,
+ securityMgr_ : SecurityManager) = {
this(channel_, selector_,
ConnectionManagerId.fromSocketAddress(
- channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), id_)
+ channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]),
+ id_, securityMgr_)
}
channel.configureBlocking(false)
@@ -52,14 +56,6 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
val remoteAddress = getRemoteAddress()
- /**
- * Used to synchronize client requests: client's work-related requests must
- * wait until SASL authentication completes.
- */
- private val authenticated = new Object()
-
- def getAuthenticated(): Object = authenticated
-
def isSaslComplete(): Boolean
def resetForceReregister(): Boolean
@@ -192,22 +188,22 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
private[nio]
class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
- remoteId_ : ConnectionManagerId, id_ : ConnectionId)
- extends Connection(SocketChannel.open, selector_, remoteId_, id_) {
+ remoteId_ : ConnectionManagerId, id_ : ConnectionId,
+ securityMgr_ : SecurityManager)
+ extends Connection(SocketChannel.open, selector_, remoteId_, id_, securityMgr_) {
def isSaslComplete(): Boolean = {
if (sparkSaslClient != null) sparkSaslClient.isComplete() else false
}
private class Outbox {
- val messages = new Queue[Message]()
+ val messages = new LinkedList[Message]()
val defaultChunkSize = 65536
var nextMessageToBeUsed = 0
def addMessage(message: Message) {
messages.synchronized {
- /* messages += message */
- messages.enqueue(message)
+ messages.add(message)
logDebug("Added [" + message + "] to outbox for sending to " +
"[" + getRemoteConnectionManagerId() + "]")
}
@@ -218,10 +214,27 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
while (!messages.isEmpty) {
/* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
/* val message = messages(nextMessageToBeUsed) */
- val message = messages.dequeue()
+
+ val message = if (securityMgr.isAuthenticationEnabled() && !isSaslComplete()) {
+ // only allow sending of security messages until sasl is complete
+ var pos = 0
+ var securityMsg: Message = null
+ while (pos < messages.size() && securityMsg == null) {
+ if (messages.get(pos).isSecurityNeg) {
+ securityMsg = messages.remove(pos)
+ }
+ pos = pos + 1
+ }
+ // didn't find any security messages and auth isn't completed so return
+ if (securityMsg == null) return None
+ securityMsg
+ } else {
+ messages.removeFirst()
+ }
+
val chunk = message.getChunkForSending(defaultChunkSize)
if (chunk.isDefined) {
- messages.enqueue(message)
+ messages.add(message)
nextMessageToBeUsed = nextMessageToBeUsed + 1
if (!message.started) {
logDebug(
@@ -273,6 +286,15 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
changeConnectionKeyInterest(DEFAULT_INTEREST)
}
+ def registerAfterAuth(): Unit = {
+ outbox.synchronized {
+ needForceReregister = true
+ }
+ if (channel.isConnected) {
+ registerInterest()
+ }
+ }
+
def send(message: Message) {
outbox.synchronized {
outbox.addMessage(message)
@@ -415,8 +437,9 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
private[spark] class ReceivingConnection(
channel_ : SocketChannel,
selector_ : Selector,
- id_ : ConnectionId)
- extends Connection(channel_, selector_, id_) {
+ id_ : ConnectionId,
+ securityMgr_ : SecurityManager)
+ extends Connection(channel_, selector_, id_, securityMgr_) {
def isSaslComplete(): Boolean = {
if (sparkSaslServer != null) sparkSaslServer.isComplete() else false
@@ -460,7 +483,7 @@ private[spark] class ReceivingConnection(
if (currId != null) currId else super.getRemoteConnectionManagerId()
}
- // The reciever's remote address is the local socket on remote side : which is NOT
+ // The receiver's remote address is the local socket on remote side : which is NOT
// the connection manager id of the receiver.
// We infer that from the messages we receive on the receiver socket.
private def processConnectionManagerId(header: MessageChunkHeader) {
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index 09d3ea306515b..01cd27a907eea 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -32,7 +32,7 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.language.postfixOps
import org.apache.spark._
-import org.apache.spark.util.{SystemClock, Utils}
+import org.apache.spark.util.Utils
private[nio] class ConnectionManager(
@@ -65,8 +65,6 @@ private[nio] class ConnectionManager(
private val selector = SelectorProvider.provider.openSelector()
private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)
- // default to 30 second timeout waiting for authentication
- private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30)
private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60)
private val handleMessageExecutor = new ThreadPoolExecutor(
@@ -409,7 +407,8 @@ private[nio] class ConnectionManager(
while (newChannel != null) {
try {
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
- val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId)
+ val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId,
+ securityManager)
newConnection.onReceive(receiveMessage)
addListeners(newConnection)
addConnection(newConnection)
@@ -501,7 +500,7 @@ private[nio] class ConnectionManager(
def changeConnectionKeyInterest(connection: Connection, ops: Int) {
keyInterestChangeRequests += ((connection.key, ops))
- // so that registerations happen !
+ // so that registrations happen !
wakeupSelector()
}
@@ -527,9 +526,8 @@ private[nio] class ConnectionManager(
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
- waitingConn.getAuthenticated().synchronized {
- waitingConn.getAuthenticated().notifyAll()
- }
+ waitingConn.registerAfterAuth()
+ wakeupSelector()
return
} else {
var replyToken : Array[Byte] = null
@@ -538,9 +536,8 @@ private[nio] class ConnectionManager(
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
- waitingConn.getAuthenticated().synchronized {
- waitingConn.getAuthenticated().notifyAll()
- }
+ waitingConn.registerAfterAuth()
+ wakeupSelector()
return
}
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
@@ -574,9 +571,11 @@ private[nio] class ConnectionManager(
}
replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
if (connection.isSaslComplete()) {
- logDebug("Server sasl completed: " + connection.connectionId)
+ logDebug("Server sasl completed: " + connection.connectionId +
+ " for: " + connectionId)
} else {
- logDebug("Server sasl not completed: " + connection.connectionId)
+ logDebug("Server sasl not completed: " + connection.connectionId +
+ " for: " + connectionId)
}
if (replyToken != null) {
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
@@ -723,7 +722,8 @@ private[nio] class ConnectionManager(
if (message == null) throw new Exception("Error creating security message")
connectionsAwaitingSasl += ((conn.connectionId, conn))
sendSecurityMessage(connManagerId, message)
- logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId)
+ logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId +
+ " to: " + connManagerId)
} catch {
case e: Exception => {
logError("Error getting first response from the SaslClient.", e)
@@ -744,7 +744,7 @@ private[nio] class ConnectionManager(
val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port)
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId,
- newConnectionId)
+ newConnectionId, securityManager)
logInfo("creating new sending connection for security! " + newConnectionId )
registerRequests.enqueue(newConnection)
@@ -769,61 +769,23 @@ private[nio] class ConnectionManager(
connectionManagerId.port)
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId,
- newConnectionId)
+ newConnectionId, securityManager)
logTrace("creating new sending connection: " + newConnectionId)
registerRequests.enqueue(newConnection)
newConnection
}
val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
- if (authEnabled) {
- checkSendAuthFirst(connectionManagerId, connection)
- }
+
message.senderAddress = id.toSocketAddress()
logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " +
"connectionid: " + connection.connectionId)
if (authEnabled) {
- // if we aren't authenticated yet lets block the senders until authentication completes
- try {
- connection.getAuthenticated().synchronized {
- val clock = SystemClock
- val startTime = clock.getTime()
-
- while (!connection.isSaslComplete()) {
- logDebug("getAuthenticated wait connectionid: " + connection.connectionId)
- // have timeout in case remote side never responds
- connection.getAuthenticated().wait(500)
- if (((clock.getTime() - startTime) >= (authTimeout * 1000))
- && (!connection.isSaslComplete())) {
- // took to long to authenticate the connection, something probably went wrong
- throw new Exception("Took to long for authentication to " + connectionManagerId +
- ", waited " + authTimeout + "seconds, failing.")
- }
- }
- }
- } catch {
- case e: Exception => logError("Exception while waiting for authentication.", e)
-
- // need to tell sender it failed
- messageStatuses.synchronized {
- val s = messageStatuses.get(message.id)
- s match {
- case Some(msgStatus) => {
- messageStatuses -= message.id
- logInfo("Notifying " + msgStatus.connectionManagerId)
- msgStatus.markDone(None)
- }
- case None => {
- logError("no messageStatus for failed message id: " + message.id)
- }
- }
- }
- }
+ checkSendAuthFirst(connectionManagerId, connection)
}
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
connection.send(message)
-
wakeupSelector()
}
@@ -832,7 +794,7 @@ private[nio] class ConnectionManager(
}
/**
- * Send a message and block until an acknowldgment is received or an error occurs.
+ * Send a message and block until an acknowledgment is received or an error occurs.
* @param connectionManagerId the message's destination
* @param message the message being sent
* @return a Future that either returns the acknowledgment message or captures an exception.
diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
index 59958ee894230..b389b9a2022c6 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
@@ -200,6 +200,6 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
val buffer = blockDataManager.getBlockData(blockId).orNull
logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " and got buffer " + buffer)
- buffer.nioByteBuffer()
+ if (buffer == null) null else buffer.nioByteBuffer()
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 036dcc49664ef..6b63eb23e9ee1 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -23,6 +23,7 @@ import java.io.EOFException
import scala.collection.immutable.Map
import scala.reflect.ClassTag
+import scala.collection.mutable.ListBuffer
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.mapred.FileSplit
@@ -43,6 +44,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.{DataReadMethod, InputMetrics}
import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.{NextIterator, Utils}
+import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation}
/**
@@ -194,7 +196,7 @@ class HadoopRDD[K, V](
val jobConf = getJobConf()
val inputFormat = getInputFormat(jobConf)
HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
- context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
+ context.getStageId, theSplit.index, context.getAttemptId.toInt, jobConf)
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
@@ -249,9 +251,21 @@ class HadoopRDD[K, V](
}
override def getPreferredLocations(split: Partition): Seq[String] = {
- // TODO: Filtering out "localhost" in case of file:// URLs
- val hadoopSplit = split.asInstanceOf[HadoopPartition]
- hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
+ val hsplit = split.asInstanceOf[HadoopPartition].inputSplit.value
+ val locs: Option[Seq[String]] = HadoopRDD.SPLIT_INFO_REFLECTIONS match {
+ case Some(c) =>
+ try {
+ val lsplit = c.inputSplitWithLocationInfo.cast(hsplit)
+ val infos = c.getLocationInfo.invoke(lsplit).asInstanceOf[Array[AnyRef]]
+ Some(HadoopRDD.convertSplitLocationInfo(infos))
+ } catch {
+ case e: Exception =>
+ logDebug("Failed to use InputSplitWithLocations.", e)
+ None
+ }
+ case None => None
+ }
+ locs.getOrElse(hsplit.getLocations.filter(_ != "localhost"))
}
override def checkpoint() {
@@ -261,7 +275,7 @@ class HadoopRDD[K, V](
def getConf: Configuration = getJobConf()
}
-private[spark] object HadoopRDD {
+private[spark] object HadoopRDD extends Logging {
/** Constructing Configuration objects is not threadsafe, use this lock to serialize. */
val CONFIGURATION_INSTANTIATION_LOCK = new Object()
@@ -309,4 +323,42 @@ private[spark] object HadoopRDD {
f(inputSplit, firstParent[T].iterator(split, context))
}
}
+
+ private[spark] class SplitInfoReflections {
+ val inputSplitWithLocationInfo =
+ Class.forName("org.apache.hadoop.mapred.InputSplitWithLocationInfo")
+ val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo")
+ val newInputSplit = Class.forName("org.apache.hadoop.mapreduce.InputSplit")
+ val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo")
+ val splitLocationInfo = Class.forName("org.apache.hadoop.mapred.SplitLocationInfo")
+ val isInMemory = splitLocationInfo.getMethod("isInMemory")
+ val getLocation = splitLocationInfo.getMethod("getLocation")
+ }
+
+ private[spark] val SPLIT_INFO_REFLECTIONS: Option[SplitInfoReflections] = try {
+ Some(new SplitInfoReflections)
+ } catch {
+ case e: Exception =>
+ logDebug("SplitLocationInfo and other new Hadoop classes are " +
+ "unavailable. Using the older Hadoop location info code.", e)
+ None
+ }
+
+ private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Seq[String] = {
+ val out = ListBuffer[String]()
+ infos.foreach { loc => {
+ val locationStr = HadoopRDD.SPLIT_INFO_REFLECTIONS.get.
+ getLocation.invoke(loc).asInstanceOf[String]
+ if (locationStr != "localhost") {
+ if (HadoopRDD.SPLIT_INFO_REFLECTIONS.get.isInMemory.
+ invoke(loc).asInstanceOf[Boolean]) {
+ logDebug("Partition " + locationStr + " is cached by Hadoop.")
+ out += new HDFSCacheTaskLocation(locationStr).toString
+ } else {
+ out += new HostTaskLocation(locationStr).toString
+ }
+ }
+ }}
+ out.seq
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 4c84b3f62354d..0cccdefc5ee09 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -173,9 +173,21 @@ class NewHadoopRDD[K, V](
new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning)
}
- override def getPreferredLocations(split: Partition): Seq[String] = {
- val theSplit = split.asInstanceOf[NewHadoopPartition]
- theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
+ override def getPreferredLocations(hsplit: Partition): Seq[String] = {
+ val split = hsplit.asInstanceOf[NewHadoopPartition].serializableHadoopSplit.value
+ val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match {
+ case Some(c) =>
+ try {
+ val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]]
+ Some(HadoopRDD.convertSplitLocationInfo(infos))
+ } catch {
+ case e : Exception =>
+ logDebug("Failed to use InputSplit#getLocationInfo.", e)
+ None
+ }
+ case None => None
+ }
+ locs.getOrElse(split.getLocations.filter(_ != "localhost"))
}
def getConf: Configuration = confBroadcast.value.value
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index f6d9d12fe9006..0d97506450a7f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -86,7 +86,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
}
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) {
- self.mapPartitionsWithContext((context, iter) => {
+ self.mapPartitions(iter => {
+ val context = TaskContext.get()
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
}, preservesPartitioning = true)
} else {
@@ -419,6 +420,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
/**
* Group the values for each key in the RDD into a single sequence. Allows controlling the
* partitioning of the resulting key-value pair RDD by passing a Partitioner.
+ * The ordering of elements within each group is not guaranteed, and may even differ
+ * each time the resulting RDD is evaluated.
*
* Note: This operation may be very expensive. If you are grouping in order to perform an
* aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]]
@@ -438,7 +441,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
- * resulting RDD with into `numPartitions` partitions.
+ * resulting RDD with into `numPartitions` partitions. The ordering of elements within
+ * each group is not guaranteed, and may even differ each time the resulting RDD is evaluated.
*
* Note: This operation may be very expensive. If you are grouping in order to perform an
* aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]]
@@ -506,6 +510,23 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
}
}
+ /**
+ * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or
+ * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each
+ * element (k, w) in `other`, the resulting RDD will either contain all pairs
+ * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements
+ * in `this` have key k. Uses the given Partitioner to partition the output RDD.
+ */
+ def fullOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner)
+ : RDD[(K, (Option[V], Option[W]))] = {
+ this.cogroup(other, partitioner).flatMapValues {
+ case (vs, Seq()) => vs.map(v => (Some(v), None))
+ case (Seq(), ws) => ws.map(w => (None, Some(w)))
+ case (vs, ws) => for (v <- vs; w <- ws) yield (Some(v), Some(w))
+ }
+ }
+
/**
* Simplified version of combineByKey that hash-partitions the resulting RDD using the
* existing partitioner/parallelism level.
@@ -517,7 +538,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
- * resulting RDD with the existing partitioner/parallelism level.
+ * resulting RDD with the existing partitioner/parallelism level. The ordering of elements
+ * within each group is not guaranteed, and may even differ each time the resulting RDD is
+ * evaluated.
*
* Note: This operation may be very expensive. If you are grouping in order to perform an
* aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]]
@@ -585,6 +608,31 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
rightOuterJoin(other, new HashPartitioner(numPartitions))
}
+ /**
+ * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or
+ * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each
+ * element (k, w) in `other`, the resulting RDD will either contain all pairs
+ * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements
+ * in `this` have key k. Hash-partitions the resulting RDD using the existing partitioner/
+ * parallelism level.
+ */
+ def fullOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], Option[W]))] = {
+ fullOuterJoin(other, defaultPartitioner(self, other))
+ }
+
+ /**
+ * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or
+ * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each
+ * element (k, w) in `other`, the resulting RDD will either contain all pairs
+ * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements
+ * in `this` have key k. Hash-partitions the resulting RDD into the given number of partitions.
+ */
+ def fullOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Option[V], Option[W]))] = {
+ fullOuterJoin(other, new HashPartitioner(numPartitions))
+ }
+
/**
* Return the key-value pairs in this RDD to the master as a Map.
*
@@ -872,7 +920,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
hadoopConf.set("mapred.output.compression.codec", c.getCanonicalName)
hadoopConf.set("mapred.output.compression.type", CompressionType.BLOCK.toString)
}
- hadoopConf.setOutputCommitter(classOf[FileOutputCommitter])
+
+ // Use configured output committer if already set
+ if (conf.getOutputCommitter == null) {
+ hadoopConf.setOutputCommitter(classOf[FileOutputCommitter])
+ }
+
FileOutputFormat.setOutputPath(hadoopConf,
SparkHadoopWriter.createPathFromString(path, hadoopConf))
saveAsHadoopDataset(hadoopConf)
@@ -903,9 +956,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => {
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+ val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt
/* "reduce task" */
- val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId,
attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = outfmt.newInstance
@@ -974,9 +1027,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => {
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+ val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt
- writer.setup(context.stageId, context.partitionId, attemptNumber)
+ writer.setup(context.getStageId, context.getPartitionId, attemptNumber)
writer.open()
try {
var count = 0
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index 5d77d37378458..56ac7a69be0d3 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -131,7 +131,6 @@ private[spark] class PipedRDD[T: ClassTag](
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + command) {
override def run() {
- SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
// input the pipe context firstly
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index a9b905b0d1a63..2aba40d152e3e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -17,7 +17,7 @@
package org.apache.spark.rdd
-import java.util.Random
+import java.util.{Properties, Random}
import scala.collection.{mutable, Map}
import scala.collection.mutable.ArrayBuffer
@@ -41,7 +41,7 @@ import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{BoundedPriorityQueue, Utils}
+import org.apache.spark.util.{BoundedPriorityQueue, Utils, CallSite}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}
@@ -208,7 +208,7 @@ abstract class RDD[T: ClassTag](
}
/**
- * Get the preferred locations of a partition (as hostnames), taking into account whether the
+ * Get the preferred locations of a partition, taking into account whether the
* RDD is checkpointed.
*/
final def preferredLocations(split: Partition): Seq[String] = {
@@ -509,7 +509,8 @@ abstract class RDD[T: ClassTag](
/**
* Return an RDD of grouped items. Each group consists of a key and a sequence of elements
- * mapping to that key.
+ * mapping to that key. The ordering of elements within each group is not guaranteed, and
+ * may even differ each time the resulting RDD is evaluated.
*
* Note: This operation may be very expensive. If you are grouping in order to perform an
* aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]]
@@ -520,7 +521,8 @@ abstract class RDD[T: ClassTag](
/**
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
- * mapping to that key.
+ * mapping to that key. The ordering of elements within each group is not guaranteed, and
+ * may even differ each time the resulting RDD is evaluated.
*
* Note: This operation may be very expensive. If you are grouping in order to perform an
* aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]]
@@ -531,7 +533,8 @@ abstract class RDD[T: ClassTag](
/**
* Return an RDD of grouped items. Each group consists of a key and a sequence of elements
- * mapping to that key.
+ * mapping to that key. The ordering of elements within each group is not guaranteed, and
+ * may even differ each time the resulting RDD is evaluated.
*
* Note: This operation may be very expensive. If you are grouping in order to perform an
* aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]]
@@ -619,6 +622,7 @@ abstract class RDD[T: ClassTag](
* should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
@DeveloperApi
+ @deprecated("use TaskContext.get", "1.2.0")
def mapPartitionsWithContext[U: ClassTag](
f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = {
@@ -1027,8 +1031,14 @@ abstract class RDD[T: ClassTag](
* Zips this RDD with its element indices. The ordering is first based on the partition index
* and then the ordering of items within each partition. So the first item in the first
* partition gets index 0, and the last item in the last partition receives the largest index.
+ *
* This is similar to Scala's zipWithIndex but it uses Long instead of Int as the index type.
* This method needs to trigger a spark job when this RDD contains more than one partitions.
+ *
+ * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of
+ * elements in a partition. The index assigned to each element is therefore not guaranteed,
+ * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee
+ * the same index assignments, you should sort the RDD with sortByKey() or save it to a file.
*/
def zipWithIndex(): RDD[(T, Long)] = new ZippedWithIndexRDD(this)
@@ -1036,6 +1046,11 @@ abstract class RDD[T: ClassTag](
* Zips this RDD with generated unique Long ids. Items in the kth partition will get ids k, n+k,
* 2*n+k, ..., where n is the number of partitions. So there may exist gaps, but this method
* won't trigger a spark job, which is different from [[org.apache.spark.rdd.RDD#zipWithIndex]].
+ *
+ * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of
+ * elements in a partition. The unique ID assigned to each element is therefore not guaranteed,
+ * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee
+ * the same index assignments, you should sort the RDD with sortByKey() or save it to a file.
*/
def zipWithUniqueId(): RDD[(T, Long)] = {
val n = this.partitions.size.toLong
@@ -1224,7 +1239,8 @@ abstract class RDD[T: ClassTag](
private var storageLevel: StorageLevel = StorageLevel.NONE
/** User code that created this RDD (e.g. `textFile`, `parallelize`). */
- @transient private[spark] val creationSite = Utils.getCallSite
+ @transient private[spark] val creationSite = sc.getCallSite()
+
private[spark] def getCreationSite: String = Option(creationSite).map(_.shortForm).getOrElse("")
private[spark] def elementClassTag: ClassTag[T] = classTag[T]
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index b2774dfc47553..788eb1ff4e455 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -630,16 +630,17 @@ class DAGScheduler(
protected def runLocallyWithinThread(job: ActiveJob) {
var jobResult: JobResult = JobSucceeded
try {
- SparkEnv.set(env)
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
val taskContext =
- new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true)
+ new TaskContext(job.finalStage.id, job.partitions(0), 0, true)
+ TaskContext.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
} finally {
taskContext.markTaskCompleted()
+ TaskContext.unset()
}
} catch {
case e: Exception =>
@@ -1207,7 +1208,7 @@ class DAGScheduler(
.format(job.jobId, stageId))
} else if (jobsForStage.get.size == 1) {
if (!stageIdToStage.contains(stageId)) {
- logError("Missing Stage for stage with id $stageId")
+ logError(s"Missing Stage for stage with id $stageId")
} else {
// This is the only job that uses this stage, so fail the stage if it is running.
val stage = stageIdToStage(stageId)
@@ -1301,7 +1302,7 @@ class DAGScheduler(
// If the RDD has some placement preferences (as is the case for input RDDs), get those
val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
if (!rddPrefs.isEmpty) {
- return rddPrefs.map(host => TaskLocation(host))
+ return rddPrefs.map(TaskLocation(_))
}
// If the RDD has narrow dependencies, pick the first partition of the first narrow dep
// that has any placement preferences. Ideally we would choose based on transfer sizes,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
index 94944399b134a..12668b6c0988e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
@@ -22,10 +22,10 @@ import com.codahale.metrics.{Gauge,MetricRegistry}
import org.apache.spark.SparkContext
import org.apache.spark.metrics.source.Source
-private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: SparkContext)
+private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler)
extends Source {
override val metricRegistry = new MetricRegistry()
- override val sourceName = "%s.DAGScheduler".format(sc.appName)
+ override val sourceName = "DAGScheduler"
metricRegistry.register(MetricRegistry.name("stage", "failedStages"), new Gauge[Int] {
override def getValue: Int = dagScheduler.failedStages.size
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index 64b32ae0edaac..100c9ba9b7809 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -43,38 +43,29 @@ import org.apache.spark.util.{FileLogger, JsonProtocol, Utils}
* spark.eventLog.buffer.kb - Buffer size to use when writing to output streams
*/
private[spark] class EventLoggingListener(
- appName: String,
+ appId: String,
+ logBaseDir: String,
sparkConf: SparkConf,
hadoopConf: Configuration)
extends SparkListener with Logging {
import EventLoggingListener._
- def this(appName: String, sparkConf: SparkConf) =
- this(appName, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf))
+ def this(appId: String, logBaseDir: String, sparkConf: SparkConf) =
+ this(appId, logBaseDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf))
private val shouldCompress = sparkConf.getBoolean("spark.eventLog.compress", false)
private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false)
private val testing = sparkConf.getBoolean("spark.eventLog.testing", false)
private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024
- private val logBaseDir = sparkConf.get("spark.eventLog.dir", DEFAULT_LOG_DIR).stripSuffix("/")
- private val name = appName.replaceAll("[ :/]", "-").replaceAll("[${}'\"]", "_")
- .toLowerCase + "-" + System.currentTimeMillis
- val logDir = Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/")
-
+ val logDir = EventLoggingListener.getLogDirPath(logBaseDir, appId)
+ val logDirName: String = logDir.split("/").last
protected val logger = new FileLogger(logDir, sparkConf, hadoopConf, outputBufferSize,
shouldCompress, shouldOverwrite, Some(LOG_FILE_PERMISSIONS))
// For testing. Keep track of all JSON serialized events that have been logged.
private[scheduler] val loggedEvents = new ArrayBuffer[JValue]
- /**
- * Return only the unique application directory without the base directory.
- */
- def getApplicationLogDir(): String = {
- name
- }
-
/**
* Begin logging events.
* If compression is used, log a file that indicates which compression library is used.
@@ -184,6 +175,18 @@ private[spark] object EventLoggingListener extends Logging {
} else ""
}
+ /**
+ * Return a file-system-safe path to the log directory for the given application.
+ *
+ * @param logBaseDir A base directory for the path to the log directory for given application.
+ * @param appId A unique app ID.
+ * @return A path which consists of file-system-safe characters.
+ */
+ def getLogDirPath(logBaseDir: String, appId: String): String = {
+ val name = appId.replaceAll("[ :/]", "-").replaceAll("[${}'\"]", "_").toLowerCase
+ Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/")
+ }
+
/**
* Parse the event logging information associated with the logs in the given directory.
*
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 4d6b5c81883b6..54904bffdf10b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -20,15 +20,12 @@ package org.apache.spark.scheduler
import java.io.{File, FileNotFoundException, IOException, PrintWriter}
import java.text.SimpleDateFormat
import java.util.{Date, Properties}
-import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.HashMap
import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.executor.{DataReadMethod, TaskMetrics}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.executor.TaskMetrics
/**
* :: DeveloperApi ::
@@ -62,24 +59,16 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
override def initialValue() = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
}
- private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent]
createLogDir()
- // The following 5 functions are used only in testing.
- private[scheduler] def getLogDir = logDir
- private[scheduler] def getJobIdToPrintWriter = jobIdToPrintWriter
- private[scheduler] def getStageIdToJobId = stageIdToJobId
- private[scheduler] def getJobIdToStageIds = jobIdToStageIds
- private[scheduler] def getEventQueue = eventQueue
-
/** Create a folder for log files, the folder's name is the creation time of jobLogger */
protected def createLogDir() {
val dir = new File(logDir + "/" + logDirName + "/")
if (dir.exists()) {
return
}
- if (dir.mkdirs() == false) {
+ if (!dir.mkdirs()) {
// JobLogger should throw a exception rather than continue to construct this object.
throw new IOException("create log directory error:" + logDir + "/" + logDirName + "/")
}
@@ -171,7 +160,6 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
}
val shuffleReadMetrics = taskMetrics.shuffleReadMetrics match {
case Some(metrics) =>
- " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
" BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
" BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
" BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
@@ -262,7 +250,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
protected def recordJobProperties(jobId: Int, properties: Properties) {
if (properties != null) {
val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
- jobLogInfo(jobId, description, false)
+ jobLogInfo(jobId, description, withTime = false)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index d3f63ff92ac6f..e25096ea92d70 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -24,22 +24,123 @@ import org.apache.spark.storage.BlockManagerId
/**
* Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the
* task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
- * The map output sizes are compressed using MapOutputTracker.compressSize.
*/
-private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte])
- extends Externalizable {
+private[spark] sealed trait MapStatus {
+ /** Location where this task was run. */
+ def location: BlockManagerId
- def this() = this(null, null) // For deserialization only
+ /** Estimated size for the reduce block, in bytes. */
+ def getSizeForBlock(reduceId: Int): Long
+}
+
+
+private[spark] object MapStatus {
+
+ def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = {
+ if (uncompressedSizes.length > 2000) {
+ new HighlyCompressedMapStatus(loc, uncompressedSizes)
+ } else {
+ new CompressedMapStatus(loc, uncompressedSizes)
+ }
+ }
+
+ private[this] val LOG_BASE = 1.1
+
+ /**
+ * Compress a size in bytes to 8 bits for efficient reporting of map output sizes.
+ * We do this by encoding the log base 1.1 of the size as an integer, which can support
+ * sizes up to 35 GB with at most 10% error.
+ */
+ def compressSize(size: Long): Byte = {
+ if (size == 0) {
+ 0
+ } else if (size <= 1L) {
+ 1
+ } else {
+ math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte
+ }
+ }
+
+ /**
+ * Decompress an 8-bit encoded block size, using the reverse operation of compressSize.
+ */
+ def decompressSize(compressedSize: Byte): Long = {
+ if (compressedSize == 0) {
+ 0
+ } else {
+ math.pow(LOG_BASE, compressedSize & 0xFF).toLong
+ }
+ }
+}
+
+
+/**
+ * A [[MapStatus]] implementation that tracks the size of each block. Size for each block is
+ * represented using a single byte.
+ *
+ * @param loc location where the task is being executed.
+ * @param compressedSizes size of the blocks, indexed by reduce partition id.
+ */
+private[spark] class CompressedMapStatus(
+ private[this] var loc: BlockManagerId,
+ private[this] var compressedSizes: Array[Byte])
+ extends MapStatus with Externalizable {
+
+ protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only
+
+ def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) {
+ this(loc, uncompressedSizes.map(MapStatus.compressSize))
+ }
- def writeExternal(out: ObjectOutput) {
- location.writeExternal(out)
+ override def location: BlockManagerId = loc
+
+ override def getSizeForBlock(reduceId: Int): Long = {
+ MapStatus.decompressSize(compressedSizes(reduceId))
+ }
+
+ override def writeExternal(out: ObjectOutput): Unit = {
+ loc.writeExternal(out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
}
- def readExternal(in: ObjectInput) {
- location = BlockManagerId(in)
- compressedSizes = new Array[Byte](in.readInt())
+ override def readExternal(in: ObjectInput): Unit = {
+ loc = BlockManagerId(in)
+ val len = in.readInt()
+ compressedSizes = new Array[Byte](len)
in.readFully(compressedSizes)
}
}
+
+
+/**
+ * A [[MapStatus]] implementation that only stores the average size of the blocks.
+ *
+ * @param loc location where the task is being executed.
+ * @param avgSize average size of all the blocks
+ */
+private[spark] class HighlyCompressedMapStatus(
+ private[this] var loc: BlockManagerId,
+ private[this] var avgSize: Long)
+ extends MapStatus with Externalizable {
+
+ def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) {
+ this(loc, uncompressedSizes.sum / uncompressedSizes.length)
+ }
+
+ protected def this() = this(null, 0L) // For deserialization only
+
+ override def location: BlockManagerId = loc
+
+ override def getSizeForBlock(reduceId: Int): Long = avgSize
+
+ override def writeExternal(out: ObjectOutput): Unit = {
+ loc.writeExternal(out)
+ out.writeLong(avgSize)
+ }
+
+ override def readExternal(in: ObjectInput): Unit = {
+ loc = BlockManagerId(in)
+ avgSize = in.readLong()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 2ccbd8edeb028..4a9ff918afe25 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -58,11 +58,7 @@ private[spark] class ResultTask[T, U](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
metrics = Some(context.taskMetrics)
- try {
- func(context, rdd.iterator(partition, context))
- } finally {
- context.markTaskCompleted()
- }
+ func(context, rdd.iterator(partition, context))
}
// This is only callable on the driver side.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
index a0be8307eff27..992c477493d8e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
@@ -23,6 +23,8 @@ package org.apache.spark.scheduler
* machines become available and can launch tasks on them.
*/
private[spark] trait SchedulerBackend {
+ private val appId = "spark-application-" + System.currentTimeMillis
+
def start(): Unit
def stop(): Unit
def reviveOffers(): Unit
@@ -33,10 +35,10 @@ private[spark] trait SchedulerBackend {
def isReady(): Boolean = true
/**
- * The application ID associated with the job, if any.
+ * Get an application ID associated with the job.
*
- * @return The application ID, or None if the backend does not provide an ID.
+ * @return An application ID
*/
- def applicationId(): Option[String] = None
+ def applicationId(): String = appId
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 381eff2147e95..79709089c0da4 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -69,12 +69,15 @@ private[spark] class ShuffleMapTask(
return writer.stop(success = true).get
} catch {
case e: Exception =>
- if (writer != null) {
- writer.stop(success = false)
+ try {
+ if (writer != null) {
+ writer.stop(success = false)
+ }
+ } catch {
+ case e: Exception =>
+ log.debug("Could not stop writer", e)
}
throw e
- } finally {
- context.markTaskCompleted()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 6aa0cca06878d..c6e47c84a0cb2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -45,13 +45,19 @@ import org.apache.spark.util.Utils
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
final def run(attemptId: Long): T = {
- context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
+ context = new TaskContext(stageId, partitionId, attemptId, false)
+ TaskContext.setTaskContext(context)
context.taskMetrics.hostname = Utils.localHostName()
taskThread = Thread.currentThread()
if (_killed) {
kill(interruptThread = false)
}
- runTask(context)
+ try {
+ runTask(context)
+ } finally {
+ context.markTaskCompleted()
+ TaskContext.unset()
+ }
}
def runTask(context: TaskContext): T
@@ -92,7 +98,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
if (interruptThread && taskThread != null) {
taskThread.interrupt()
}
- }
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
index 67c9a6760b1b3..10c685f29d3ac 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
@@ -22,13 +22,51 @@ package org.apache.spark.scheduler
* In the latter case, we will prefer to launch the task on that executorID, but our next level
* of preference will be executors on the same host if this is not possible.
*/
-private[spark]
-class TaskLocation private (val host: String, val executorId: Option[String]) extends Serializable {
- override def toString: String = "TaskLocation(" + host + ", " + executorId + ")"
+private[spark] sealed trait TaskLocation {
+ def host: String
+}
+
+/**
+ * A location that includes both a host and an executor id on that host.
+ */
+private [spark] case class ExecutorCacheTaskLocation(override val host: String,
+ val executorId: String) extends TaskLocation {
+}
+
+/**
+ * A location on a host.
+ */
+private [spark] case class HostTaskLocation(override val host: String) extends TaskLocation {
+ override def toString = host
+}
+
+/**
+ * A location on a host that is cached by HDFS.
+ */
+private [spark] case class HDFSCacheTaskLocation(override val host: String)
+ extends TaskLocation {
+ override def toString = TaskLocation.inMemoryLocationTag + host
}
private[spark] object TaskLocation {
- def apply(host: String, executorId: String) = new TaskLocation(host, Some(executorId))
+ // We identify hosts on which the block is cached with this prefix. Because this prefix contains
+ // underscores, which are not legal characters in hostnames, there should be no potential for
+ // confusion. See RFC 952 and RFC 1123 for information about the format of hostnames.
+ val inMemoryLocationTag = "hdfs_cache_"
+
+ def apply(host: String, executorId: String) = new ExecutorCacheTaskLocation(host, executorId)
- def apply(host: String) = new TaskLocation(host, None)
+ /**
+ * Create a TaskLocation from a string returned by getPreferredLocations.
+ * These strings have the form [hostname] or hdfs_cache_[hostname], depending on whether the
+ * location is cached.
+ */
+ def apply(str: String) = {
+ val hstr = str.stripPrefix(inMemoryLocationTag)
+ if (hstr.equals(str)) {
+ new HostTaskLocation(str)
+ } else {
+ new HostTaskLocation(hstr)
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index df59f444b7a0e..3f345ceeaaf7a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -19,6 +19,8 @@ package org.apache.spark.scheduler
import java.nio.ByteBuffer
+import scala.util.control.NonFatal
+
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.serializer.SerializerInstance
@@ -32,7 +34,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4)
private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool(
- THREADS, "Result resolver thread")
+ THREADS, "task-result-getter")
protected val serializer = new ThreadLocal[SerializerInstance] {
override def initialValue(): SerializerInstance = {
@@ -70,7 +72,8 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
case cnf: ClassNotFoundException =>
val loader = Thread.currentThread.getContextClassLoader
taskSetManager.abort("ClassNotFound with classloader: " + loader)
- case ex: Exception =>
+ // Matching NonFatal so we don't catch the ControlThrowable from the "return" above.
+ case NonFatal(ex) =>
logError("Exception while getting task result", ex)
taskSetManager.abort("Exception while getting task result: %s".format(ex))
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 1c1ce666eab0f..a129a434c9a1a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -31,6 +31,8 @@ import org.apache.spark.storage.BlockManagerId
*/
private[spark] trait TaskScheduler {
+ private val appId = "spark-application-" + System.currentTimeMillis
+
def rootPool: Pool
def schedulingMode: SchedulingMode
@@ -66,10 +68,10 @@ private[spark] trait TaskScheduler {
blockManagerId: BlockManagerId): Boolean
/**
- * The application ID associated with the job, if any.
+ * Get an application ID associated with the job.
*
- * @return The application ID, or None if the backend does not provide an ID.
+ * @return An application ID
*/
- def applicationId(): Option[String] = None
+ def applicationId(): String = appId
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 633e892554c50..6d697e3d003f6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -216,8 +216,6 @@ private[spark] class TaskSchedulerImpl(
* that tasks are balanced across the cluster.
*/
def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
- SparkEnv.set(sc.env)
-
// Mark each slave as alive and remember its hostname
// Also track if new executor is added
var newExecAvail = false
@@ -492,7 +490,7 @@ private[spark] class TaskSchedulerImpl(
}
}
- override def applicationId(): Option[String] = backend.applicationId()
+ override def applicationId(): String = backend.applicationId()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index d9d53faf843ff..a6c23fc85a1b0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -181,8 +181,24 @@ private[spark] class TaskSetManager(
}
for (loc <- tasks(index).preferredLocations) {
- for (execId <- loc.executorId) {
- addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
+ loc match {
+ case e: ExecutorCacheTaskLocation =>
+ addTo(pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer))
+ case e: HDFSCacheTaskLocation => {
+ val exe = sched.getExecutorsAliveOnHost(loc.host)
+ exe match {
+ case Some(set) => {
+ for (e <- set) {
+ addTo(pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer))
+ }
+ logInfo(s"Pending task $index has a cached location at ${e.host} " +
+ ", where there are executors " + set.mkString(","))
+ }
+ case None => logDebug(s"Pending task $index has a cached location at ${e.host} " +
+ ", but there are no executors alive there.")
+ }
+ }
+ case _ => Unit
}
addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
for (rack <- sched.getRackForHost(loc.host)) {
@@ -283,7 +299,10 @@ private[spark] class TaskSetManager(
// on multiple nodes when we replicate cached blocks, as in Spark Streaming
for (index <- speculatableTasks if canRunOnHost(index)) {
val prefs = tasks(index).preferredLocations
- val executors = prefs.flatMap(_.executorId)
+ val executors = prefs.flatMap(_ match {
+ case e: ExecutorCacheTaskLocation => Some(e.executorId)
+ case _ => None
+ });
if (executors.contains(execId)) {
speculatableTasks -= index
return Some((index, TaskLocality.PROCESS_LOCAL))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 6abf6d930c155..fb8160abc59db 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -66,7 +66,7 @@ private[spark] object CoarseGrainedClusterMessages {
case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage
- case class AddWebUIFilter(filterName:String, filterParams: String, proxyBase :String)
+ case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase :String)
extends CoarseGrainedClusterMessage
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 9a0cb1c6c6ccd..59aed6b72fe42 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -62,15 +62,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
val createTime = System.currentTimeMillis()
class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive {
-
override protected def log = CoarseGrainedSchedulerBackend.this.log
-
- private val executorActor = new HashMap[String, ActorRef]
- private val executorAddress = new HashMap[String, Address]
- private val executorHost = new HashMap[String, String]
- private val freeCores = new HashMap[String, Int]
- private val totalCores = new HashMap[String, Int]
private val addressToExecutorId = new HashMap[Address, String]
+ private val executorDataMap = new HashMap[String, ExecutorData]
override def preStart() {
// Listen for remote client disconnection events, since they don't go through Akka's watch()
@@ -85,16 +79,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
def receiveWithLogging = {
case RegisterExecutor(executorId, hostPort, cores) =>
Utils.checkHostPort(hostPort, "Host port expected " + hostPort)
- if (executorActor.contains(executorId)) {
+ if (executorDataMap.contains(executorId)) {
sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId)
} else {
logInfo("Registered executor: " + sender + " with ID " + executorId)
sender ! RegisteredExecutor
- executorActor(executorId) = sender
- executorHost(executorId) = Utils.parseHostPort(hostPort)._1
- totalCores(executorId) = cores
- freeCores(executorId) = cores
- executorAddress(executorId) = sender.path.address
+ executorDataMap.put(executorId, new ExecutorData(sender, sender.path.address,
+ Utils.parseHostPort(hostPort)._1, cores, cores))
+
addressToExecutorId(sender.path.address) = executorId
totalCoreCount.addAndGet(cores)
totalRegisteredExecutors.addAndGet(1)
@@ -104,13 +96,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
case StatusUpdate(executorId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) {
- if (executorActor.contains(executorId)) {
- freeCores(executorId) += scheduler.CPUS_PER_TASK
- makeOffers(executorId)
- } else {
- // Ignoring the update since we don't know about the executor.
- val msg = "Ignored task status update (%d state %s) from unknown executor %s with ID %s"
- logWarning(msg.format(taskId, state, sender, executorId))
+ executorDataMap.get(executorId) match {
+ case Some(executorInfo) =>
+ executorInfo.freeCores += scheduler.CPUS_PER_TASK
+ makeOffers(executorId)
+ case None =>
+ // Ignoring the update since we don't know about the executor.
+ logWarning(s"Ignored task status update ($taskId state $state) " +
+ "from unknown executor $sender with ID $executorId")
}
}
@@ -118,7 +111,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
makeOffers()
case KillTask(taskId, executorId, interruptThread) =>
- executorActor(executorId) ! KillTask(taskId, executorId, interruptThread)
+ executorDataMap(executorId).executorActor ! KillTask(taskId, executorId, interruptThread)
case StopDriver =>
sender ! true
@@ -126,8 +119,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
case StopExecutors =>
logInfo("Asking each executor to shut down")
- for (executor <- executorActor.values) {
- executor ! StopExecutor
+ for ((_, executorData) <- executorDataMap) {
+ executorData.executorActor ! StopExecutor
}
sender ! true
@@ -138,6 +131,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
case AddWebUIFilter(filterName, filterParams, proxyBase) =>
addWebUIFilter(filterName, filterParams, proxyBase)
sender ! true
+
case DisassociatedEvent(_, address, _) =>
addressToExecutorId.get(address).foreach(removeExecutor(_,
"remote Akka client disassociated"))
@@ -148,14 +142,16 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
// Make fake resource offers on all executors
def makeOffers() {
- launchTasks(scheduler.resourceOffers(
- executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))}))
+ launchTasks(scheduler.resourceOffers(executorDataMap.map { case (id, executorData) =>
+ new WorkerOffer(id, executorData.executorHost, executorData.freeCores)
+ }.toSeq))
}
// Make fake resource offers on just one executor
def makeOffers(executorId: String) {
+ val executorData = executorDataMap(executorId)
launchTasks(scheduler.resourceOffers(
- Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId)))))
+ Seq(new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores))))
}
// Launch tasks returned by a set of resource offers
@@ -179,25 +175,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
}
}
else {
- freeCores(task.executorId) -= scheduler.CPUS_PER_TASK
- executorActor(task.executorId) ! LaunchTask(new SerializableBuffer(serializedTask))
+ val executorData = executorDataMap(task.executorId)
+ executorData.freeCores -= scheduler.CPUS_PER_TASK
+ executorData.executorActor ! LaunchTask(new SerializableBuffer(serializedTask))
}
}
}
// Remove a disconnected slave from the cluster
def removeExecutor(executorId: String, reason: String) {
- if (executorActor.contains(executorId)) {
- logInfo("Executor " + executorId + " disconnected, so removing it")
- val numCores = totalCores(executorId)
- executorActor -= executorId
- executorHost -= executorId
- addressToExecutorId -= executorAddress(executorId)
- executorAddress -= executorId
- totalCores -= executorId
- freeCores -= executorId
- totalCoreCount.addAndGet(-numCores)
- scheduler.executorLost(executorId, SlaveLost(reason))
+ executorDataMap.get(executorId) match {
+ case Some(executorInfo) =>
+ executorDataMap -= executorId
+ totalCoreCount.addAndGet(-executorInfo.totalCores)
+ scheduler.executorLost(executorId, SlaveLost(reason))
+ case None => logError(s"Asked to remove non existant executor $executorId")
}
}
}
@@ -283,15 +275,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
}
// Add filters to the SparkUI
- def addWebUIFilter(filterName: String, filterParams: String, proxyBase: String) {
+ def addWebUIFilter(filterName: String, filterParams: Map[String, String], proxyBase: String) {
if (proxyBase != null && proxyBase.nonEmpty) {
System.setProperty("spark.ui.proxyBase", proxyBase)
}
- if (Seq(filterName, filterParams).forall(t => t != null && t.nonEmpty)) {
+ val hasFilter = (filterName != null && filterName.nonEmpty &&
+ filterParams != null && filterParams.nonEmpty)
+ if (hasFilter) {
logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase")
conf.set("spark.ui.filters", filterName)
- conf.set(s"spark.$filterName.params", filterParams)
+ filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) }
scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) }
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
new file mode 100644
index 0000000000000..b71bd5783d6df
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import akka.actor.{Address, ActorRef}
+
+/**
+ * Grouping of data for an executor used by CoarseGrainedSchedulerBackend.
+ *
+ * @param executorActor The ActorRef representing this executor
+ * @param executorAddress The network address of this executor
+ * @param executorHost The hostname that this executor is running on
+ * @param freeCores The current number of cores available for work on the executor
+ * @param totalCores The total number of cores available to the executor
+ */
+private[cluster] class ExecutorData(
+ val executorActor: ActorRef,
+ val executorAddress: Address,
+ val executorHost: String ,
+ var freeCores: Int,
+ val totalCores: Int
+)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index 2f45d192e1d4d..ed209d195ec9d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -34,7 +34,7 @@ private[spark] class SparkDeploySchedulerBackend(
var client: AppClient = null
var stopping = false
var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _
- var appId: String = _
+ @volatile var appId: String = _
val registrationLock = new Object()
var registrationDone = false
@@ -68,9 +68,8 @@ private[spark] class SparkDeploySchedulerBackend(
val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend",
args, sc.executorEnvs, classPathEntries, libraryPathEntries, javaOpts)
val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("")
- val eventLogDir = sc.eventLogger.map(_.logDir)
val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command,
- appUIAddress, eventLogDir)
+ appUIAddress, sc.eventLogDir)
client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf)
client.start()
@@ -129,7 +128,11 @@ private[spark] class SparkDeploySchedulerBackend(
totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio
}
- override def applicationId(): Option[String] = Option(appId)
+ override def applicationId(): String =
+ Option(appId).getOrElse {
+ logWarning("Application ID is not initialized yet.")
+ super.applicationId
+ }
private def waitForRegistration() = {
registrationLock.synchronized {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 64568409dbafd..90828578cd88f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -76,6 +76,8 @@ private[spark] class CoarseMesosSchedulerBackend(
var nextMesosTaskId = 0
+ @volatile var appId: String = _
+
def newMesosTaskId(): Int = {
val id = nextMesosTaskId
nextMesosTaskId += 1
@@ -167,7 +169,8 @@ private[spark] class CoarseMesosSchedulerBackend(
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
- logInfo("Registered as framework ID " + frameworkId.getValue)
+ appId = frameworkId.getValue
+ logInfo("Registered as framework ID " + appId)
registeredLock.synchronized {
isRegistered = true
registeredLock.notifyAll()
@@ -198,7 +201,9 @@ private[spark] class CoarseMesosSchedulerBackend(
val slaveId = offer.getSlaveId.toString
val mem = getResource(offer.getResourcesList, "mem")
val cpus = getResource(offer.getResourcesList, "cpus").toInt
- if (totalCoresAcquired < maxCores && mem >= sc.executorMemory && cpus >= 1 &&
+ if (totalCoresAcquired < maxCores &&
+ mem >= MemoryUtils.calculateTotalMemory(sc) &&
+ cpus >= 1 &&
failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES &&
!slaveIdsWithExecutors.contains(slaveId)) {
// Launch an executor on the slave
@@ -214,7 +219,8 @@ private[spark] class CoarseMesosSchedulerBackend(
.setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave))
.setName("Task " + taskId)
.addResources(createResource("cpus", cpusToUse))
- .addResources(createResource("mem", sc.executorMemory))
+ .addResources(createResource("mem",
+ MemoryUtils.calculateTotalMemory(sc)))
.build()
d.launchTasks(
Collections.singleton(offer.getId), Collections.singletonList(task), filters)
@@ -310,4 +316,10 @@ private[spark] class CoarseMesosSchedulerBackend(
slaveLost(d, s)
}
+ override def applicationId(): String =
+ Option(appId).getOrElse {
+ logWarning("Application ID is not initialized yet.")
+ super.applicationId
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala
new file mode 100644
index 0000000000000..5101ec8352e79
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster.mesos
+
+import org.apache.spark.SparkContext
+
+private[spark] object MemoryUtils {
+ // These defaults copied from YARN
+ val OVERHEAD_FRACTION = 1.07
+ val OVERHEAD_MINIMUM = 384
+
+ def calculateTotalMemory(sc: SparkContext) = {
+ math.max(
+ sc.conf.getOption("spark.mesos.executor.memoryOverhead")
+ .getOrElse(OVERHEAD_MINIMUM.toString)
+ .toInt + sc.executorMemory,
+ OVERHEAD_FRACTION * sc.executorMemory
+ )
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index a9ef126f5de0e..e0f2fd622f54c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -30,7 +30,7 @@ import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
import org.apache.spark.{Logging, SparkContext, SparkException, TaskState}
-import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer}
+import org.apache.spark.scheduler._
import org.apache.spark.util.Utils
/**
@@ -62,6 +62,8 @@ private[spark] class MesosSchedulerBackend(
var classLoader: ClassLoader = null
+ @volatile var appId: String = _
+
override def start() {
synchronized {
classLoader = Thread.currentThread.getContextClassLoader
@@ -124,15 +126,24 @@ private[spark] class MesosSchedulerBackend(
command.setValue("cd %s*; ./sbin/spark-executor".format(basename))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
+ val cpus = Resource.newBuilder()
+ .setName("cpus")
+ .setType(Value.Type.SCALAR)
+ .setScalar(Value.Scalar.newBuilder()
+ .setValue(scheduler.CPUS_PER_TASK).build())
+ .build()
val memory = Resource.newBuilder()
.setName("mem")
.setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(sc.executorMemory).build())
+ .setScalar(
+ Value.Scalar.newBuilder()
+ .setValue(MemoryUtils.calculateTotalMemory(sc)).build())
.build()
ExecutorInfo.newBuilder()
.setExecutorId(ExecutorID.newBuilder().setValue(execId).build())
.setCommand(command)
.setData(ByteString.copyFrom(createExecArg()))
+ .addResources(cpus)
.addResources(memory)
.build()
}
@@ -168,7 +179,8 @@ private[spark] class MesosSchedulerBackend(
override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
val oldClassLoader = setClassLoader()
try {
- logInfo("Registered as framework ID " + frameworkId.getValue)
+ appId = frameworkId.getValue
+ logInfo("Registered as framework ID " + appId)
registeredLock.synchronized {
isRegistered = true
registeredLock.notifyAll()
@@ -204,18 +216,31 @@ private[spark] class MesosSchedulerBackend(
val offerableWorkers = new ArrayBuffer[WorkerOffer]
val offerableIndices = new HashMap[String, Int]
- def enoughMemory(o: Offer) = {
+ def sufficientOffer(o: Offer) = {
val mem = getResource(o.getResourcesList, "mem")
+ val cpus = getResource(o.getResourcesList, "cpus")
val slaveId = o.getSlaveId.getValue
- mem >= sc.executorMemory || slaveIdsWithExecutors.contains(slaveId)
+ (mem >= MemoryUtils.calculateTotalMemory(sc) &&
+ // need at least 1 for executor, 1 for task
+ cpus >= 2 * scheduler.CPUS_PER_TASK) ||
+ (slaveIdsWithExecutors.contains(slaveId) &&
+ cpus >= scheduler.CPUS_PER_TASK)
}
- for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) {
- offerableIndices.put(offer.getSlaveId.getValue, index)
+ for ((offer, index) <- offers.zipWithIndex if sufficientOffer(offer)) {
+ val slaveId = offer.getSlaveId.getValue
+ offerableIndices.put(slaveId, index)
+ val cpus = if (slaveIdsWithExecutors.contains(slaveId)) {
+ getResource(offer.getResourcesList, "cpus").toInt
+ } else {
+ // If the executor doesn't exist yet, subtract CPU for executor
+ getResource(offer.getResourcesList, "cpus").toInt -
+ scheduler.CPUS_PER_TASK
+ }
offerableWorkers += new WorkerOffer(
offer.getSlaveId.getValue,
offer.getHostname,
- getResource(offer.getResourcesList, "cpus").toInt)
+ cpus)
}
// Call into the TaskSchedulerImpl
@@ -347,7 +372,20 @@ private[spark] class MesosSchedulerBackend(
recordSlaveLost(d, slaveId, ExecutorExited(status))
}
+ override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {
+ driver.killTask(
+ TaskID.newBuilder()
+ .setValue(taskId.toString).build()
+ )
+ }
+
// TODO: query Mesos for number of cores
override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8)
+ override def applicationId(): String =
+ Option(appId).getOrElse {
+ logWarning("Application ID is not initialized yet.")
+ super.applicationId
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index 9ea25c2bc7090..58b78f041cd85 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -88,6 +88,7 @@ private[spark] class LocalActor(
private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int)
extends SchedulerBackend with ExecutorBackend {
+ private val appId = "local-" + System.currentTimeMillis
var localActor: ActorRef = null
override def start() {
@@ -115,4 +116,6 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores:
localActor ! StatusUpdate(taskId, state, serializedData)
}
+ override def applicationId(): String = appId
+
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 4b9454d75abb7..746ed33b54c00 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -103,13 +103,11 @@ private[spark] class HashShuffleWriter[K, V](
private def commitWritesAndBuildStatus(): MapStatus = {
// Commit the writes. Get the size of each bucket block (total block size).
- val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter =>
+ val sizes: Array[Long] = shuffle.writers.map { writer: BlockObjectWriter =>
writer.commitAndClose()
- val size = writer.fileSegment().length
- MapOutputTracker.compressSize(size)
+ writer.fileSegment().length
}
-
- new MapStatus(blockManager.blockManagerId, compressedSizes)
+ MapStatus(blockManager.blockManagerId, sizes)
}
private def revertWrites(): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 89a78d6982ba0..927481b72cf4f 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -70,8 +70,7 @@ private[spark] class SortShuffleWriter[K, V, C](
val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
- mapStatus = new MapStatus(blockManager.blockManagerId,
- partitionLengths.map(MapOutputTracker.compressSize))
+ mapStatus = MapStatus(blockManager.blockManagerId, partitionLengths)
}
/** Close this writer, passing along whether the map completed */
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index d1bee3d2c033c..3f5d06e1aeee7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -22,6 +22,7 @@ import java.nio.{ByteBuffer, MappedByteBuffer}
import scala.concurrent.ExecutionContext.Implicits.global
+import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
@@ -112,6 +113,11 @@ private[spark] class BlockManager(
private val broadcastCleaner = new MetadataCleaner(
MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks, conf)
+ // Field related to peer block managers that are necessary for block replication
+ @volatile private var cachedPeers: Seq[BlockManagerId] = _
+ private val peerFetchLock = new Object
+ private var lastPeerFetchTime = 0L
+
initialize()
/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
@@ -787,31 +793,111 @@ private[spark] class BlockManager(
}
/**
- * Replicate block to another node.
+ * Get peer block managers in the system.
+ */
+ private def getPeers(forceFetch: Boolean): Seq[BlockManagerId] = {
+ peerFetchLock.synchronized {
+ val cachedPeersTtl = conf.getInt("spark.storage.cachedPeersTtl", 60 * 1000) // milliseconds
+ val timeout = System.currentTimeMillis - lastPeerFetchTime > cachedPeersTtl
+ if (cachedPeers == null || forceFetch || timeout) {
+ cachedPeers = master.getPeers(blockManagerId).sortBy(_.hashCode)
+ lastPeerFetchTime = System.currentTimeMillis
+ logDebug("Fetched peers from master: " + cachedPeers.mkString("[", ",", "]"))
+ }
+ cachedPeers
+ }
+ }
+
+ /**
+ * Replicate block to another node. Not that this is a blocking call that returns after
+ * the block has been replicated.
*/
- @volatile var cachedPeers: Seq[BlockManagerId] = null
private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel): Unit = {
+ val maxReplicationFailures = conf.getInt("spark.storage.maxReplicationFailures", 1)
+ val numPeersToReplicateTo = level.replication - 1
+ val peersForReplication = new ArrayBuffer[BlockManagerId]
+ val peersReplicatedTo = new ArrayBuffer[BlockManagerId]
+ val peersFailedToReplicateTo = new ArrayBuffer[BlockManagerId]
val tLevel = StorageLevel(
level.useDisk, level.useMemory, level.useOffHeap, level.deserialized, 1)
- if (cachedPeers == null) {
- cachedPeers = master.getPeers(blockManagerId, level.replication - 1)
+ val startTime = System.currentTimeMillis
+ val random = new Random(blockId.hashCode)
+
+ var replicationFailed = false
+ var failures = 0
+ var done = false
+
+ // Get cached list of peers
+ peersForReplication ++= getPeers(forceFetch = false)
+
+ // Get a random peer. Note that this selection of a peer is deterministic on the block id.
+ // So assuming the list of peers does not change and no replication failures,
+ // if there are multiple attempts in the same node to replicate the same block,
+ // the same set of peers will be selected.
+ def getRandomPeer(): Option[BlockManagerId] = {
+ // If replication had failed, then force update the cached list of peers and remove the peers
+ // that have been already used
+ if (replicationFailed) {
+ peersForReplication.clear()
+ peersForReplication ++= getPeers(forceFetch = true)
+ peersForReplication --= peersReplicatedTo
+ peersForReplication --= peersFailedToReplicateTo
+ }
+ if (!peersForReplication.isEmpty) {
+ Some(peersForReplication(random.nextInt(peersForReplication.size)))
+ } else {
+ None
+ }
}
- for (peer: BlockManagerId <- cachedPeers) {
- val start = System.nanoTime
- data.rewind()
- logDebug(s"Try to replicate $blockId once; The size of the data is ${data.limit()} Bytes. " +
- s"To node: $peer")
- try {
- blockTransferService.uploadBlockSync(
- peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel)
- } catch {
- case e: Exception =>
- logError(s"Failed to replicate block to $peer", e)
+ // One by one choose a random peer and try uploading the block to it
+ // If replication fails (e.g., target peer is down), force the list of cached peers
+ // to be re-fetched from driver and then pick another random peer for replication. Also
+ // temporarily black list the peer for which replication failed.
+ //
+ // This selection of a peer and replication is continued in a loop until one of the
+ // following 3 conditions is fulfilled:
+ // (i) specified number of peers have been replicated to
+ // (ii) too many failures in replicating to peers
+ // (iii) no peer left to replicate to
+ //
+ while (!done) {
+ getRandomPeer() match {
+ case Some(peer) =>
+ try {
+ val onePeerStartTime = System.currentTimeMillis
+ data.rewind()
+ logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer")
+ blockTransferService.uploadBlockSync(
+ peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel)
+ logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %f ms"
+ .format((System.currentTimeMillis - onePeerStartTime)))
+ peersReplicatedTo += peer
+ peersForReplication -= peer
+ replicationFailed = false
+ if (peersReplicatedTo.size == numPeersToReplicateTo) {
+ done = true // specified number of peers have been replicated to
+ }
+ } catch {
+ case e: Exception =>
+ logWarning(s"Failed to replicate $blockId to $peer, failure #$failures", e)
+ failures += 1
+ replicationFailed = true
+ peersFailedToReplicateTo += peer
+ if (failures > maxReplicationFailures) { // too many failures in replcating to peers
+ done = true
+ }
+ }
+ case None => // no peer left to replicate to
+ done = true
}
-
- logDebug("Replicating BlockId %s once used %fs; The size of the data is %d bytes."
- .format(blockId, (System.nanoTime - start) / 1e6, data.limit()))
+ }
+ val timeTakeMs = (System.currentTimeMillis - startTime)
+ logDebug(s"Replicating $blockId of ${data.limit()} bytes to " +
+ s"${peersReplicatedTo.size} peer(s) took $timeTakeMs ms")
+ if (peersReplicatedTo.size < numPeersToReplicateTo) {
+ logWarning(s"Block $blockId replicated to only " +
+ s"${peersReplicatedTo.size} peer(s) instead of $numPeersToReplicateTo peers")
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index d4487fce49ab6..142285094342c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -59,6 +59,8 @@ class BlockManagerId private (
def port: Int = port_
+ def isDriver: Boolean = (executorId == "")
+
override def writeExternal(out: ObjectOutput) {
out.writeUTF(executorId_)
out.writeUTF(host_)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 2e262594b3538..d08e1419e3e41 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -84,13 +84,8 @@ class BlockManagerMaster(
}
/** Get ids of other nodes in the cluster from the driver */
- def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = {
- val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers))
- if (result.length != numPeers) {
- throw new SparkException(
- "Error getting peers, only got " + result.size + " instead of " + numPeers)
- }
- result
+ def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = {
+ askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId))
}
/**
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 1a6c7cb24f9ac..6a06257ed0c08 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -83,8 +83,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
case GetLocationsMultipleBlockIds(blockIds) =>
sender ! getLocationsMultipleBlockIds(blockIds)
- case GetPeers(blockManagerId, size) =>
- sender ! getPeers(blockManagerId, size)
+ case GetPeers(blockManagerId) =>
+ sender ! getPeers(blockManagerId)
case GetMemoryStatus =>
sender ! memoryStatus
@@ -173,11 +173,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
* from the executors, but not from the driver.
*/
private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = {
- // TODO: Consolidate usages of
import context.dispatcher
val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver)
val requiredBlockManagers = blockManagerInfo.values.filter { info =>
- removeFromDriver || info.blockManagerId.executorId != ""
+ removeFromDriver || !info.blockManagerId.isDriver
}
Future.sequence(
requiredBlockManagers.map { bm =>
@@ -212,7 +211,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
val minSeenTime = now - slaveTimeout
val toRemove = new mutable.HashSet[BlockManagerId]
for (info <- blockManagerInfo.values) {
- if (info.lastSeenMs < minSeenTime && info.blockManagerId.executorId != "") {
+ if (info.lastSeenMs < minSeenTime && !info.blockManagerId.isDriver) {
logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: "
+ (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms")
toRemove += info.blockManagerId
@@ -232,7 +231,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
*/
private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = {
if (!blockManagerInfo.contains(blockManagerId)) {
- blockManagerId.executorId == "" && !isLocal
+ blockManagerId.isDriver && !isLocal
} else {
blockManagerInfo(blockManagerId).updateLastSeenMs()
true
@@ -355,7 +354,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
tachyonSize: Long) {
if (!blockManagerInfo.contains(blockManagerId)) {
- if (blockManagerId.executorId == "" && !isLocal) {
+ if (blockManagerId.isDriver && !isLocal) {
// We intentionally do not register the master (except in local mode),
// so we should not indicate failure.
sender ! true
@@ -403,16 +402,14 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
blockIds.map(blockId => getLocations(blockId))
}
- private def getPeers(blockManagerId: BlockManagerId, size: Int): Seq[BlockManagerId] = {
- val peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
-
- val selfIndex = peers.indexOf(blockManagerId)
- if (selfIndex == -1) {
- throw new SparkException("Self index for " + blockManagerId + " not found")
+ /** Get the list of the peers of the given block manager */
+ private def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = {
+ val blockManagerIds = blockManagerInfo.keySet
+ if (blockManagerIds.contains(blockManagerId)) {
+ blockManagerIds.filterNot { _.isDriver }.filterNot { _ == blockManagerId }.toSeq
+ } else {
+ Seq.empty
}
-
- // Note that this logic will select the same node multiple times if there aren't enough peers
- Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 2ba16b8476600..3db5dd9774ae8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -88,7 +88,7 @@ private[spark] object BlockManagerMessages {
case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster
- case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster
+ case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
case class RemoveExecutor(execId: String) extends ToBlockManagerMaster
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
index 14ae2f38c5670..8462871e798a5 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
@@ -58,9 +58,9 @@ class BlockManagerSlaveActor(
SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
}
- case RemoveBroadcast(broadcastId, tellMaster) =>
+ case RemoveBroadcast(broadcastId, _) =>
doAsync[Int]("removing broadcast " + broadcastId, sender) {
- blockManager.removeBroadcast(broadcastId, tellMaster)
+ blockManager.removeBroadcast(broadcastId, tellMaster = true)
}
case GetBlockStatus(blockId, _) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala
index 49fea6d9e2a76..8569c6f3cbbc3 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala
@@ -22,10 +22,10 @@ import com.codahale.metrics.{Gauge,MetricRegistry}
import org.apache.spark.SparkContext
import org.apache.spark.metrics.source.Source
-private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: SparkContext)
+private[spark] class BlockManagerSource(val blockManager: BlockManager)
extends Source {
override val metricRegistry = new MetricRegistry()
- override val sourceName = "%s.BlockManager".format(sc.appName)
+ override val sourceName = "BlockManager"
metricRegistry.register(MetricRegistry.name("memory", "maxMem_MB"), new Gauge[Long] {
override def getValue: Long = {
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index e9304f6bb45d0..bac459e835a3f 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -73,7 +73,21 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc
val startTime = System.currentTimeMillis
val file = diskManager.getFile(blockId)
val outputStream = new FileOutputStream(file)
- blockManager.dataSerializeStream(blockId, outputStream, values)
+ try {
+ try {
+ blockManager.dataSerializeStream(blockId, outputStream, values)
+ } finally {
+ // Close outputStream here because it should be closed before file is deleted.
+ outputStream.close()
+ }
+ } catch {
+ case e: Throwable =>
+ if (file.exists()) {
+ file.delete()
+ }
+ throw e
+ }
+
val length = file.length
val timeTaken = System.currentTimeMillis - startTime
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index 0a09c24d61879..edbc729c17ade 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -132,8 +132,6 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
PutResult(res.size, res.data, droppedBlocks)
case Right(iteratorValues) =>
// Not enough space to unroll this block; drop to disk if applicable
- logWarning(s"Not enough space to store block $blockId in memory! " +
- s"Free memory is $freeMemory bytes.")
if (level.useDisk && allowPersistToDisk) {
logWarning(s"Persisting block $blockId to disk instead.")
val res = blockManager.diskStore.putIterator(blockId, iteratorValues, level, returnValues)
@@ -265,6 +263,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
Left(vector.toArray)
} else {
// We ran out of space while unrolling the values for this block
+ logUnrollFailureMessage(blockId, vector.estimateSize())
Right(vector.iterator ++ values)
}
@@ -424,7 +423,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* Reserve additional memory for unrolling blocks used by this thread.
* Return whether the request is granted.
*/
- private[spark] def reserveUnrollMemoryForThisThread(memory: Long): Boolean = {
+ def reserveUnrollMemoryForThisThread(memory: Long): Boolean = {
accountingLock.synchronized {
val granted = freeMemory > currentUnrollMemory + memory
if (granted) {
@@ -439,7 +438,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* Release memory used by this thread for unrolling blocks.
* If the amount is not specified, remove the current thread's allocation altogether.
*/
- private[spark] def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = {
+ def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = {
val threadId = Thread.currentThread().getId
accountingLock.synchronized {
if (memory < 0) {
@@ -457,16 +456,50 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
/**
* Return the amount of memory currently occupied for unrolling blocks across all threads.
*/
- private[spark] def currentUnrollMemory: Long = accountingLock.synchronized {
+ def currentUnrollMemory: Long = accountingLock.synchronized {
unrollMemoryMap.values.sum
}
/**
* Return the amount of memory currently occupied for unrolling blocks by this thread.
*/
- private[spark] def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized {
+ def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized {
unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L)
}
+
+ /**
+ * Return the number of threads currently unrolling blocks.
+ */
+ def numThreadsUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size }
+
+ /**
+ * Log information about current memory usage.
+ */
+ def logMemoryUsage(): Unit = {
+ val blocksMemory = currentMemory
+ val unrollMemory = currentUnrollMemory
+ val totalMemory = blocksMemory + unrollMemory
+ logInfo(
+ s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " +
+ s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " +
+ s"$numThreadsUnrolling thread(s)) = ${Utils.bytesToString(totalMemory)}. " +
+ s"Storage limit = ${Utils.bytesToString(maxMemory)}."
+ )
+ }
+
+ /**
+ * Log a warning for failing to unroll a block.
+ *
+ * @param blockId ID of the block we are trying to unroll.
+ * @param finalVectorSize Final size of the vector before unrolling failed.
+ */
+ def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = {
+ logWarning(
+ s"Not enough space to cache $blockId in memory! " +
+ s"(computed ${Utils.bytesToString(finalVectorSize)} so far)"
+ )
+ logMemoryUsage()
+ }
}
private[spark] case class ResultWithDroppedBlocks(
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index d868758a7f549..71b276b5f18e4 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -121,7 +121,7 @@ final class ShuffleBlockFetcherIterator(
}
override def onBlockFetchFailure(e: Throwable): Unit = {
- logError("Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
+ logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
// Note that there is a chance that some blocks have been fetched successfully, but we
// still add them to the failed queue. This is fine because when the caller see a
// FetchFailedException, it is going to fail the entire task anyway.
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 6b4689291097f..2a27d49d2de05 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -21,9 +21,7 @@ import java.net.{InetSocketAddress, URL}
import javax.servlet.DispatcherType
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
-import scala.annotation.tailrec
import scala.language.implicitConversions
-import scala.util.{Failure, Success, Try}
import scala.xml.Node
import org.eclipse.jetty.server.Server
@@ -147,15 +145,19 @@ private[spark] object JettyUtils extends Logging {
val holder : FilterHolder = new FilterHolder()
holder.setClassName(filter)
// Get any parameters for each filter
- val paramName = "spark." + filter + ".params"
- val params = conf.get(paramName, "").split(',').map(_.trim()).toSet
- params.foreach {
- case param : String =>
+ conf.get("spark." + filter + ".params", "").split(',').map(_.trim()).toSet.foreach {
+ param: String =>
if (!param.isEmpty) {
val parts = param.split("=")
if (parts.length == 2) holder.setInitParameter(parts(0), parts(1))
}
}
+
+ val prefix = s"spark.$filter.param."
+ conf.getAll
+ .filter { case (k, v) => k.length() > prefix.length() && k.startsWith(prefix) }
+ .foreach { case (k, v) => holder.setInitParameter(k.substring(prefix.length()), v) }
+
val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR,
DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST)
handlers.foreach { case(handler) => handler.addFilter(holder, "/*", enumDispatcher) }
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index db01be596e073..2414e4c65237e 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -103,7 +103,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
val taskHeaders: Seq[String] =
Seq(
- "Index", "ID", "Attempt", "Status", "Locality Level", "Executor",
+ "Index", "ID", "Attempt", "Status", "Locality Level", "Executor ID / Host",
"Launch Time", "Duration", "GC Time", "Accumulators") ++
{if (hasInput) Seq("Input") else Nil} ++
{if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++
@@ -282,7 +282,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
}
{info.status}
{info.taskLocality}
-
{info.host}
+
{info.executorId} / {info.host}
{UIUtils.formatDate(new Date(info.launchTime))}
{formatDuration}
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index c4dddb2d1037e..5b2e7d3a7edb9 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -25,7 +25,6 @@ import scala.collection.Map
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.JsonAST._
-import org.json4s.jackson.JsonMethods._
import org.apache.spark.executor.{DataReadMethod, InputMetrics, ShuffleReadMetrics,
@@ -255,7 +254,6 @@ private[spark] object JsonProtocol {
}
def shuffleReadMetricsToJson(shuffleReadMetrics: ShuffleReadMetrics): JValue = {
- ("Shuffle Finish Time" -> shuffleReadMetrics.shuffleFinishTime) ~
("Remote Blocks Fetched" -> shuffleReadMetrics.remoteBlocksFetched) ~
("Local Blocks Fetched" -> shuffleReadMetrics.localBlocksFetched) ~
("Fetch Wait Time" -> shuffleReadMetrics.fetchWaitTime) ~
@@ -590,7 +588,6 @@ private[spark] object JsonProtocol {
def shuffleReadMetricsFromJson(json: JValue): ShuffleReadMetrics = {
val metrics = new ShuffleReadMetrics
- metrics.shuffleFinishTime = (json \ "Shuffle Finish Time").extract[Long]
metrics.remoteBlocksFetched = (json \ "Remote Blocks Fetched").extract[Int]
metrics.localBlocksFetched = (json \ "Local Blocks Fetched").extract[Int]
metrics.fetchWaitTime = (json \ "Fetch Wait Time").extract[Long]
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index c76b7af18481d..3d307b3c16d3e 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -23,7 +23,7 @@ import java.nio.ByteBuffer
import java.util.{Properties, Locale, Random, UUID}
import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor}
-import org.apache.log4j.PropertyConfigurator
+import org.eclipse.jetty.util.MultiException
import scala.collection.JavaConversions._
import scala.collection.Map
@@ -37,18 +37,23 @@ import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.commons.lang3.SystemUtils
import org.apache.hadoop.conf.Configuration
+import org.apache.log4j.PropertyConfigurator
import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
import org.json4s._
import tachyon.client.{TachyonFile,TachyonFS}
import org.apache.spark._
-import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.ExecutorUncaughtExceptionHandler
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
/** CallSite represents a place in user code. It can have a short and a long form. */
private[spark] case class CallSite(shortForm: String, longForm: String)
+private[spark] object CallSite {
+ val SHORT_FORM = "callSite.short"
+ val LONG_FORM = "callSite.long"
+}
+
/**
* Various utility methods used by Spark.
*/
@@ -81,7 +86,7 @@ private[spark] object Utils extends Logging {
ois.readObject.asInstanceOf[T]
}
- /** Deserialize a Long value (used for {@link org.apache.spark.api.python.PythonPartitioner}) */
+ /** Deserialize a Long value (used for [[org.apache.spark.api.python.PythonPartitioner]]) */
def deserializeLongValue(bytes: Array[Byte]) : Long = {
// Note: we assume that we are given a Long value encoded in network (big-endian) byte order
var result = bytes(7) & 0xFFL
@@ -148,7 +153,7 @@ private[spark] object Utils extends Logging {
def classForName(className: String) = Class.forName(className, true, getContextOrSparkClassLoader)
/**
- * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}.
+ * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]]
*/
def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = {
if (bb.hasArray) {
@@ -328,7 +333,7 @@ private[spark] object Utils extends Logging {
val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
val targetFile = new File(targetDir, filename)
val uri = new URI(url)
- val fileOverwrite = conf.getBoolean("spark.files.overwrite", false)
+ val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false)
uri.getScheme match {
case "http" | "https" | "ftp" =>
logInfo("Fetching " + url + " to " + tempFile)
@@ -350,7 +355,7 @@ private[spark] object Utils extends Logging {
uc.connect()
val in = uc.getInputStream()
val out = new FileOutputStream(tempFile)
- Utils.copyStream(in, out, true)
+ Utils.copyStream(in, out, closeStreams = true)
if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
if (fileOverwrite) {
targetFile.delete()
@@ -397,7 +402,7 @@ private[spark] object Utils extends Logging {
val fs = getHadoopFileSystem(uri, hadoopConf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(tempFile)
- Utils.copyStream(in, out, true)
+ Utils.copyStream(in, out, closeStreams = true)
if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
if (fileOverwrite) {
targetFile.delete()
@@ -661,7 +666,7 @@ private[spark] object Utils extends Logging {
*/
def deleteRecursively(file: File) {
if (file != null) {
- if ((file.isDirectory) && !isSymlink(file)) {
+ if (file.isDirectory() && !isSymlink(file)) {
for (child <- listFilesSafely(file)) {
deleteRecursively(child)
}
@@ -696,26 +701,27 @@ private[spark] object Utils extends Logging {
new File(file.getParentFile().getCanonicalFile(), file.getName())
}
- if (fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile())) {
- return false
- } else {
- return true
- }
+ !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile())
}
/**
- * Finds all the files in a directory whose last modified time is older than cutoff seconds.
- * @param dir must be the path to a directory, or IllegalArgumentException is thrown
- * @param cutoff measured in seconds. Files older than this are returned.
+ * Determines if a directory contains any files newer than cutoff seconds.
+ *
+ * @param dir must be the path to a directory, or IllegalArgumentException is thrown
+ * @param cutoff measured in seconds. Returns true if there are any files or directories in the
+ * given directory whose last modified time is later than this many seconds ago
*/
- def findOldFiles(dir: File, cutoff: Long): Seq[File] = {
- val currentTimeMillis = System.currentTimeMillis
- if (dir.isDirectory) {
- val files = listFilesSafely(dir)
- files.filter { file => file.lastModified < (currentTimeMillis - cutoff * 1000) }
- } else {
- throw new IllegalArgumentException(dir + " is not a directory!")
+ def doesDirectoryContainAnyNewFiles(dir: File, cutoff: Long): Boolean = {
+ if (!dir.isDirectory) {
+ throw new IllegalArgumentException("$dir is not a directory!")
}
+ val filesAndDirs = dir.listFiles()
+ val cutoffTimeInMillis = System.currentTimeMillis - (cutoff * 1000)
+
+ filesAndDirs.exists(_.lastModified() > cutoffTimeInMillis) ||
+ filesAndDirs.filter(_.isDirectory).exists(
+ subdir => doesDirectoryContainAnyNewFiles(subdir, cutoff)
+ )
}
/**
@@ -799,7 +805,7 @@ private[spark] object Utils extends Logging {
.start()
new Thread("read stdout for " + command(0)) {
override def run() {
- for (line <- Source.fromInputStream(process.getInputStream).getLines) {
+ for (line <- Source.fromInputStream(process.getInputStream).getLines()) {
System.err.println(line)
}
}
@@ -813,8 +819,10 @@ private[spark] object Utils extends Logging {
/**
* Execute a command and get its output, throwing an exception if it yields a code other than 0.
*/
- def executeAndGetOutput(command: Seq[String], workingDir: File = new File("."),
- extraEnvironment: Map[String, String] = Map.empty): String = {
+ def executeAndGetOutput(
+ command: Seq[String],
+ workingDir: File = new File("."),
+ extraEnvironment: Map[String, String] = Map.empty): String = {
val builder = new ProcessBuilder(command: _*)
.directory(workingDir)
val environment = builder.environment()
@@ -824,7 +832,7 @@ private[spark] object Utils extends Logging {
val process = builder.start()
new Thread("read stderr for " + command(0)) {
override def run() {
- for (line <- Source.fromInputStream(process.getErrorStream).getLines) {
+ for (line <- Source.fromInputStream(process.getErrorStream).getLines()) {
System.err.println(line)
}
}
@@ -832,7 +840,7 @@ private[spark] object Utils extends Logging {
val output = new StringBuffer
val stdoutThread = new Thread("read stdout for " + command(0)) {
override def run() {
- for (line <- Source.fromInputStream(process.getInputStream).getLines) {
+ for (line <- Source.fromInputStream(process.getInputStream).getLines()) {
output.append(line)
}
}
@@ -841,8 +849,8 @@ private[spark] object Utils extends Logging {
val exitCode = process.waitFor()
stdoutThread.join() // Wait for it to finish reading output
if (exitCode != 0) {
- logError(s"Process $command exited with code $exitCode: ${output}")
- throw new SparkException("Process " + command + " exited with code " + exitCode)
+ logError(s"Process $command exited with code $exitCode: $output")
+ throw new SparkException(s"Process $command exited with code $exitCode")
}
output.toString
}
@@ -855,29 +863,37 @@ private[spark] object Utils extends Logging {
try {
block
} catch {
+ case e: ControlThrowable => throw e
case t: Throwable => ExecutorUncaughtExceptionHandler.uncaughtException(t)
}
}
- /**
- * A regular expression to match classes of the "core" Spark API that we want to skip when
- * finding the call site of a method.
- */
- private val SPARK_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r
+ /** Default filtering function for finding call sites using `getCallSite`. */
+ private def coreExclusionFunction(className: String): Boolean = {
+ // A regular expression to match classes of the "core" Spark API that we want to skip when
+ // finding the call site of a method.
+ val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r
+ val SCALA_CLASS_REGEX = """^scala""".r
+ val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined
+ val isScalaClass = SCALA_CLASS_REGEX.findFirstIn(className).isDefined
+ // If the class is a Spark internal class or a Scala class, then exclude.
+ isSparkCoreClass || isScalaClass
+ }
/**
* When called inside a class in the spark package, returns the name of the user code class
* (outside the spark package) that called into Spark, as well as which Spark method they called.
* This is used, for example, to tell users where in their code each RDD got created.
+ *
+ * @param skipClass Function that is used to exclude non-user-code classes.
*/
- def getCallSite: CallSite = {
- val trace = Thread.currentThread.getStackTrace()
- .filterNot { ste:StackTraceElement =>
- // When running under some profilers, the current stack trace might contain some bogus
- // frames. This is intended to ensure that we don't crash in these situations by
- // ignoring any frames that we can't examine.
- (ste == null || ste.getMethodName == null || ste.getMethodName.contains("getStackTrace"))
- }
+ def getCallSite(skipClass: String => Boolean = coreExclusionFunction): CallSite = {
+ val trace = Thread.currentThread.getStackTrace().filterNot { ste: StackTraceElement =>
+ // When running under some profilers, the current stack trace might contain some bogus
+ // frames. This is intended to ensure that we don't crash in these situations by
+ // ignoring any frames that we can't examine.
+ ste == null || ste.getMethodName == null || ste.getMethodName.contains("getStackTrace")
+ }
// Keep crawling up the stack trace until we find the first function not inside of the spark
// package. We track the last (shallowest) contiguous Spark method. This might be an RDD
@@ -891,7 +907,7 @@ private[spark] object Utils extends Logging {
for (el <- trace) {
if (insideSpark) {
- if (SPARK_CLASS_REGEX.findFirstIn(el.getClassName).isDefined) {
+ if (skipClass(el.getClassName)) {
lastSparkMethod = if (el.getMethodName == "") {
// Spark method is a constructor; get its class name
el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1)
@@ -911,7 +927,7 @@ private[spark] object Utils extends Logging {
}
val callStackDepth = System.getProperty("spark.callstack.depth", "20").toInt
CallSite(
- shortForm = "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine),
+ shortForm = s"$lastSparkMethod at $firstUserFile:$firstUserLine",
longForm = callStack.take(callStackDepth).mkString("\n"))
}
@@ -1014,7 +1030,7 @@ private[spark] object Utils extends Logging {
false
}
- def isSpace(c: Char): Boolean = {
+ private def isSpace(c: Char): Boolean = {
" \t\r\n".indexOf(c) != -1
}
@@ -1166,7 +1182,7 @@ private[spark] object Utils extends Logging {
}
import scala.sys.process._
(linkCmd + src.getAbsolutePath() + " " + dst.getPath() + cmdSuffix) lines_!
- ProcessLogger(line => (logInfo(line)))
+ ProcessLogger(line => logInfo(line))
}
@@ -1247,7 +1263,7 @@ private[spark] object Utils extends Logging {
val startTime = System.currentTimeMillis
while (!terminated) {
try {
- process.exitValue
+ process.exitValue()
terminated = true
} catch {
case e: IllegalThreadStateException =>
@@ -1291,6 +1307,20 @@ private[spark] object Utils extends Logging {
}
}
+ /** Executes the given block in a Try, logging any uncaught exceptions. */
+ def tryLog[T](f: => T): Try[T] = {
+ try {
+ val res = f
+ scala.util.Success(res)
+ } catch {
+ case ct: ControlThrowable =>
+ throw ct
+ case t: Throwable =>
+ logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t)
+ scala.util.Failure(t)
+ }
+ }
+
/** Returns true if the given exception was fatal. See docs for scala.util.control.NonFatal. */
def isFatalError(e: Throwable): Boolean = {
e match {
@@ -1382,15 +1412,15 @@ private[spark] object Utils extends Logging {
}
/**
- * Default number of retries in binding to a port.
+ * Default maximum number of retries when binding to a port before giving up.
*/
val portMaxRetries: Int = {
if (sys.props.contains("spark.testing")) {
// Set a higher number of retries for tests...
- sys.props.get("spark.ports.maxRetries").map(_.toInt).getOrElse(100)
+ sys.props.get("spark.port.maxRetries").map(_.toInt).getOrElse(100)
} else {
Option(SparkEnv.get)
- .flatMap(_.conf.getOption("spark.ports.maxRetries"))
+ .flatMap(_.conf.getOption("spark.port.maxRetries"))
.map(_.toInt)
.getOrElse(16)
}
@@ -1414,7 +1444,12 @@ private[spark] object Utils extends Logging {
val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'"
for (offset <- 0 to maxRetries) {
// Do not increment port if startPort is 0, which is treated as a special port
- val tryPort = if (startPort == 0) startPort else (startPort + offset) % 65536
+ val tryPort = if (startPort == 0) {
+ startPort
+ } else {
+ // If the new port wraps around, do not try a privilege port
+ ((startPort + offset - 1024) % (65536 - 1024)) + 1024
+ }
try {
val (service, port) = startService(tryPort)
logInfo(s"Successfully started service$serviceString on port $port.")
@@ -1447,6 +1482,7 @@ private[spark] object Utils extends Logging {
return true
}
isBindCollision(e.getCause)
+ case e: MultiException => e.getThrowables.exists(isBindCollision)
case e: Exception => isBindCollision(e.getCause)
case _ => false
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 8a015c1d26a96..0c088da46aa5e 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -66,23 +66,19 @@ class ExternalAppendOnlyMap[K, V, C](
mergeCombiners: (C, C) => C,
serializer: Serializer = SparkEnv.get.serializer,
blockManager: BlockManager = SparkEnv.get.blockManager)
- extends Iterable[(K, C)] with Serializable with Logging {
+ extends Iterable[(K, C)]
+ with Serializable
+ with Logging
+ with Spillable[SizeTracker] {
private var currentMap = new SizeTrackingAppendOnlyMap[K, C]
private val spilledMaps = new ArrayBuffer[DiskMapIterator]
private val sparkConf = SparkEnv.get.conf
private val diskBlockManager = blockManager.diskBlockManager
- private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
// Number of pairs inserted since last spill; note that we count them even if a value is merged
// with a previous key in case we're doing something like groupBy where the result grows
- private var elementsRead = 0L
-
- // Number of in-memory pairs inserted before tracking the map's shuffle memory usage
- private val trackMemoryThreshold = 1000
-
- // How much of the shared memory pool this collection has claimed
- private var myMemoryThreshold = 0L
+ protected[this] var elementsRead = 0L
/**
* Size of object batches when reading/writing from serializers.
@@ -95,11 +91,7 @@ class ExternalAppendOnlyMap[K, V, C](
*/
private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000)
- // How many times we have spilled so far
- private var spillCount = 0
-
// Number of bytes spilled in total
- private var _memoryBytesSpilled = 0L
private var _diskBytesSpilled = 0L
private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024
@@ -136,19 +128,8 @@ class ExternalAppendOnlyMap[K, V, C](
while (entries.hasNext) {
curEntry = entries.next()
- if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
- currentMap.estimateSize() >= myMemoryThreshold)
- {
- // Claim up to double our current memory from the shuffle memory pool
- val currentMemory = currentMap.estimateSize()
- val amountToRequest = 2 * currentMemory - myMemoryThreshold
- val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
- myMemoryThreshold += granted
- if (myMemoryThreshold <= currentMemory) {
- // We were granted too little memory to grow further (either tryToAcquire returned 0,
- // or we already had more memory than myMemoryThreshold); spill the current collection
- spill(currentMemory) // Will also release memory back to ShuffleMemoryManager
- }
+ if (maybeSpill(currentMap, currentMap.estimateSize())) {
+ currentMap = new SizeTrackingAppendOnlyMap[K, C]
}
currentMap.changeValue(curEntry._1, update)
elementsRead += 1
@@ -171,11 +152,7 @@ class ExternalAppendOnlyMap[K, V, C](
/**
* Sort the existing contents of the in-memory map and spill them to a temporary file on disk.
*/
- private def spill(mapSize: Long): Unit = {
- spillCount += 1
- val threadId = Thread.currentThread().getId
- logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)"
- .format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
+ override protected[this] def spill(collection: SizeTracker): Unit = {
val (blockId, file) = diskBlockManager.createTempBlock()
curWriteMetrics = new ShuffleWriteMetrics()
var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize,
@@ -231,18 +208,11 @@ class ExternalAppendOnlyMap[K, V, C](
}
}
- currentMap = new SizeTrackingAppendOnlyMap[K, C]
spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
- // Release our memory back to the shuffle pool so that other threads can grab it
- shuffleMemoryManager.release(myMemoryThreshold)
- myMemoryThreshold = 0L
-
elementsRead = 0
- _memoryBytesSpilled += mapSize
}
- def memoryBytesSpilled: Long = _memoryBytesSpilled
def diskBytesSpilled: Long = _diskBytesSpilled
/**
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 782b979e2e93d..644fa36818647 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -79,14 +79,14 @@ private[spark] class ExternalSorter[K, V, C](
aggregator: Option[Aggregator[K, V, C]] = None,
partitioner: Option[Partitioner] = None,
ordering: Option[Ordering[K]] = None,
- serializer: Option[Serializer] = None) extends Logging {
+ serializer: Option[Serializer] = None)
+ extends Logging with Spillable[SizeTrackingPairCollection[(Int, K), C]] {
private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
private val shouldPartition = numPartitions > 1
private val blockManager = SparkEnv.get.blockManager
private val diskBlockManager = blockManager.diskBlockManager
- private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
private val ser = Serializer.getSerializer(serializer)
private val serInstance = ser.newInstance()
@@ -115,22 +115,14 @@ private[spark] class ExternalSorter[K, V, C](
// Number of pairs read from input since last spill; note that we count them even if a value is
// merged with a previous key in case we're doing something like groupBy where the result grows
- private var elementsRead = 0L
-
- // What threshold of elementsRead we start estimating map size at.
- private val trackMemoryThreshold = 1000
+ protected[this] var elementsRead = 0L
// Total spilling statistics
- private var spillCount = 0
- private var _memoryBytesSpilled = 0L
private var _diskBytesSpilled = 0L
// Write metrics for current spill
private var curWriteMetrics: ShuffleWriteMetrics = _
- // How much of the shared memory pool this collection has claimed
- private var myMemoryThreshold = 0L
-
// If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need
// local aggregation and sorting, write numPartitions files directly and just concatenate them
// at the end. This avoids doing serialization and deserialization twice to merge together the
@@ -152,7 +144,7 @@ private[spark] class ExternalSorter[K, V, C](
override def compare(a: K, b: K): Int = {
val h1 = if (a == null) 0 else a.hashCode()
val h2 = if (b == null) 0 else b.hashCode()
- h1 - h2
+ if (h1 < h2) -1 else if (h1 == h2) 0 else 1
}
})
@@ -209,7 +201,7 @@ private[spark] class ExternalSorter[K, V, C](
elementsRead += 1
kv = records.next()
map.changeValue((getPartition(kv._1), kv._1), update)
- maybeSpill(usingMap = true)
+ maybeSpillCollection(usingMap = true)
}
} else {
// Stick values into our buffer
@@ -217,7 +209,7 @@ private[spark] class ExternalSorter[K, V, C](
elementsRead += 1
val kv = records.next()
buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
- maybeSpill(usingMap = false)
+ maybeSpillCollection(usingMap = false)
}
}
}
@@ -227,61 +219,31 @@ private[spark] class ExternalSorter[K, V, C](
*
* @param usingMap whether we're using a map or buffer as our current in-memory collection
*/
- private def maybeSpill(usingMap: Boolean): Unit = {
+ private def maybeSpillCollection(usingMap: Boolean): Unit = {
if (!spillingEnabled) {
return
}
- val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
-
- // TODO: factor this out of both here and ExternalAppendOnlyMap
- if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
- collection.estimateSize() >= myMemoryThreshold)
- {
- // Claim up to double our current memory from the shuffle memory pool
- val currentMemory = collection.estimateSize()
- val amountToRequest = 2 * currentMemory - myMemoryThreshold
- val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
- myMemoryThreshold += granted
- if (myMemoryThreshold <= currentMemory) {
- // We were granted too little memory to grow further (either tryToAcquire returned 0,
- // or we already had more memory than myMemoryThreshold); spill the current collection
- spill(currentMemory, usingMap) // Will also release memory back to ShuffleMemoryManager
+ if (usingMap) {
+ if (maybeSpill(map, map.estimateSize())) {
+ map = new SizeTrackingAppendOnlyMap[(Int, K), C]
+ }
+ } else {
+ if (maybeSpill(buffer, buffer.estimateSize())) {
+ buffer = new SizeTrackingPairBuffer[(Int, K), C]
}
}
}
/**
* Spill the current in-memory collection to disk, adding a new file to spills, and clear it.
- *
- * @param usingMap whether we're using a map or buffer as our current in-memory collection
*/
- private def spill(memorySize: Long, usingMap: Boolean): Unit = {
- val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
- val memorySize = collection.estimateSize()
-
- spillCount += 1
- val threadId = Thread.currentThread().getId
- logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)"
- .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
-
+ override protected[this] def spill(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
if (bypassMergeSort) {
spillToPartitionFiles(collection)
} else {
spillToMergeableFile(collection)
}
-
- if (usingMap) {
- map = new SizeTrackingAppendOnlyMap[(Int, K), C]
- } else {
- buffer = new SizeTrackingPairBuffer[(Int, K), C]
- }
-
- // Release our memory back to the shuffle pool so that other threads can grab it
- shuffleMemoryManager.release(myMemoryThreshold)
- myMemoryThreshold = 0
-
- _memoryBytesSpilled += memorySize
}
/**
@@ -804,8 +766,6 @@ private[spark] class ExternalSorter[K, V, C](
}
}
- def memoryBytesSpilled: Long = _memoryBytesSpilled
-
def diskBytesSpilled: Long = _diskBytesSpilled
/**
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
new file mode 100644
index 0000000000000..d7dccd4af8c6e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkEnv
+
+/**
+ * Spills contents of an in-memory collection to disk when the memory threshold
+ * has been exceeded.
+ */
+private[spark] trait Spillable[C] {
+
+ this: Logging =>
+
+ /**
+ * Spills the current in-memory collection to disk, and releases the memory.
+ *
+ * @param collection collection to spill to disk
+ */
+ protected def spill(collection: C): Unit
+
+ // Number of elements read from input since last spill
+ protected var elementsRead: Long
+
+ // Memory manager that can be used to acquire/release memory
+ private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
+
+ // What threshold of elementsRead we start estimating collection size at
+ private[this] val trackMemoryThreshold = 1000
+
+ // How much of the shared memory pool this collection has claimed
+ private[this] var myMemoryThreshold = 0L
+
+ // Number of bytes spilled in total
+ private[this] var _memoryBytesSpilled = 0L
+
+ // Number of spills
+ private[this] var _spillCount = 0
+
+ /**
+ * Spills the current in-memory collection to disk if needed. Attempts to acquire more
+ * memory before spilling.
+ *
+ * @param collection collection to spill to disk
+ * @param currentMemory estimated size of the collection in bytes
+ * @return true if `collection` was spilled to disk; false otherwise
+ */
+ protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
+ if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
+ currentMemory >= myMemoryThreshold) {
+ // Claim up to double our current memory from the shuffle memory pool
+ val amountToRequest = 2 * currentMemory - myMemoryThreshold
+ val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
+ myMemoryThreshold += granted
+ if (myMemoryThreshold <= currentMemory) {
+ // We were granted too little memory to grow further (either tryToAcquire returned 0,
+ // or we already had more memory than myMemoryThreshold); spill the current collection
+ _spillCount += 1
+ logSpillage(currentMemory)
+
+ spill(collection)
+
+ // Keep track of spills, and release memory
+ _memoryBytesSpilled += currentMemory
+ releaseMemoryForThisThread()
+ return true
+ }
+ }
+ false
+ }
+
+ /**
+ * @return number of bytes spilled in total
+ */
+ def memoryBytesSpilled: Long = _memoryBytesSpilled
+
+ /**
+ * Release our memory back to the shuffle pool so that other threads can grab it.
+ */
+ private def releaseMemoryForThisThread(): Unit = {
+ shuffleMemoryManager.release(myMemoryThreshold)
+ myMemoryThreshold = 0L
+ }
+
+ /**
+ * Prints a standard log message detailing spillage.
+ *
+ * @param size number of bytes spilled
+ */
+ @inline private def logSpillage(size: Long) {
+ val threadId = Thread.currentThread().getId
+ logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)"
+ .format(threadId, size / (1024 * 1024), _spillCount, if (_spillCount > 1) "s" else ""))
+ }
+}
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index b8574dfb42e6b..4a078435447e5 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -776,7 +776,7 @@ public void persist() {
@Test
public void iterator() {
JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContext(0, 0, 0, false, new TaskMetrics());
+ TaskContext context = new TaskContext(0, 0, 0L, false, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}
@@ -1307,4 +1307,30 @@ public void collectUnderlyingScalaRDD() {
SomeCustomClass[] collected = (SomeCustomClass[]) rdd.rdd().retag(SomeCustomClass.class).collect();
Assert.assertEquals(data.size(), collected.length);
}
+
+ /**
+ * Test for SPARK-3647. This test needs to use the maven-built assembly to trigger the issue,
+ * since that's the only artifact where Guava classes have been relocated.
+ */
+ @Test
+ public void testGuavaOptional() {
+ // Stop the context created in setUp() and start a local-cluster one, to force usage of the
+ // assembly.
+ sc.stop();
+ JavaSparkContext localCluster = new JavaSparkContext("local-cluster[1,1,512]", "JavaAPISuite");
+ try {
+ JavaRDD rdd1 = localCluster.parallelize(Arrays.asList(1, 2, null), 3);
+ JavaRDD> rdd2 = rdd1.map(
+ new Function>() {
+ @Override
+ public Optional call(Integer i) {
+ return Optional.fromNullable(i);
+ }
+ });
+ rdd2.collect();
+ } finally {
+ localCluster.stop();
+ }
+ }
+
}
diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
index af34cdb03e4d1..0944bf8cd5c71 100644
--- a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
+++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
@@ -30,10 +30,9 @@ public class JavaTaskCompletionListenerImpl implements TaskCompletionListener {
public void onTaskCompletion(TaskContext context) {
context.isCompleted();
context.isInterrupted();
- context.stageId();
- context.partitionId();
- context.runningLocally();
- context.taskMetrics();
+ context.getStageId();
+ context.getPartitionId();
+ context.isRunningLocally();
context.addTaskCompletionListener(this);
}
}
diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties
index 26b73a1b39744..9dd05f17f012b 100644
--- a/core/src/test/resources/log4j.properties
+++ b/core/src/test/resources/log4j.properties
@@ -21,7 +21,7 @@ log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
-log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 90dcadcffd091..d735010d7c9d5 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
whenExecuting(blockManager) {
- val context = new TaskContext(0, 0, 0, runningLocally = true)
+ val context = new TaskContext(0, 0, 0, true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
diff --git a/core/src/test/scala/org/apache/spark/FutureActionSuite.scala b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala
new file mode 100644
index 0000000000000..db9c25fc457a4
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.concurrent.Await
+import scala.concurrent.duration.Duration
+
+import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
+
+import org.apache.spark.SparkContext._
+
+class FutureActionSuite extends FunSuite with BeforeAndAfter with Matchers with LocalSparkContext {
+
+ before {
+ sc = new SparkContext("local", "FutureActionSuite")
+ }
+
+ test("simple async action") {
+ val rdd = sc.parallelize(1 to 10, 2)
+ val job = rdd.countAsync()
+ val res = Await.result(job, Duration.Inf)
+ res should be (10)
+ job.jobIds.size should be (1)
+ }
+
+ test("complex async action") {
+ val rdd = sc.parallelize(1 to 15, 3)
+ val job = rdd.takeAsync(10)
+ val res = Await.result(job, Duration.Inf)
+ res should be (1 to 10)
+ job.jobIds.size should be (2)
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 5369169811f81..1fef79ad1001f 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -23,32 +23,13 @@ import akka.actor._
import akka.testkit.TestActorRef
import org.scalatest.FunSuite
-import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.AkkaUtils
class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
private val conf = new SparkConf
- test("compressSize") {
- assert(MapOutputTracker.compressSize(0L) === 0)
- assert(MapOutputTracker.compressSize(1L) === 1)
- assert(MapOutputTracker.compressSize(2L) === 8)
- assert(MapOutputTracker.compressSize(10L) === 25)
- assert((MapOutputTracker.compressSize(1000000L) & 0xFF) === 145)
- assert((MapOutputTracker.compressSize(1000000000L) & 0xFF) === 218)
- // This last size is bigger than we can encode in a byte, so check that we just return 255
- assert((MapOutputTracker.compressSize(1000000000000000000L) & 0xFF) === 255)
- }
-
- test("decompressSize") {
- assert(MapOutputTracker.decompressSize(0) === 0)
- for (size <- Seq(2L, 10L, 100L, 50000L, 1000000L, 1000000000L)) {
- val size2 = MapOutputTracker.decompressSize(MapOutputTracker.compressSize(size))
- assert(size2 >= 0.99 * size && size2 <= 1.11 * size,
- "size " + size + " decompressed to " + size2 + ", which is out of range")
- }
- }
test("master start and stop") {
val actorSystem = ActorSystem("test")
@@ -65,14 +46,12 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
tracker.registerShuffle(10, 2)
assert(tracker.containsShuffle(10))
- val compressedSize1000 = MapOutputTracker.compressSize(1000L)
- val compressedSize10000 = MapOutputTracker.compressSize(10000L)
- val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
- val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
- Array(compressedSize1000, compressedSize10000)))
- tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
- Array(compressedSize10000, compressedSize1000)))
+ val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
+ val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L))
+ tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
+ Array(1000L, 10000L)))
+ tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
+ Array(10000L, 1000L)))
val statuses = tracker.getServerStatuses(10, 0)
assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000),
(BlockManagerId("b", "hostB", 1000), size10000)))
@@ -84,11 +63,11 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val tracker = new MapOutputTrackerMaster(conf)
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
tracker.registerShuffle(10, 2)
- val compressedSize1000 = MapOutputTracker.compressSize(1000L)
- val compressedSize10000 = MapOutputTracker.compressSize(10000L)
- tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
+ val compressedSize1000 = MapStatus.compressSize(1000L)
+ val compressedSize10000 = MapStatus.compressSize(10000L)
+ tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
Array(compressedSize1000, compressedSize10000)))
- tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
+ tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000)))
assert(tracker.containsShuffle(10))
assert(tracker.getServerStatuses(10, 0).nonEmpty)
@@ -103,11 +82,11 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
tracker.trackerActor =
actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
tracker.registerShuffle(10, 2)
- val compressedSize1000 = MapOutputTracker.compressSize(1000L)
- val compressedSize10000 = MapOutputTracker.compressSize(10000L)
- tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
+ val compressedSize1000 = MapStatus.compressSize(1000L)
+ val compressedSize10000 = MapStatus.compressSize(10000L)
+ tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
Array(compressedSize1000, compressedSize1000, compressedSize1000)))
- tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
+ tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
// As if we had two simultaneous fetch failures
@@ -142,10 +121,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
slaveTracker.updateEpoch(masterTracker.getEpoch)
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
- val compressedSize1000 = MapOutputTracker.compressSize(1000L)
- val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
- masterTracker.registerMapOutput(10, 0, new MapStatus(
- BlockManagerId("a", "hostA", 1000), Array(compressedSize1000)))
+ val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
+ masterTracker.registerMapOutput(10, 0, MapStatus(
+ BlockManagerId("a", "hostA", 1000), Array(1000L)))
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
@@ -173,8 +151,8 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
// Frame size should be ~123B, and no exception should be thrown
masterTracker.registerShuffle(10, 1)
- masterTracker.registerMapOutput(10, 0, new MapStatus(
- BlockManagerId("88", "mph", 1000), Array.fill[Byte](10)(0)))
+ masterTracker.registerMapOutput(10, 0, MapStatus(
+ BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0)))
masterActor.receive(GetMapOutputStatuses(10))
}
@@ -194,8 +172,8 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
// being sent.
masterTracker.registerShuffle(20, 100)
(0 until 100).foreach { i =>
- masterTracker.registerMapOutput(20, i, new MapStatus(
- BlockManagerId("999", "mps", 1000), Array.fill[Byte](4000000)(0)))
+ masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
+ BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0)))
}
intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) }
}
diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
index fc0cee3e8749d..646ede30ae6ff 100644
--- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
@@ -193,11 +193,13 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
assert(grouped2.join(grouped4).partitioner === grouped4.partitioner)
assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped4.partitioner)
assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped4.partitioner)
+ assert(grouped2.fullOuterJoin(grouped4).partitioner === grouped4.partitioner)
assert(grouped2.cogroup(grouped4).partitioner === grouped4.partitioner)
assert(grouped2.join(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner)
+ assert(grouped2.fullOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.map(_ => 1).partitioner === None)
@@ -218,6 +220,7 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.fullOuterJoin(arrPairs) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array"))
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index 978a6ded80829..acaf321de52fb 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -132,7 +132,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === 1)
statuses.head match { case (bm, status) =>
- assert(bm.executorId === "", "Block should only be on the driver")
+ assert(bm.isDriver, "Block should only be on the driver")
assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
assert(status.memSize > 0, "Block should be in memory store on the driver")
assert(status.diskSize === 0, "Block should not be in disk store on the driver")
diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
index 2a58c6a40d8e4..3f1cd0752e766 100644
--- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
@@ -115,11 +115,13 @@ class JsonProtocolSuite extends FunSuite {
workerInfo.lastHeartbeat = JsonConstants.currTimeInMillis
workerInfo
}
+
def createExecutorRunner(): ExecutorRunner = {
new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host",
new File("sparkHome"), new File("workDir"), "akka://worker",
new SparkConf, ExecutorState.RUNNING)
}
+
def createDriverRunner(): DriverRunner = {
new DriverRunner(new SparkConf(), "driverId", new File("workDir"), new File("sparkHome"),
createDriverDesc(), null, "akka://worker")
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 22b369a829418..4cba90e8f2afe 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy
-import java.io.{File, OutputStream, PrintStream}
+import java.io._
import scala.collection.mutable.ArrayBuffer
@@ -26,6 +26,7 @@ import org.apache.spark.deploy.SparkSubmit._
import org.apache.spark.util.Utils
import org.scalatest.FunSuite
import org.scalatest.Matchers
+import com.google.common.io.Files
class SparkSubmitSuite extends FunSuite with Matchers {
def beforeAll() {
@@ -154,6 +155,7 @@ class SparkSubmitSuite extends FunSuite with Matchers {
sysProps("spark.app.name") should be ("beauty")
sysProps("spark.shuffle.spill") should be ("false")
sysProps("SPARK_SUBMIT") should be ("true")
+ sysProps.keys should not contain ("spark.jars")
}
test("handles YARN client mode") {
@@ -305,6 +307,21 @@ class SparkSubmitSuite extends FunSuite with Matchers {
runSparkSubmit(args)
}
+ test("SPARK_CONF_DIR overrides spark-defaults.conf") {
+ forConfDir(Map("spark.executor.memory" -> "2.3g")) { path =>
+ val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
+ val args = Seq(
+ "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"),
+ "--name", "testApp",
+ "--master", "local",
+ unusedJar.toString)
+ val appArgs = new SparkSubmitArguments(args, Map("SPARK_CONF_DIR" -> path))
+ assert(appArgs.propertiesFile != null)
+ assert(appArgs.propertiesFile.startsWith(path))
+ appArgs.executorMemory should be ("2.3g")
+ }
+ }
+
// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
def runSparkSubmit(args: Seq[String]): String = {
val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
@@ -313,6 +330,22 @@ class SparkSubmitSuite extends FunSuite with Matchers {
new File(sparkHome),
Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome))
}
+
+ def forConfDir(defaults: Map[String, String]) (f: String => Unit) = {
+ val tmpDir = Files.createTempDir()
+
+ val defaultsConf = new File(tmpDir.getAbsolutePath, "spark-defaults.conf")
+ val writer = new OutputStreamWriter(new FileOutputStream(defaultsConf))
+ for ((key, value) <- defaults) writer.write(s"$key $value\n")
+
+ writer.close()
+
+ try {
+ f(tmpDir.getAbsolutePath)
+ } finally {
+ Utils.deleteRecursively(tmpDir)
+ }
+ }
}
object JarCreationTest {
diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
index 96a5a1231813e..3925f0ccbdbf0 100644
--- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
@@ -17,42 +17,171 @@
package org.apache.spark.metrics
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester}
+
import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.master.MasterSource
+import org.apache.spark.metrics.source.Source
+
+import com.codahale.metrics.MetricRegistry
+
+import scala.collection.mutable.ArrayBuffer
-class MetricsSystemSuite extends FunSuite with BeforeAndAfter {
+class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester{
var filePath: String = _
var conf: SparkConf = null
var securityMgr: SecurityManager = null
before {
- filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile()
+ filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile
conf = new SparkConf(false).set("spark.metrics.conf", filePath)
securityMgr = new SecurityManager(conf)
}
test("MetricsSystem with default config") {
val metricsSystem = MetricsSystem.createMetricsSystem("default", conf, securityMgr)
- val sources = metricsSystem.sources
- val sinks = metricsSystem.sinks
+ metricsSystem.start()
+ val sources = PrivateMethod[ArrayBuffer[Source]]('sources)
+ val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks)
- assert(sources.length === 0)
- assert(sinks.length === 0)
- assert(!metricsSystem.getServletHandlers.isEmpty)
+ assert(metricsSystem.invokePrivate(sources()).length === 0)
+ assert(metricsSystem.invokePrivate(sinks()).length === 0)
+ assert(metricsSystem.getServletHandlers.nonEmpty)
}
test("MetricsSystem with sources add") {
val metricsSystem = MetricsSystem.createMetricsSystem("test", conf, securityMgr)
- val sources = metricsSystem.sources
- val sinks = metricsSystem.sinks
+ metricsSystem.start()
+ val sources = PrivateMethod[ArrayBuffer[Source]]('sources)
+ val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks)
- assert(sources.length === 0)
- assert(sinks.length === 1)
- assert(!metricsSystem.getServletHandlers.isEmpty)
+ assert(metricsSystem.invokePrivate(sources()).length === 0)
+ assert(metricsSystem.invokePrivate(sinks()).length === 1)
+ assert(metricsSystem.getServletHandlers.nonEmpty)
val source = new MasterSource(null)
metricsSystem.registerSource(source)
- assert(sources.length === 1)
+ assert(metricsSystem.invokePrivate(sources()).length === 1)
+ }
+
+ test("MetricsSystem with Driver instance") {
+ val source = new Source {
+ override val sourceName = "dummySource"
+ override val metricRegistry = new MetricRegistry()
+ }
+
+ val appId = "testId"
+ val executorId = "driver"
+ conf.set("spark.app.id", appId)
+ conf.set("spark.executor.id", executorId)
+
+ val instanceName = "driver"
+ val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr)
+
+ val metricName = driverMetricsSystem.buildRegistryName(source)
+ assert(metricName === s"$appId.$executorId.${source.sourceName}")
+ }
+
+ test("MetricsSystem with Driver instance and spark.app.id is not set") {
+ val source = new Source {
+ override val sourceName = "dummySource"
+ override val metricRegistry = new MetricRegistry()
+ }
+
+ val executorId = "driver"
+ conf.set("spark.executor.id", executorId)
+
+ val instanceName = "driver"
+ val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr)
+
+ val metricName = driverMetricsSystem.buildRegistryName(source)
+ assert(metricName === source.sourceName)
+ }
+
+ test("MetricsSystem with Driver instance and spark.executor.id is not set") {
+ val source = new Source {
+ override val sourceName = "dummySource"
+ override val metricRegistry = new MetricRegistry()
+ }
+
+ val appId = "testId"
+ conf.set("spark.app.id", appId)
+
+ val instanceName = "driver"
+ val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr)
+
+ val metricName = driverMetricsSystem.buildRegistryName(source)
+ assert(metricName === source.sourceName)
+ }
+
+ test("MetricsSystem with Executor instance") {
+ val source = new Source {
+ override val sourceName = "dummySource"
+ override val metricRegistry = new MetricRegistry()
+ }
+
+ val appId = "testId"
+ val executorId = "executor.1"
+ conf.set("spark.app.id", appId)
+ conf.set("spark.executor.id", executorId)
+
+ val instanceName = "executor"
+ val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr)
+
+ val metricName = driverMetricsSystem.buildRegistryName(source)
+ assert(metricName === s"$appId.$executorId.${source.sourceName}")
+ }
+
+ test("MetricsSystem with Executor instance and spark.app.id is not set") {
+ val source = new Source {
+ override val sourceName = "dummySource"
+ override val metricRegistry = new MetricRegistry()
+ }
+
+ val executorId = "executor.1"
+ conf.set("spark.executor.id", executorId)
+
+ val instanceName = "executor"
+ val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr)
+
+ val metricName = driverMetricsSystem.buildRegistryName(source)
+ assert(metricName === source.sourceName)
+ }
+
+ test("MetricsSystem with Executor instance and spark.executor.id is not set") {
+ val source = new Source {
+ override val sourceName = "dummySource"
+ override val metricRegistry = new MetricRegistry()
+ }
+
+ val appId = "testId"
+ conf.set("spark.app.id", appId)
+
+ val instanceName = "executor"
+ val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr)
+
+ val metricName = driverMetricsSystem.buildRegistryName(source)
+ assert(metricName === source.sourceName)
+ }
+
+ test("MetricsSystem with instance which is neither Driver nor Executor") {
+ val source = new Source {
+ override val sourceName = "dummySource"
+ override val metricRegistry = new MetricRegistry()
+ }
+
+ val appId = "testId"
+ val executorId = "dummyExecutorId"
+ conf.set("spark.app.id", appId)
+ conf.set("spark.executor.id", executorId)
+
+ val instanceName = "testInstance"
+ val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr)
+
+ val metricName = driverMetricsSystem.buildRegistryName(source)
+
+ // Even if spark.app.id and spark.executor.id are set, they are not used for the metric name.
+ assert(metricName != s"$appId.$executorId.${source.sourceName}")
+ assert(metricName === source.sourceName)
}
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 63d3ddb4af98a..75b01191901b8 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -17,17 +17,21 @@
package org.apache.spark.rdd
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashSet
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.mapred._
+import org.apache.hadoop.util.Progressable
+
+import scala.collection.mutable.{ArrayBuffer, HashSet}
import scala.util.Random
-import org.scalatest.FunSuite
import com.google.common.io.Files
-import org.apache.hadoop.mapreduce._
-import org.apache.hadoop.conf.{Configuration, Configurable}
-
-import org.apache.spark.SparkContext._
+import org.apache.hadoop.conf.{Configurable, Configuration}
+import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, OutputCommitter => NewOutputCommitter,
+OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter,
+TaskAttemptContext => NewTaskAttempContext}
import org.apache.spark.{Partitioner, SharedSparkContext}
+import org.apache.spark.SparkContext._
+import org.scalatest.FunSuite
class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
test("aggregateByKey") {
@@ -294,6 +298,21 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
))
}
+ test("fullOuterJoin") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.fullOuterJoin(rdd2).collect()
+ assert(joined.size === 6)
+ assert(joined.toSet === Set(
+ (1, (Some(1), Some('x'))),
+ (1, (Some(2), Some('x'))),
+ (2, (Some(1), Some('y'))),
+ (2, (Some(1), Some('z'))),
+ (3, (Some(1), None)),
+ (4, (None, Some('w')))
+ ))
+ }
+
test("join with no matches") {
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
@@ -467,7 +486,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
val pairs = sc.parallelize(Array((new Integer(1), new Integer(1))))
// No error, non-configurable formats still work
- pairs.saveAsNewAPIHadoopFile[FakeFormat]("ignored")
+ pairs.saveAsNewAPIHadoopFile[NewFakeFormat]("ignored")
/*
Check that configurable formats get configured:
@@ -478,6 +497,17 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
pairs.saveAsNewAPIHadoopFile[ConfigTestFormat]("ignored")
}
+ test("saveAsHadoopFile should respect configured output committers") {
+ val pairs = sc.parallelize(Array((new Integer(1), new Integer(1))))
+ val conf = new JobConf()
+ conf.setOutputCommitter(classOf[FakeOutputCommitter])
+
+ FakeOutputCommitter.ran = false
+ pairs.saveAsHadoopFile("ignored", pairs.keyClass, pairs.valueClass, classOf[FakeOutputFormat], conf)
+
+ assert(FakeOutputCommitter.ran, "OutputCommitter was never called")
+ }
+
test("lookup") {
val pairs = sc.parallelize(Array((1,2), (3,4), (5,6), (5,7)))
@@ -621,40 +651,86 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
and the test will therefore throw InstantiationException when saveAsNewAPIHadoopFile
tries to instantiate them with Class.newInstance.
*/
+
+/*
+ * Original Hadoop API
+ */
class FakeWriter extends RecordWriter[Integer, Integer] {
+ override def write(key: Integer, value: Integer): Unit = ()
+
+ override def close(reporter: Reporter): Unit = ()
+}
+
+class FakeOutputCommitter() extends OutputCommitter() {
+ override def setupJob(jobContext: JobContext): Unit = ()
+
+ override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true
+
+ override def setupTask(taskContext: TaskAttemptContext): Unit = ()
+
+ override def commitTask(taskContext: TaskAttemptContext): Unit = {
+ FakeOutputCommitter.ran = true
+ ()
+ }
+
+ override def abortTask(taskContext: TaskAttemptContext): Unit = ()
+}
+
+/*
+ * Used to communicate state between the test harness and the OutputCommitter.
+ */
+object FakeOutputCommitter {
+ var ran = false
+}
+
+class FakeOutputFormat() extends OutputFormat[Integer, Integer]() {
+ override def getRecordWriter(
+ ignored: FileSystem,
+ job: JobConf, name: String,
+ progress: Progressable): RecordWriter[Integer, Integer] = {
+ new FakeWriter()
+ }
- def close(p1: TaskAttemptContext) = ()
+ override def checkOutputSpecs(ignored: FileSystem, job: JobConf): Unit = ()
+}
+
+/*
+ * New-style Hadoop API
+ */
+class NewFakeWriter extends NewRecordWriter[Integer, Integer] {
+
+ def close(p1: NewTaskAttempContext) = ()
def write(p1: Integer, p2: Integer) = ()
}
-class FakeCommitter extends OutputCommitter {
- def setupJob(p1: JobContext) = ()
+class NewFakeCommitter extends NewOutputCommitter {
+ def setupJob(p1: NewJobContext) = ()
- def needsTaskCommit(p1: TaskAttemptContext): Boolean = false
+ def needsTaskCommit(p1: NewTaskAttempContext): Boolean = false
- def setupTask(p1: TaskAttemptContext) = ()
+ def setupTask(p1: NewTaskAttempContext) = ()
- def commitTask(p1: TaskAttemptContext) = ()
+ def commitTask(p1: NewTaskAttempContext) = ()
- def abortTask(p1: TaskAttemptContext) = ()
+ def abortTask(p1: NewTaskAttempContext) = ()
}
-class FakeFormat() extends OutputFormat[Integer, Integer]() {
+class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() {
- def checkOutputSpecs(p1: JobContext) = ()
+ def checkOutputSpecs(p1: NewJobContext) = ()
- def getRecordWriter(p1: TaskAttemptContext): RecordWriter[Integer, Integer] = {
- new FakeWriter()
+ def getRecordWriter(p1: NewTaskAttempContext): NewRecordWriter[Integer, Integer] = {
+ new NewFakeWriter()
}
- def getOutputCommitter(p1: TaskAttemptContext): OutputCommitter = {
- new FakeCommitter()
+ def getOutputCommitter(p1: NewTaskAttempContext): NewOutputCommitter = {
+ new NewFakeCommitter()
}
}
-class ConfigTestFormat() extends FakeFormat() with Configurable {
+class ConfigTestFormat() extends NewFakeFormat() with Configurable {
var setConfCalled = false
def setConf(p1: Configuration) = {
@@ -664,7 +740,7 @@ class ConfigTestFormat() extends FakeFormat() with Configurable {
def getConf: Configuration = null
- override def getRecordWriter(p1: TaskAttemptContext): RecordWriter[Integer, Integer] = {
+ override def getRecordWriter(p1: NewTaskAttempContext): NewRecordWriter[Integer, Integer] = {
assert(setConfCalled, "setConf was never called")
super.getRecordWriter(p1)
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index c1b501a75c8b8..465c1a8a43a79 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -193,6 +193,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(rdd.join(emptyKv).collect().size === 0)
assert(rdd.rightOuterJoin(emptyKv).collect().size === 0)
assert(rdd.leftOuterJoin(emptyKv).collect().size === 2)
+ assert(rdd.fullOuterJoin(emptyKv).collect().size === 2)
assert(rdd.cogroup(emptyKv).collect().size === 2)
assert(rdd.union(emptyKv).collect().size === 2)
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index aa73469b6acd8..a2e4f712db55b 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -740,7 +740,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
}
private def makeMapStatus(host: String, reduces: Int): MapStatus =
- new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
+ MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(2))
private def makeBlockManagerId(host: String): BlockManagerId =
BlockManagerId("exec-" + host, host, 12345)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
index e5315bc93e217..3efa85431876b 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
@@ -169,7 +169,9 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter {
// Verify logging directory exists
val conf = getLoggingConf(logDirPath, compressionCodec)
- val eventLogger = new EventLoggingListener("test", conf)
+ val logBaseDir = conf.get("spark.eventLog.dir")
+ val appId = EventLoggingListenerSuite.getUniqueApplicationId
+ val eventLogger = new EventLoggingListener(appId, logBaseDir, conf)
eventLogger.start()
val logPath = new Path(eventLogger.logDir)
assert(fileSystem.exists(logPath))
@@ -209,7 +211,9 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter {
// Verify that all information is correctly parsed before stop()
val conf = getLoggingConf(logDirPath, compressionCodec)
- val eventLogger = new EventLoggingListener("test", conf)
+ val logBaseDir = conf.get("spark.eventLog.dir")
+ val appId = EventLoggingListenerSuite.getUniqueApplicationId
+ val eventLogger = new EventLoggingListener(appId, logBaseDir, conf)
eventLogger.start()
var eventLoggingInfo = EventLoggingListener.parseLoggingInfo(eventLogger.logDir, fileSystem)
assertInfoCorrect(eventLoggingInfo, loggerStopped = false)
@@ -228,7 +232,9 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter {
*/
private def testEventLogging(compressionCodec: Option[String] = None) {
val conf = getLoggingConf(logDirPath, compressionCodec)
- val eventLogger = new EventLoggingListener("test", conf)
+ val logBaseDir = conf.get("spark.eventLog.dir")
+ val appId = EventLoggingListenerSuite.getUniqueApplicationId
+ val eventLogger = new EventLoggingListener(appId, logBaseDir, conf)
val listenerBus = new LiveListenerBus
val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None,
125L, "Mickey")
@@ -408,4 +414,6 @@ object EventLoggingListenerSuite {
}
conf
}
+
+ def getUniqueApplicationId = "test-" + System.currentTimeMillis
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
new file mode 100644
index 0000000000000..79e04f046e4c4
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark.storage.BlockManagerId
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkConf
+import org.apache.spark.serializer.JavaSerializer
+
+
+class MapStatusSuite extends FunSuite {
+
+ test("compressSize") {
+ assert(MapStatus.compressSize(0L) === 0)
+ assert(MapStatus.compressSize(1L) === 1)
+ assert(MapStatus.compressSize(2L) === 8)
+ assert(MapStatus.compressSize(10L) === 25)
+ assert((MapStatus.compressSize(1000000L) & 0xFF) === 145)
+ assert((MapStatus.compressSize(1000000000L) & 0xFF) === 218)
+ // This last size is bigger than we can encode in a byte, so check that we just return 255
+ assert((MapStatus.compressSize(1000000000000000000L) & 0xFF) === 255)
+ }
+
+ test("decompressSize") {
+ assert(MapStatus.decompressSize(0) === 0)
+ for (size <- Seq(2L, 10L, 100L, 50000L, 1000000L, 1000000000L)) {
+ val size2 = MapStatus.decompressSize(MapStatus.compressSize(size))
+ assert(size2 >= 0.99 * size && size2 <= 1.11 * size,
+ "size " + size + " decompressed to " + size2 + ", which is out of range")
+ }
+ }
+
+ test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) {
+ val sizes = Array.fill[Long](2001)(150L)
+ val status = MapStatus(null, sizes)
+ assert(status.isInstanceOf[HighlyCompressedMapStatus])
+ assert(status.getSizeForBlock(10) === 150L)
+ assert(status.getSizeForBlock(50) === 150L)
+ assert(status.getSizeForBlock(99) === 150L)
+ assert(status.getSizeForBlock(2000) === 150L)
+ }
+
+ test(classOf[HighlyCompressedMapStatus].getName + ": estimated size is within 10%") {
+ val sizes = Array.tabulate[Long](50) { i => i.toLong }
+ val loc = BlockManagerId("a", "b", 10)
+ val status = MapStatus(loc, sizes)
+ val ser = new JavaSerializer(new SparkConf)
+ val buf = ser.newInstance().serialize(status)
+ val status1 = ser.newInstance().deserialize[MapStatus](buf)
+ assert(status1.location == loc)
+ for (i <- 0 until sizes.length) {
+ // make sure the estimated size is within 10% of the input; note that we skip the very small
+ // sizes because the compression is very lossy there.
+ val estimate = status1.getSizeForBlock(i)
+ if (estimate > 100) {
+ assert(math.abs(estimate - sizes(i)) * 10 <= sizes(i),
+ s"incorrect estimated size $estimate, original was ${sizes(i)}")
+ }
+ }
+ }
+
+ test(classOf[HighlyCompressedMapStatus].getName + ": estimated size should be the average size") {
+ val sizes = Array.tabulate[Long](3000) { i => i.toLong }
+ val avg = sizes.sum / sizes.length
+ val loc = BlockManagerId("a", "b", 10)
+ val status = MapStatus(loc, sizes)
+ val ser = new JavaSerializer(new SparkConf)
+ val buf = ser.newInstance().serialize(status)
+ val status1 = ser.newInstance().deserialize[MapStatus](buf)
+ assert(status1.location == loc)
+ for (i <- 0 until 3000) {
+ val estimate = status1.getSizeForBlock(i)
+ assert(estimate === avg)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
index 7ab351d1b4d24..48114feee6233 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
@@ -155,7 +155,8 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter {
* This child listener inherits only the event buffering functionality, but does not actually
* log the events.
*/
- private class EventMonster(conf: SparkConf) extends EventLoggingListener("test", conf) {
+ private class EventMonster(conf: SparkConf)
+ extends EventLoggingListener("test", "testdir", conf) {
logger.close()
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 93e8ddacf8865..c0b07649eb6dd 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -642,6 +642,28 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
assert(manager.resourceOffer("execC", "host3", ANY) !== None)
}
+ test("Test that locations with HDFSCacheTaskLocation are treated as PROCESS_LOCAL.") {
+ // Regression test for SPARK-2931
+ sc = new SparkContext("local", "test")
+ val sched = new FakeTaskScheduler(sc,
+ ("execA", "host1"), ("execB", "host2"), ("execC", "host3"))
+ val taskSet = FakeTask.createTaskSet(3,
+ Seq(HostTaskLocation("host1")),
+ Seq(HostTaskLocation("host2")),
+ Seq(HDFSCacheTaskLocation("host3")))
+ val clock = new FakeClock
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
+ assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
+ sched.removeExecutor("execA")
+ manager.executorAdded()
+ assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
+ sched.removeExecutor("execB")
+ manager.executorAdded()
+ assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
+ sched.removeExecutor("execC")
+ manager.executorAdded()
+ assert(manager.myLocalityLevels.sameElements(Array(ANY)))
+ }
def createTaskResult(id: Int): DirectTaskResult[Int] = {
val valueSer = SparkEnv.get.serializer.newInstance()
diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
index aad6599589420..d037e2c19a64d 100644
--- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
@@ -50,8 +50,7 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex
"flatMap" -> xflatMap _,
"filter" -> xfilter _,
"mapPartitions" -> xmapPartitions _,
- "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _,
- "mapPartitionsWithContext" -> xmapPartitionsWithContext _)) {
+ "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _)) {
val (name, xf) = transformation
test(s"$name transformations throw proactive serialization exceptions") {
@@ -78,8 +77,5 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex
private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y)))
-
- private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] =
- x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y)))
}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
new file mode 100644
index 0000000000000..1f1d53a1ee3b0
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.duration._
+import scala.language.implicitConversions
+import scala.language.postfixOps
+
+import akka.actor.{ActorSystem, Props}
+import org.mockito.Mockito.{mock, when}
+import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester}
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
+import org.apache.spark.network.BlockTransferService
+import org.apache.spark.network.nio.NioBlockTransferService
+import org.apache.spark.scheduler.LiveListenerBus
+import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.shuffle.hash.HashShuffleManager
+import org.apache.spark.storage.StorageLevel._
+import org.apache.spark.util.{AkkaUtils, SizeEstimator}
+
+/** Testsuite that tests block replication in BlockManager */
+class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter {
+
+ private val conf = new SparkConf(false)
+ var actorSystem: ActorSystem = null
+ var master: BlockManagerMaster = null
+ val securityMgr = new SecurityManager(conf)
+ val mapOutputTracker = new MapOutputTrackerMaster(conf)
+ val shuffleManager = new HashShuffleManager(conf)
+
+ // List of block manager created during an unit test, so that all of the them can be stopped
+ // after the unit test.
+ val allStores = new ArrayBuffer[BlockManager]
+
+ // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
+ conf.set("spark.kryoserializer.buffer.mb", "1")
+ val serializer = new KryoSerializer(conf)
+
+ // Implicitly convert strings to BlockIds for test clarity.
+ implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value)
+
+ private def makeBlockManager(maxMem: Long, name: String = ""): BlockManager = {
+ val transfer = new NioBlockTransferService(conf, securityMgr)
+ val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
+ mapOutputTracker, shuffleManager, transfer)
+ allStores += store
+ store
+ }
+
+ before {
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
+ "test", "localhost", 0, conf = conf, securityManager = securityMgr)
+ this.actorSystem = actorSystem
+
+ conf.set("spark.authenticate", "false")
+ conf.set("spark.driver.port", boundPort.toString)
+ conf.set("spark.storage.unrollFraction", "0.4")
+ conf.set("spark.storage.unrollMemoryThreshold", "512")
+
+ // to make a replication attempt to inactive store fail fast
+ conf.set("spark.core.connection.ack.wait.timeout", "1")
+ // to make cached peers refresh frequently
+ conf.set("spark.storage.cachedPeersTtl", "10")
+
+ master = new BlockManagerMaster(
+ actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))),
+ conf, true)
+ allStores.clear()
+ }
+
+ after {
+ allStores.foreach { _.stop() }
+ allStores.clear()
+ actorSystem.shutdown()
+ actorSystem.awaitTermination()
+ actorSystem = null
+ master = null
+ }
+
+
+ test("get peers with addition and removal of block managers") {
+ val numStores = 4
+ val stores = (1 to numStores - 1).map { i => makeBlockManager(1000, s"store$i") }
+ val storeIds = stores.map { _.blockManagerId }.toSet
+ assert(master.getPeers(stores(0).blockManagerId).toSet ===
+ storeIds.filterNot { _ == stores(0).blockManagerId })
+ assert(master.getPeers(stores(1).blockManagerId).toSet ===
+ storeIds.filterNot { _ == stores(1).blockManagerId })
+ assert(master.getPeers(stores(2).blockManagerId).toSet ===
+ storeIds.filterNot { _ == stores(2).blockManagerId })
+
+ // Add driver store and test whether it is filtered out
+ val driverStore = makeBlockManager(1000, "")
+ assert(master.getPeers(stores(0).blockManagerId).forall(!_.isDriver))
+ assert(master.getPeers(stores(1).blockManagerId).forall(!_.isDriver))
+ assert(master.getPeers(stores(2).blockManagerId).forall(!_.isDriver))
+
+ // Add a new store and test whether get peers returns it
+ val newStore = makeBlockManager(1000, s"store$numStores")
+ assert(master.getPeers(stores(0).blockManagerId).toSet ===
+ storeIds.filterNot { _ == stores(0).blockManagerId } + newStore.blockManagerId)
+ assert(master.getPeers(stores(1).blockManagerId).toSet ===
+ storeIds.filterNot { _ == stores(1).blockManagerId } + newStore.blockManagerId)
+ assert(master.getPeers(stores(2).blockManagerId).toSet ===
+ storeIds.filterNot { _ == stores(2).blockManagerId } + newStore.blockManagerId)
+ assert(master.getPeers(newStore.blockManagerId).toSet === storeIds)
+
+ // Remove a store and test whether get peers returns it
+ val storeIdToRemove = stores(0).blockManagerId
+ master.removeExecutor(storeIdToRemove.executorId)
+ assert(!master.getPeers(stores(1).blockManagerId).contains(storeIdToRemove))
+ assert(!master.getPeers(stores(2).blockManagerId).contains(storeIdToRemove))
+ assert(!master.getPeers(newStore.blockManagerId).contains(storeIdToRemove))
+
+ // Test whether asking for peers of a unregistered block manager id returns empty list
+ assert(master.getPeers(stores(0).blockManagerId).isEmpty)
+ assert(master.getPeers(BlockManagerId("", "", 1)).isEmpty)
+ }
+
+
+ test("block replication - 2x replication") {
+ testReplication(2,
+ Seq(MEMORY_ONLY, MEMORY_ONLY_SER, DISK_ONLY, MEMORY_AND_DISK_2, MEMORY_AND_DISK_SER_2)
+ )
+ }
+
+ test("block replication - 3x replication") {
+ // Generate storage levels with 3x replication
+ val storageLevels = {
+ Seq(MEMORY_ONLY, MEMORY_ONLY_SER, DISK_ONLY, MEMORY_AND_DISK, MEMORY_AND_DISK_SER).map {
+ level => StorageLevel(
+ level.useDisk, level.useMemory, level.useOffHeap, level.deserialized, 3)
+ }
+ }
+ testReplication(3, storageLevels)
+ }
+
+ test("block replication - mixed between 1x to 5x") {
+ // Generate storage levels with varying replication
+ val storageLevels = Seq(
+ MEMORY_ONLY,
+ MEMORY_ONLY_SER_2,
+ StorageLevel(true, false, false, false, 3),
+ StorageLevel(true, true, false, true, 4),
+ StorageLevel(true, true, false, false, 5),
+ StorageLevel(true, true, false, true, 4),
+ StorageLevel(true, false, false, false, 3),
+ MEMORY_ONLY_SER_2,
+ MEMORY_ONLY
+ )
+ testReplication(5, storageLevels)
+ }
+
+ test("block replication - 2x replication without peers") {
+ intercept[org.scalatest.exceptions.TestFailedException] {
+ testReplication(1,
+ Seq(StorageLevel.MEMORY_AND_DISK_2, StorageLevel(true, false, false, false, 3)))
+ }
+ }
+
+ test("block replication - deterministic node selection") {
+ val blockSize = 1000
+ val storeSize = 10000
+ val stores = (1 to 5).map {
+ i => makeBlockManager(storeSize, s"store$i")
+ }
+ val storageLevel2x = StorageLevel.MEMORY_AND_DISK_2
+ val storageLevel3x = StorageLevel(true, true, false, true, 3)
+ val storageLevel4x = StorageLevel(true, true, false, true, 4)
+
+ def putBlockAndGetLocations(blockId: String, level: StorageLevel): Set[BlockManagerId] = {
+ stores.head.putSingle(blockId, new Array[Byte](blockSize), level)
+ val locations = master.getLocations(blockId).sortBy { _.executorId }.toSet
+ stores.foreach { _.removeBlock(blockId) }
+ master.removeBlock(blockId)
+ locations
+ }
+
+ // Test if two attempts to 2x replication returns same set of locations
+ val a1Locs = putBlockAndGetLocations("a1", storageLevel2x)
+ assert(putBlockAndGetLocations("a1", storageLevel2x) === a1Locs,
+ "Inserting a 2x replicated block second time gave different locations from the first")
+
+ // Test if two attempts to 3x replication returns same set of locations
+ val a2Locs3x = putBlockAndGetLocations("a2", storageLevel3x)
+ assert(putBlockAndGetLocations("a2", storageLevel3x) === a2Locs3x,
+ "Inserting a 3x replicated block second time gave different locations from the first")
+
+ // Test if 2x replication of a2 returns a strict subset of the locations of 3x replication
+ val a2Locs2x = putBlockAndGetLocations("a2", storageLevel2x)
+ assert(
+ a2Locs2x.subsetOf(a2Locs3x),
+ "Inserting a with 2x replication gave locations that are not a subset of locations" +
+ s" with 3x replication [3x: ${a2Locs3x.mkString(",")}; 2x: ${a2Locs2x.mkString(",")}"
+ )
+
+ // Test if 4x replication of a2 returns a strict superset of the locations of 3x replication
+ val a2Locs4x = putBlockAndGetLocations("a2", storageLevel4x)
+ assert(
+ a2Locs3x.subsetOf(a2Locs4x),
+ "Inserting a with 4x replication gave locations that are not a superset of locations " +
+ s"with 3x replication [3x: ${a2Locs3x.mkString(",")}; 4x: ${a2Locs4x.mkString(",")}"
+ )
+
+ // Test if 3x replication of two different blocks gives two different sets of locations
+ val a3Locs3x = putBlockAndGetLocations("a3", storageLevel3x)
+ assert(a3Locs3x !== a2Locs3x, "Two blocks gave same locations with 3x replication")
+ }
+
+ test("block replication - replication failures") {
+ /*
+ Create a system of three block managers / stores. One of them (say, failableStore)
+ cannot receive blocks. So attempts to use that as replication target fails.
+
+ +-----------/fails/-----------> failableStore
+ |
+ normalStore
+ |
+ +-----------/works/-----------> anotherNormalStore
+
+ We are first going to add a normal block manager (i.e. normalStore) and the failable block
+ manager (i.e. failableStore), and test whether 2x replication fails to create two
+ copies of a block. Then we are going to add another normal block manager
+ (i.e., anotherNormalStore), and test that now 2x replication works as the
+ new store will be used for replication.
+ */
+
+ // Add a normal block manager
+ val store = makeBlockManager(10000, "store")
+
+ // Insert a block with 2x replication and return the number of copies of the block
+ def replicateAndGetNumCopies(blockId: String): Int = {
+ store.putSingle(blockId, new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK_2)
+ val numLocations = master.getLocations(blockId).size
+ allStores.foreach { _.removeBlock(blockId) }
+ numLocations
+ }
+
+ // Add a failable block manager with a mock transfer service that does not
+ // allow receiving of blocks. So attempts to use it as a replication target will fail.
+ val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work
+ when(failableTransfer.hostName).thenReturn("some-hostname")
+ when(failableTransfer.port).thenReturn(1000)
+ val failableStore = new BlockManager("failable-store", actorSystem, master, serializer,
+ 10000, conf, mapOutputTracker, shuffleManager, failableTransfer)
+ allStores += failableStore // so that this gets stopped after test
+ assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId))
+
+ // Test that 2x replication fails by creating only one copy of the block
+ assert(replicateAndGetNumCopies("a1") === 1)
+
+ // Add another normal block manager and test that 2x replication works
+ makeBlockManager(10000, "anotherStore")
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ assert(replicateAndGetNumCopies("a2") === 2)
+ }
+ }
+
+ test("block replication - addition and deletion of block managers") {
+ val blockSize = 1000
+ val storeSize = 10000
+ val initialStores = (1 to 2).map { i => makeBlockManager(storeSize, s"store$i") }
+
+ // Insert a block with given replication factor and return the number of copies of the block\
+ def replicateAndGetNumCopies(blockId: String, replicationFactor: Int): Int = {
+ val storageLevel = StorageLevel(true, true, false, true, replicationFactor)
+ initialStores.head.putSingle(blockId, new Array[Byte](blockSize), storageLevel)
+ val numLocations = master.getLocations(blockId).size
+ allStores.foreach { _.removeBlock(blockId) }
+ numLocations
+ }
+
+ // 2x replication should work, 3x replication should only replicate 2x
+ assert(replicateAndGetNumCopies("a1", 2) === 2)
+ assert(replicateAndGetNumCopies("a2", 3) === 2)
+
+ // Add another store, 3x replication should work now, 4x replication should only replicate 3x
+ val newStore1 = makeBlockManager(storeSize, s"newstore1")
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ assert(replicateAndGetNumCopies("a3", 3) === 3)
+ }
+ assert(replicateAndGetNumCopies("a4", 4) === 3)
+
+ // Add another store, 4x replication should work now
+ val newStore2 = makeBlockManager(storeSize, s"newstore2")
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ assert(replicateAndGetNumCopies("a5", 4) === 4)
+ }
+
+ // Remove all but the 1st store, 2x replication should fail
+ (initialStores.tail ++ Seq(newStore1, newStore2)).foreach {
+ store =>
+ master.removeExecutor(store.blockManagerId.executorId)
+ store.stop()
+ }
+ assert(replicateAndGetNumCopies("a6", 2) === 1)
+
+ // Add new stores, 3x replication should work
+ val newStores = (3 to 5).map {
+ i => makeBlockManager(storeSize, s"newstore$i")
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ assert(replicateAndGetNumCopies("a7", 3) === 3)
+ }
+ }
+
+ /**
+ * Test replication of blocks with different storage levels (various combinations of
+ * memory, disk & serialization). For each storage level, this function tests every store
+ * whether the block is present and also tests the master whether its knowledge of blocks
+ * is correct. Then it also drops the block from memory of each store (using LRU) and
+ * again checks whether the master's knowledge gets updated.
+ */
+ private def testReplication(maxReplication: Int, storageLevels: Seq[StorageLevel]) {
+ import org.apache.spark.storage.StorageLevel._
+
+ assert(maxReplication > 1,
+ s"Cannot test replication factor $maxReplication")
+
+ // storage levels to test with the given replication factor
+
+ val storeSize = 10000
+ val blockSize = 1000
+
+ // As many stores as the replication factor
+ val stores = (1 to maxReplication).map {
+ i => makeBlockManager(storeSize, s"store$i")
+ }
+
+ storageLevels.foreach { storageLevel =>
+ // Put the block into one of the stores
+ val blockId = new TestBlockId(
+ "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase)
+ stores(0).putSingle(blockId, new Array[Byte](blockSize), storageLevel)
+
+ // Assert that master know two locations for the block
+ val blockLocations = master.getLocations(blockId).map(_.executorId).toSet
+ assert(blockLocations.size === storageLevel.replication,
+ s"master did not have ${storageLevel.replication} locations for $blockId")
+
+ // Test state of the stores that contain the block
+ stores.filter {
+ testStore => blockLocations.contains(testStore.blockManagerId.executorId)
+ }.foreach { testStore =>
+ val testStoreName = testStore.blockManagerId.executorId
+ assert(testStore.getLocal(blockId).isDefined, s"$blockId was not found in $testStoreName")
+ assert(master.getLocations(blockId).map(_.executorId).toSet.contains(testStoreName),
+ s"master does not have status for ${blockId.name} in $testStoreName")
+
+ val blockStatus = master.getBlockStatus(blockId)(testStore.blockManagerId)
+
+ // Assert that block status in the master for this store has expected storage level
+ assert(
+ blockStatus.storageLevel.useDisk === storageLevel.useDisk &&
+ blockStatus.storageLevel.useMemory === storageLevel.useMemory &&
+ blockStatus.storageLevel.useOffHeap === storageLevel.useOffHeap &&
+ blockStatus.storageLevel.deserialized === storageLevel.deserialized,
+ s"master does not know correct storage level for ${blockId.name} in $testStoreName")
+
+ // Assert that the block status in the master for this store has correct memory usage info
+ assert(!blockStatus.storageLevel.useMemory || blockStatus.memSize >= blockSize,
+ s"master does not know size of ${blockId.name} stored in memory of $testStoreName")
+
+
+ // If the block is supposed to be in memory, then drop the copy of the block in
+ // this store test whether master is updated with zero memory usage this store
+ if (storageLevel.useMemory) {
+ // Force the block to be dropped by adding a number of dummy blocks
+ (1 to 10).foreach {
+ i =>
+ testStore.putSingle(s"dummy-block-$i", new Array[Byte](1000), MEMORY_ONLY_SER)
+ }
+ (1 to 10).foreach {
+ i => testStore.removeBlock(s"dummy-block-$i")
+ }
+
+ val newBlockStatusOption = master.getBlockStatus(blockId).get(testStore.blockManagerId)
+
+ // Assert that the block status in the master either does not exist (block removed
+ // from every store) or has zero memory usage for this store
+ assert(
+ newBlockStatusOption.isEmpty || newBlockStatusOption.get.memSize === 0,
+ s"after dropping, master does not know size of ${blockId.name} " +
+ s"stored in memory of $testStoreName"
+ )
+ }
+
+ // If the block is supposed to be in disk (after dropping or otherwise, then
+ // test whether master has correct disk usage for this store
+ if (storageLevel.useDisk) {
+ assert(master.getBlockStatus(blockId)(testStore.blockManagerId).diskSize >= blockSize,
+ s"after dropping, master does not know size of ${blockId.name} " +
+ s"stored in disk of $testStoreName"
+ )
+ }
+ }
+ master.removeBlock(blockId)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index e251660dae5de..9d96202a3e7ac 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -21,8 +21,6 @@ import java.nio.{ByteBuffer, MappedByteBuffer}
import java.util.Arrays
import java.util.concurrent.TimeUnit
-import org.apache.spark.network.nio.NioBlockTransferService
-
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Await
import scala.concurrent.duration._
@@ -35,13 +33,13 @@ import akka.util.Timeout
import org.mockito.Mockito.{mock, when}
-import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester}
+import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester}
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts._
-import org.scalatest.Matchers
import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
import org.apache.spark.executor.DataReadMethod
+import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.shuffle.hash.HashShuffleManager
@@ -189,7 +187,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
store = makeBlockManager(2000, "exec1")
store2 = makeBlockManager(2000, "exec2")
- val peers = master.getPeers(store.blockManagerId, 1)
+ val peers = master.getPeers(store.blockManagerId)
assert(peers.size === 1, "master did not return the other manager as a peer")
assert(peers.head === store2.blockManagerId, "peer returned by master is not the other manager")
@@ -448,7 +446,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
val list2DiskGet = store.get("list2disk")
assert(list2DiskGet.isDefined, "list2memory expected to be in store")
assert(list2DiskGet.get.data.size === 3)
- System.out.println(list2DiskGet)
// We don't know the exact size of the data on disk, but it should certainly be > 0.
assert(list2DiskGet.get.inputMetrics.bytesRead > 0)
assert(list2DiskGet.get.inputMetrics.readMethod === DataReadMethod.Disk)
diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
index 76bf4cfd11267..7bca1711ae226 100644
--- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
@@ -106,10 +106,9 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
- val compressedSize1000 = MapOutputTracker.compressSize(1000L)
- val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
- masterTracker.registerMapOutput(10, 0, new MapStatus(
- BlockManagerId("a", "hostA", 1000), Array(compressedSize1000)))
+ val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
+ masterTracker.registerMapOutput(10, 0,
+ MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L)))
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
@@ -157,10 +156,9 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
- val compressedSize1000 = MapOutputTracker.compressSize(1000L)
- val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
- masterTracker.registerMapOutput(10, 0, new MapStatus(
- BlockManagerId("a", "hostA", 1000), Array(compressedSize1000)))
+ val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
+ masterTracker.registerMapOutput(10, 0, MapStatus(
+ BlockManagerId("a", "hostA", 1000), Array(1000L)))
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 2b45d8b695853..f1f88c5fd3634 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -356,7 +356,6 @@ class JsonProtocolSuite extends FunSuite {
}
private def assertEquals(metrics1: ShuffleReadMetrics, metrics2: ShuffleReadMetrics) {
- assert(metrics1.shuffleFinishTime === metrics2.shuffleFinishTime)
assert(metrics1.remoteBlocksFetched === metrics2.remoteBlocksFetched)
assert(metrics1.localBlocksFetched === metrics2.localBlocksFetched)
assert(metrics1.fetchWaitTime === metrics2.fetchWaitTime)
@@ -568,7 +567,6 @@ class JsonProtocolSuite extends FunSuite {
t.inputMetrics = Some(inputMetrics)
} else {
val sr = new ShuffleReadMetrics
- sr.shuffleFinishTime = b + c
sr.remoteBytesRead = b + d
sr.localBlocksFetched = e
sr.fetchWaitTime = a + d
@@ -806,7 +804,6 @@ class JsonProtocolSuite extends FunSuite {
| "Memory Bytes Spilled": 800,
| "Disk Bytes Spilled": 0,
| "Shuffle Read Metrics": {
- | "Shuffle Finish Time": 900,
| "Remote Blocks Fetched": 800,
| "Local Blocks Fetched": 700,
| "Fetch Wait Time": 900,
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 70d423ba8a04d..e63d9d085e385 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -189,17 +189,28 @@ class UtilsSuite extends FunSuite {
assert(Utils.getIteratorSize(iterator) === 5L)
}
- test("findOldFiles") {
+ test("doesDirectoryContainFilesNewerThan") {
// create some temporary directories and files
val parent: File = Utils.createTempDir()
val child1: File = Utils.createTempDir(parent.getCanonicalPath) // The parent directory has two child directories
val child2: File = Utils.createTempDir(parent.getCanonicalPath)
- // set the last modified time of child1 to 10 secs old
- child1.setLastModified(System.currentTimeMillis() - (1000 * 10))
+ val child3: File = Utils.createTempDir(child1.getCanonicalPath)
+ // set the last modified time of child1 to 30 secs old
+ child1.setLastModified(System.currentTimeMillis() - (1000 * 30))
- val result = Utils.findOldFiles(parent, 5) // find files older than 5 secs
- assert(result.size.equals(1))
- assert(result(0).getCanonicalPath.equals(child1.getCanonicalPath))
+ // although child1 is old, child2 is still new so return true
+ assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5))
+
+ child2.setLastModified(System.currentTimeMillis - (1000 * 30))
+ assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5))
+
+ parent.setLastModified(System.currentTimeMillis - (1000 * 30))
+ // although parent and its immediate children are new, child3 is still old
+ // we expect a full recursive search for new files.
+ assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5))
+
+ child3.setLastModified(System.currentTimeMillis - (1000 * 30))
+ assert(!Utils.doesDirectoryContainAnyNewFiles(parent, 5))
}
test("resolveURI") {
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index 706faed980f31..f26e40fbd4b36 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -24,6 +24,8 @@ import org.scalatest.{PrivateMethodTester, FunSuite}
import org.apache.spark._
import org.apache.spark.SparkContext._
+import scala.util.Random
+
class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMethodTester {
private def createSparkConf(loadDefaults: Boolean): SparkConf = {
val conf = new SparkConf(loadDefaults)
@@ -707,4 +709,57 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
Some(agg), Some(new HashPartitioner(FEW_PARTITIONS)), None, None)
assertDidNotBypassMergeSort(sorter4)
}
+
+ test("sort without breaking sorting contracts") {
+ val conf = createSparkConf(true)
+ conf.set("spark.shuffle.memoryFraction", "0.01")
+ conf.set("spark.shuffle.manager", "sort")
+ sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+
+ // Using wrongOrdering to show integer overflow introduced exception.
+ val rand = new Random(100L)
+ val wrongOrdering = new Ordering[String] {
+ override def compare(a: String, b: String) = {
+ val h1 = if (a == null) 0 else a.hashCode()
+ val h2 = if (b == null) 0 else b.hashCode()
+ h1 - h2
+ }
+ }
+
+ val testData = Array.tabulate(100000) { _ => rand.nextInt().toString }
+
+ val sorter1 = new ExternalSorter[String, String, String](
+ None, None, Some(wrongOrdering), None)
+ val thrown = intercept[IllegalArgumentException] {
+ sorter1.insertAll(testData.iterator.map(i => (i, i)))
+ sorter1.iterator
+ }
+
+ assert(thrown.getClass() === classOf[IllegalArgumentException])
+ assert(thrown.getMessage().contains("Comparison method violates its general contract"))
+ sorter1.stop()
+
+ // Using aggregation and external spill to make sure ExternalSorter using
+ // partitionKeyComparator.
+ def createCombiner(i: String) = ArrayBuffer(i)
+ def mergeValue(c: ArrayBuffer[String], i: String) = c += i
+ def mergeCombiners(c1: ArrayBuffer[String], c2: ArrayBuffer[String]) = c1 ++= c2
+
+ val agg = new Aggregator[String, String, ArrayBuffer[String]](
+ createCombiner, mergeValue, mergeCombiners)
+
+ val sorter2 = new ExternalSorter[String, String, ArrayBuffer[String]](
+ Some(agg), None, None, None)
+ sorter2.insertAll(testData.iterator.map(i => (i, i)))
+
+ // To validate the hash ordering of key
+ var minKey = Int.MinValue
+ sorter2.iterator.foreach { case (k, v) =>
+ val h = k.hashCode()
+ assert(h >= minKey)
+ minKey = h
+ }
+
+ sorter2.stop()
+ }
}
diff --git a/dev/check-license b/dev/check-license
index 9ff0929e9a5e8..72b1013479964 100755
--- a/dev/check-license
+++ b/dev/check-license
@@ -20,11 +20,10 @@
acquire_rat_jar () {
- URL1="http://search.maven.org/remotecontent?filepath=org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar"
- URL2="http://repo1.maven.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar"
+ URL="http://repo1.maven.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar"
JAR="$rat_jar"
-
+
if [[ ! -f "$rat_jar" ]]; then
# Download rat launch jar if it hasn't been downloaded yet
if [ ! -f "$JAR" ]; then
@@ -32,15 +31,17 @@ acquire_rat_jar () {
printf "Attempting to fetch rat\n"
JAR_DL="${JAR}.part"
if hash curl 2>/dev/null; then
- (curl --silent "${URL1}" > "$JAR_DL" || curl --silent "${URL2}" > "$JAR_DL") && mv "$JAR_DL" "$JAR"
+ curl --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR"
elif hash wget 2>/dev/null; then
- (wget --quiet ${URL1} -O "$JAR_DL" || wget --quiet ${URL2} -O "$JAR_DL") && mv "$JAR_DL" "$JAR"
+ wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR"
else
printf "You do not have curl or wget installed, please install rat manually.\n"
exit -1
fi
fi
- if [ ! -f "$JAR" ]; then
+
+ unzip -tq $JAR &> /dev/null
+ if [ $? -ne 0 ]; then
# We failed to download
printf "Our attempt to download rat locally to ${JAR} failed. Please install rat manually.\n"
exit -1
@@ -55,7 +56,7 @@ cd "$FWDIR"
if test -x "$JAVA_HOME/bin/java"; then
declare java_cmd="$JAVA_HOME/bin/java"
-else
+else
declare java_cmd=java
fi
diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py
index a8e92e36fe0d8..02ac20984add9 100755
--- a/dev/merge_spark_pr.py
+++ b/dev/merge_spark_pr.py
@@ -73,11 +73,10 @@ def fail(msg):
def run_cmd(cmd):
+ print cmd
if isinstance(cmd, list):
- print " ".join(cmd)
return subprocess.check_output(cmd)
else:
- print cmd
return subprocess.check_output(cmd.split(" "))
diff --git a/dev/run-tests b/dev/run-tests
index 53148d23f385f..4be2baaf48cd1 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -24,6 +24,16 @@ cd "$FWDIR"
# Remove work directory
rm -rf ./work
+source "$FWDIR/dev/run-tests-codes.sh"
+
+CURRENT_BLOCK=$BLOCK_GENERAL
+
+function handle_error () {
+ echo "[error] Got a return code of $? on line $1 of the run-tests script."
+ exit $CURRENT_BLOCK
+}
+
+
# Build against the right verison of Hadoop.
{
if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then
@@ -91,26 +101,34 @@ if [ -n "$AMPLAB_JENKINS" ]; then
fi
fi
-# Fail fast
-set -e
set -o pipefail
+trap 'handle_error $LINENO' ERR
echo ""
echo "========================================================================="
echo "Running Apache RAT checks"
echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_RAT
+
./dev/check-license
echo ""
echo "========================================================================="
echo "Running Scala style checks"
echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_SCALA_STYLE
+
./dev/lint-scala
echo ""
echo "========================================================================="
echo "Running Python style checks"
echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_PYTHON_STYLE
+
./dev/lint-python
echo ""
@@ -118,6 +136,8 @@ echo "========================================================================="
echo "Building Spark"
echo "========================================================================="
+CURRENT_BLOCK=$BLOCK_BUILD
+
{
# We always build with Hive because the PySpark Spark SQL tests need it.
BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive"
@@ -127,6 +147,8 @@ echo "========================================================================="
# NOTE: echo "q" is needed because sbt on encountering a build file with failure
#+ (either resolution or compilation) prompts the user for input either q, r, etc
#+ to quit or retry. This echo is there to make it not block.
+ # NOTE: Do not quote $BUILD_MVN_PROFILE_ARGS or else it will be interpreted as a
+ #+ single argument!
# QUESTION: Why doesn't 'yes "q"' work?
# QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work?
echo -e "q\n" \
@@ -139,27 +161,35 @@ echo "========================================================================="
echo "Running Spark unit tests"
echo "========================================================================="
+CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS
+
{
# If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled.
+ # This must be a single argument, as it is.
if [ -n "$_RUN_SQL_TESTS" ]; then
SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive"
fi
if [ -n "$_SQL_TESTS_ONLY" ]; then
- SBT_MAVEN_TEST_ARGS="catalyst/test sql/test hive/test"
+ # This must be an array of individual arguments. Otherwise, having one long string
+ #+ will be interpreted as a single test, which doesn't work.
+ SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "hive-thriftserver/test")
else
- SBT_MAVEN_TEST_ARGS="test"
+ SBT_MAVEN_TEST_ARGS=("test")
fi
- echo "[info] Running Spark tests with these arguments: $SBT_MAVEN_PROFILES_ARGS $SBT_MAVEN_TEST_ARGS"
+ echo "[info] Running Spark tests with these arguments: $SBT_MAVEN_PROFILES_ARGS ${SBT_MAVEN_TEST_ARGS[@]}"
# NOTE: echo "q" is needed because sbt on encountering a build file with failure
#+ (either resolution or compilation) prompts the user for input either q, r, etc
#+ to quit or retry. This echo is there to make it not block.
+ # NOTE: Do not quote $SBT_MAVEN_PROFILES_ARGS or else it will be interpreted as a
+ #+ single argument!
+ #+ "${SBT_MAVEN_TEST_ARGS[@]}" is cool because it's an array.
# QUESTION: Why doesn't 'yes "q"' work?
# QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work?
echo -e "q\n" \
- | sbt/sbt "$SBT_MAVEN_PROFILES_ARGS" "$SBT_MAVEN_TEST_ARGS" \
+ | sbt/sbt $SBT_MAVEN_PROFILES_ARGS "${SBT_MAVEN_TEST_ARGS[@]}" \
| grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
}
@@ -167,10 +197,16 @@ echo ""
echo "========================================================================="
echo "Running PySpark tests"
echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS
+
./python/run-tests
echo ""
echo "========================================================================="
echo "Detecting binary incompatibilites with MiMa"
echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_MIMA
+
./dev/mima
diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh
new file mode 100644
index 0000000000000..1348e0609dda4
--- /dev/null
+++ b/dev/run-tests-codes.sh
@@ -0,0 +1,27 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+readonly BLOCK_GENERAL=10
+readonly BLOCK_RAT=11
+readonly BLOCK_SCALA_STYLE=12
+readonly BLOCK_PYTHON_STYLE=13
+readonly BLOCK_BUILD=14
+readonly BLOCK_SPARK_UNIT_TESTS=15
+readonly BLOCK_PYSPARK_UNIT_TESTS=16
+readonly BLOCK_MIMA=17
diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins
index 06c3781eb3ccf..451f3b771cc76 100755
--- a/dev/run-tests-jenkins
+++ b/dev/run-tests-jenkins
@@ -26,9 +26,23 @@
FWDIR="$(cd `dirname $0`/..; pwd)"
cd "$FWDIR"
+source "$FWDIR/dev/run-tests-codes.sh"
+
COMMENTS_URL="https://api.github.com/repos/apache/spark/issues/$ghprbPullId/comments"
PULL_REQUEST_URL="https://github.com/apache/spark/pull/$ghprbPullId"
+# Important Environment Variables
+# ---
+# $ghprbActualCommit
+#+ This is the hash of the most recent commit in the PR.
+#+ The merge-base of this and master is the commit from which the PR was branched.
+# $sha1
+#+ If the patch merges cleanly, this is a reference to the merge commit hash
+#+ (e.g. "origin/pr/2606/merge").
+#+ If the patch does not merge cleanly, it is equal to $ghprbActualCommit.
+#+ The merge-base of this and master in the case of a clean merge is the most recent commit
+#+ against master.
+
COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}"
# GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :(
SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}"
@@ -84,42 +98,46 @@ function post_message () {
fi
}
+
+# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR
+#+ and not anything else added to master since the PR was branched.
+
# check PR merge-ability and check for new public classes
{
if [ "$sha1" == "$ghprbActualCommit" ]; then
- merge_note=" * This patch **does not** merge cleanly!"
+ merge_note=" * This patch **does not merge cleanly**."
else
merge_note=" * This patch merges cleanly."
+ fi
+
+ source_files=$(
+ git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \
+ | grep -v -e "\/test" `# ignore files in test directories` \
+ | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \
+ | tr "\n" " "
+ )
+ new_public_classes=$(
+ git diff master...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \
+ | grep "^\+" `# filter in only added lines` \
+ | sed -r -e "s/^\+//g" `# remove the leading +` \
+ | grep -e "trait " -e "class " `# filter in lines with these key words` \
+ | grep -e "{" -e "(" `# filter in lines with these key words, too` \
+ | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \
+ | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \
+ | sed -r -e "s/\{.*//g" `# remove from the { onwards` \
+ | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \
+ | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \
+ | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \
+ | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \
+ | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \
+ | tr -d "\n" `# remove actual LF characters`
+ )
- source_files=$(
- git diff master --name-only \
- | grep -v -e "\/test" `# ignore files in test directories` \
- | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \
- | tr "\n" " "
- )
- new_public_classes=$(
- git diff master ${source_files} `# diff this patch against master and...` \
- | grep "^\+" `# filter in only added lines` \
- | sed -r -e "s/^\+//g" `# remove the leading +` \
- | grep -e "trait " -e "class " `# filter in lines with these key words` \
- | grep -e "{" -e "(" `# filter in lines with these key words, too` \
- | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \
- | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \
- | sed -r -e "s/\{.*//g" `# remove from the { onwards` \
- | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \
- | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \
- | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \
- | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \
- | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \
- | tr -d "\n" `# remove actual LF characters`
- )
-
- if [ "$new_public_classes" == "" ]; then
- public_classes_note=" * This patch adds no public classes."
- else
- public_classes_note=" * This patch adds the following public classes _(experimental)_:"
- public_classes_note="${public_classes_note}\n${new_public_classes}"
- fi
+ if [ -z "$new_public_classes" ]; then
+ public_classes_note=" * This patch adds no public classes."
+ else
+ public_classes_note=" * This patch adds the following public classes _(experimental)_:"
+ public_classes_note="${public_classes_note}\n${new_public_classes}"
fi
}
@@ -141,16 +159,36 @@ function post_message () {
test_result="$?"
if [ "$test_result" -eq "124" ]; then
- fail_message="**[Tests timed out](${BUILD_URL}consoleFull)** after \
- a configured wait of \`${TESTS_TIMEOUT}\`."
+ fail_message="**[Tests timed out](${BUILD_URL}consoleFull)** \
+ for PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL}) \
+ after a configured wait of \`${TESTS_TIMEOUT}\`."
+
post_message "$fail_message"
exit $test_result
+ elif [ "$test_result" -eq "0" ]; then
+ test_result_note=" * This patch **passes all tests**."
else
- if [ "$test_result" -eq "0" ]; then
- test_result_note=" * This patch **passes** unit tests."
+ if [ "$test_result" -eq "$BLOCK_GENERAL" ]; then
+ failing_test="some tests"
+ elif [ "$test_result" -eq "$BLOCK_RAT" ]; then
+ failing_test="RAT tests"
+ elif [ "$test_result" -eq "$BLOCK_SCALA_STYLE" ]; then
+ failing_test="Scala style tests"
+ elif [ "$test_result" -eq "$BLOCK_PYTHON_STYLE" ]; then
+ failing_test="Python style tests"
+ elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then
+ failing_test="to build"
+ elif [ "$test_result" -eq "$BLOCK_SPARK_UNIT_TESTS" ]; then
+ failing_test="Spark unit tests"
+ elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then
+ failing_test="PySpark unit tests"
+ elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then
+ failing_test="MiMa tests"
else
- test_result_note=" * This patch **fails** unit tests."
+ failing_test="some tests"
fi
+
+ test_result_note=" * This patch **fails $failing_test**."
fi
}
diff --git a/docs/README.md b/docs/README.md
index fdc89d2eb767a..79708c3df9106 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -20,12 +20,16 @@ In this directory you will find textfiles formatted using Markdown, with an ".md
read those text files directly if you want. Start with index.md.
The markdown code can be compiled to HTML using the [Jekyll tool](http://jekyllrb.com).
-To use the `jekyll` command, you will need to have Jekyll installed.
-The easiest way to do this is via a Ruby Gem, see the
-[jekyll installation instructions](http://jekyllrb.com/docs/installation).
-If not already installed, you need to install `kramdown` and `jekyll-redirect-from` Gems
-with `sudo gem install kramdown jekyll-redirect-from`.
-Execute `jekyll build` from the `docs/` directory. Compiling the site with Jekyll will create a directory
+`Jekyll` and a few dependencies must be installed for this to work. We recommend
+installing via the Ruby Gem dependency manager. Since the exact HTML output
+varies between versions of Jekyll and its dependencies, we list specific versions here
+in some cases:
+
+ $ sudo gem install jekyll -v 1.4.3
+ $ sudo gem uninstall kramdown -v 1.4.1
+ $ sudo gem install jekyll-redirect-from
+
+Execute `jekyll` from the `docs/` directory. Compiling the site with Jekyll will create a directory
called `_site` containing index.html as well as the rest of the compiled files.
You can modify the default Jekyll build as follows:
diff --git a/docs/_config.yml b/docs/_config.yml
index d3ea2625c7448..7bc3a78e2d265 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -3,6 +3,11 @@ markdown: kramdown
gems:
- jekyll-redirect-from
+# For some reason kramdown seems to behave differently on different
+# OS/packages wrt encoding. So we hard code this config.
+kramdown:
+ entity_output: numeric
+
# These allow the documentation to be updated with nerw releases
# of Spark, Scala, and Mesos.
SPARK_VERSION: 1.0.0-SNAPSHOT
diff --git a/docs/building-spark.md b/docs/building-spark.md
index 2378092d4a1a8..b2940ee4029e8 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -169,7 +169,22 @@ compilation. More advanced developers may wish to use SBT.
The SBT build is derived from the Maven POM files, and so the same Maven profiles and variables
can be set to control the SBT build. For example:
- sbt/sbt -Pyarn -Phadoop-2.3 compile
+ sbt/sbt -Pyarn -Phadoop-2.3 assembly
+
+# Testing with SBT
+
+Some of the tests require Spark to be packaged first, so always run `sbt/sbt assembly` the first time. The following is an example of a correct (build, test) sequence:
+
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive assembly
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive test
+
+To run only a specific test suite as follows:
+
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive "test-only org.apache.spark.repl.ReplSuite"
+
+To run test suites of a specific sub project as follows:
+
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive core/test
# Speeding up Compilation with Zinc
diff --git a/docs/configuration.md b/docs/configuration.md
index 99faf51c6f3db..1c33855365170 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -206,6 +206,25 @@ Apart from these, the following properties are also available, and may be useful
used during aggregation goes above this amount, it will spill the data into disks.
+
+
spark.python.profile
+
false
+
+ Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`,
+ or it will be displayed before the driver exiting. It also can be dumped into disk by
+ `sc.dump_profiles(path)`. If some of the profile results had been displayed maually,
+ they will not be displayed automatically before driver exiting.
+
+
+
+
spark.python.profile.dump
+
(none)
+
+ The directory which is used to dump the profile result before driver exiting.
+ The results will be dumped as separated file for each RDD. They can be loaded
+ by ptats.Stats(). If this is specified, the profile result will not be displayed
+ automatically.
+
spark.python.worker.reuse
true
@@ -234,6 +253,17 @@ Apart from these, the following properties are also available, and may be useful
spark.executor.uri.
+
+
spark.mesos.executor.memoryOverhead
+
executor memory * 0.07, with minimum of 384
+
+ This value is an additive for spark.executor.memory, specified in MiB,
+ which is used to calculate the total Mesos task memory. A value of 384
+ implies a 384MiB overhead. Additionally, there is a hard-coded 7% minimum
+ overhead. The final overhead will be the larger of either
+ `spark.mesos.executor.memoryOverhead` or 7% of `spark.executor.memory`.
+
+
#### Shuffle Behavior
@@ -394,10 +424,11 @@ Apart from these, the following properties are also available, and may be useful
spark.io.compression.codec
snappy
- The codec used to compress internal data such as RDD partitions and shuffle outputs. By default,
- Spark provides three codecs: lz4, lzf, and snappy. You
- can also use fully qualified class names to specify the codec, e.g.
- org.apache.spark.io.LZ4CompressionCodec,
+ The codec used to compress internal data such as RDD partitions, broadcast variables and
+ shuffle outputs. By default, Spark provides three codecs: lz4, lzf,
+ and snappy. You can also use fully qualified class names to specify the codec,
+ e.g.
+ org.apache.spark.io.LZ4CompressionCodec,
org.apache.spark.io.LZFCompressionCodec,
and org.apache.spark.io.SnappyCompressionCodec.
@@ -657,7 +688,7 @@ Apart from these, the following properties are also available, and may be useful
spark.port.maxRetries
16
- Maximum number of retries when binding to a port before giving up.
+ Default maximum number of retries when binding to a port before giving up.
@@ -1088,3 +1119,10 @@ compute `SPARK_LOCAL_IP` by looking up the IP of a specific network interface.
Spark uses [log4j](http://logging.apache.org/log4j/) for logging. You can configure it by adding a
`log4j.properties` file in the `conf` directory. One way to start is to copy the existing
`log4j.properties.template` located there.
+
+# Overriding configuration directory
+
+To specify a different configuration directory other than the default "SPARK_HOME/conf",
+you can set SPARK_CONF_DIR. Spark will use the the configuration files (spark-defaults.conf, spark-env.sh, log4j.properties, etc)
+from this directory.
+
diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md
index b2ca6a9b48f32..530798f2b8022 100644
--- a/docs/ec2-scripts.md
+++ b/docs/ec2-scripts.md
@@ -48,6 +48,15 @@ by looking for the "Name" tag of the instance in the Amazon EC2 Console.
key pair, `` is the number of slave nodes to launch (try
1 at first), and `` is the name to give to your
cluster.
+
+ For example:
+
+ ```bash
+ export AWS_SECRET_ACCESS_KEY=AaBbCcDdEeFGgHhIiJjKkLlMmNnOoPpQqRrSsTtU
+export AWS_ACCESS_KEY_ID=ABCDEFG1234567890123
+./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a --spark-version=1.1.0 launch my-spark-cluster
+ ```
+
- After everything launches, check that the cluster scheduler is up and sees
all the slaves by going to its web UI, which will be printed at the end of
the script (typically `http://:8080`).
@@ -55,27 +64,27 @@ by looking for the "Name" tag of the instance in the Amazon EC2 Console.
You can also run `./spark-ec2 --help` to see more usage options. The
following options are worth pointing out:
-- `--instance-type=` can be used to specify an EC2
+- `--instance-type=` can be used to specify an EC2
instance type to use. For now, the script only supports 64-bit instance
types, and the default type is `m1.large` (which has 2 cores and 7.5 GB
RAM). Refer to the Amazon pages about [EC2 instance
types](http://aws.amazon.com/ec2/instance-types) and [EC2
pricing](http://aws.amazon.com/ec2/#pricing) for information about other
instance types.
-- `--region=` specifies an EC2 region in which to launch
+- `--region=` specifies an EC2 region in which to launch
instances. The default region is `us-east-1`.
-- `--zone=` can be used to specify an EC2 availability zone
+- `--zone=` can be used to specify an EC2 availability zone
to launch instances in. Sometimes, you will get an error because there
is not enough capacity in one zone, and you should try to launch in
another.
-- `--ebs-vol-size=GB` will attach an EBS volume with a given amount
+- `--ebs-vol-size=` will attach an EBS volume with a given amount
of space to each node so that you can have a persistent HDFS cluster
on your nodes across cluster restarts (see below).
-- `--spot-price=PRICE` will launch the worker nodes as
+- `--spot-price=` will launch the worker nodes as
[Spot Instances](http://aws.amazon.com/ec2/spot-instances/),
bidding for the given maximum price (in dollars).
-- `--spark-version=VERSION` will pre-load the cluster with the
- specified version of Spark. VERSION can be a version number
+- `--spark-version=` will pre-load the cluster with the
+ specified version of Spark. The `` can be a version number
(e.g. "0.7.3") or a specific git hash. By default, a recent
version will be used.
- If one of your launches fails due to e.g. not having the right
@@ -137,11 +146,11 @@ cost you any EC2 cycles, but ***will*** continue to cost money for EBS
storage.
- To stop one of your clusters, go into the `ec2` directory and run
-`./spark-ec2 stop `.
+`./spark-ec2 --region= stop `.
- To restart it later, run
-`./spark-ec2 -i start `.
+`./spark-ec2 -i --region= start `.
- To ultimately destroy the cluster and stop consuming EBS space, run
-`./spark-ec2 destroy ` as described in the previous
+`./spark-ec2 --region= destroy ` as described in the previous
section.
# Limitations
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index dfd9cd572888c..d10bd63746629 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -52,7 +52,7 @@ import org.apache.spark.mllib.linalg.Vectors
// Load and parse the data
val data = sc.textFile("data/mllib/kmeans_data.txt")
-val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble)))
+val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache()
// Cluster the data into two classes using KMeans
val numClusters = 2
@@ -100,6 +100,7 @@ public class KMeansExample {
}
}
);
+ parsedData.cache();
// Cluster the data into two classes using KMeans
int numClusters = 2;
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md
index 44f0f76220b6e..1511ae6dda4ed 100644
--- a/docs/mllib-feature-extraction.md
+++ b/docs/mllib-feature-extraction.md
@@ -68,7 +68,7 @@ val sc: SparkContext = ...
val documents: RDD[Seq[String]] = sc.textFile("...").map(_.split(" ").toSeq)
val hashingTF = new HashingTF()
-val tf: RDD[Vector] = hasingTF.transform(documents)
+val tf: RDD[Vector] = hashingTF.transform(documents)
{% endhighlight %}
While applying `HashingTF` only needs a single pass to the data, applying `IDF` needs two passes:
@@ -82,6 +82,21 @@ tf.cache()
val idf = new IDF().fit(tf)
val tfidf: RDD[Vector] = idf.transform(tf)
{% endhighlight %}
+
+MLLib's IDF implementation provides an option for ignoring terms which occur in less than a
+minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature
+can be used by passing the `minDocFreq` value to the IDF constructor.
+
+{% highlight scala %}
+import org.apache.spark.mllib.feature.IDF
+
+// ... continue from the previous example
+tf.cache()
+val idf = new IDF(minDocFreq = 2).fit(tf)
+val tfidf: RDD[Vector] = idf.transform(tf)
+{% endhighlight %}
+
+
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md
index 9137f9dc1b692..d31bec3e1bd01 100644
--- a/docs/mllib-linear-methods.md
+++ b/docs/mllib-linear-methods.md
@@ -396,7 +396,7 @@ val data = sc.textFile("data/mllib/ridge-data/lpsa.data")
val parsedData = data.map { line =>
val parts = line.split(',')
LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
-}
+}.cache()
// Building the model
val numIterations = 100
@@ -455,6 +455,7 @@ public class LinearRegression {
}
}
);
+ parsedData.cache();
// Building the model
int numIterations = 100;
@@ -470,7 +471,7 @@ public class LinearRegression {
}
}
);
- JavaRDD