then `rdd` contains * {{{ diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 29ca751519abd..61b125ef7c6c1 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -19,7 +19,6 @@ package org.apache.spark.api.python import java.io._ import java.net._ -import java.nio.charset.Charset import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ @@ -27,6 +26,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials +import com.google.common.base.Charsets.UTF_8 import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.hadoop.conf.Configuration @@ -75,6 +75,7 @@ private[spark] class PythonRDD( var complete_cleanly = false context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() + writerThread.join() if (reuse_worker && complete_cleanly) { env.releasePythonWorker(pythonExec, envVars.toMap, worker) } else { @@ -133,7 +134,7 @@ private[spark] class PythonRDD( val exLength = stream.readInt() val obj = new Array[Byte](exLength) stream.readFully(obj) - throw new PythonException(new String(obj, "utf-8"), + throw new PythonException(new String(obj, UTF_8), writerThread.exception.getOrElse(null)) case SpecialLengths.END_OF_DATA_SECTION => // We've finished the data section of the output, but we can still @@ -145,7 +146,9 @@ private[spark] class PythonRDD( stream.readFully(update) accumulator += Collections.singletonList(update) } - complete_cleanly = true + if (stream.readInt() == SpecialLengths.END_OF_STREAM) { + complete_cleanly = true + } null } } catch { @@ -154,6 +157,10 @@ private[spark] class PythonRDD( logDebug("Exception thrown after task interruption", e) throw new TaskKilledException + case e: Exception if env.isStopped => + logDebug("Exception thrown after context is stopped", e) + null // exit silently + case e: Exception if writerThread.exception.isDefined => logError("Python worker exited unexpectedly (crashed)", e) logError("This may have been caused by a prior exception:", writerThread.exception.get) @@ -235,6 +242,7 @@ private[spark] class PythonRDD( // Data values PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() } catch { case e: Exception if context.isCompleted || context.isInterrupted => @@ -306,10 +314,10 @@ private object SpecialLengths { val END_OF_DATA_SECTION = -1 val PYTHON_EXCEPTION_THROWN = -2 val TIMING_DATA = -3 + val END_OF_STREAM = -4 } private[spark] object PythonRDD extends Logging { - val UTF8 = Charset.forName("UTF-8") // remember the broadcasts sent to each worker private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() @@ -577,7 +585,7 @@ private[spark] object PythonRDD extends Logging { } def writeUTF(str: String, dataOut: DataOutputStream) { - val bytes = str.getBytes(UTF8) + val bytes = str.getBytes(UTF_8) dataOut.writeInt(bytes.length) dataOut.write(bytes) } @@ -840,7 +848,7 @@ private[spark] object PythonRDD extends Logging { private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { - override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8) + override def call(arr: Array[Byte]) : String = new String(arr, UTF_8) } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index d11db978b842e..e9ca9166eb4d6 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -18,7 +18,8 @@ package org.apache.spark.api.python import java.io.{DataOutput, DataInput} -import java.nio.charset.Charset + +import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.io._ import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat @@ -136,7 +137,7 @@ object WriteInputFormatTestDataGenerator { sc.parallelize(intKeys).saveAsSequenceFile(intPath) sc.parallelize(intKeys.map{ case (k, v) => (k.toDouble, v) }).saveAsSequenceFile(doublePath) sc.parallelize(intKeys.map{ case (k, v) => (k.toString, v) }).saveAsSequenceFile(textPath) - sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(Charset.forName("UTF-8"))) } + sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(UTF_8)) } ).saveAsSequenceFile(bytesPath) val bools = Seq((1, true), (2, true), (2, false), (3, true), (2, false), (1, false)) sc.parallelize(bools).saveAsSequenceFile(boolPath) diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 15fd30e65761d..87f5cf944ed85 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -20,6 +20,8 @@ package org.apache.spark.broadcast import java.io.Serializable import org.apache.spark.SparkException +import org.apache.spark.Logging +import org.apache.spark.util.Utils import scala.reflect.ClassTag @@ -52,7 +54,7 @@ import scala.reflect.ClassTag * @param id A unique identifier for the broadcast variable. * @tparam T Type of the data contained in the broadcast variable. */ -abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { +abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Logging { /** * Flag signifying whether the broadcast variable is valid @@ -60,6 +62,8 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { */ @volatile private var _isValid = true + private var _destroySite = "" + /** Get the broadcasted value. */ def value: T = { assertValid() @@ -84,13 +88,26 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { doUnpersist(blocking) } + + /** + * Destroy all data and metadata related to this broadcast variable. Use this with caution; + * once a broadcast variable has been destroyed, it cannot be used again. + * This method blocks until destroy has completed + */ + def destroy() { + destroy(blocking = true) + } + /** * Destroy all data and metadata related to this broadcast variable. Use this with caution; * once a broadcast variable has been destroyed, it cannot be used again. + * @param blocking Whether to block until destroy has completed */ private[spark] def destroy(blocking: Boolean) { assertValid() _isValid = false + _destroySite = Utils.getCallSite().shortForm + logInfo("Destroying %s (from %s)".format(toString, _destroySite)) doDestroy(blocking) } @@ -124,7 +141,8 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { /** Check if this broadcast is valid. If not valid, exception is thrown. */ protected def assertValid() { if (!_isValid) { - throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString)) + throw new SparkException( + "Attempted to use %s after it was destroyed (%s) ".format(toString, _destroySite)) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 4cd4f4f96fd16..7dade04273b08 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -72,13 +72,13 @@ private[spark] class HttpBroadcast[T: ClassTag]( } /** Used by the JVM when serializing this object. */ - private def writeObject(out: ObjectOutputStream) { + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { assertValid() out.defaultWriteObject() } /** Used by the JVM when deserializing this object. */ - private def readObject(in: ObjectInputStream) { + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() HttpBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(blockId) match { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 99af2e9608ea7..94142d33369c7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -28,7 +28,7 @@ import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} -import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.{ByteBufferInputStream, Utils} import org.apache.spark.util.io.ByteArrayChunkOutputStream /** @@ -56,11 +56,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) extends Broadcast[T](id) with Logging with Serializable { /** - * Value of the broadcast object. On driver, this is set directly by the constructor. - * On executors, this is reconstructed by [[readObject]], which builds this value by reading - * blocks from the driver and/or other executors. + * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]], + * which builds this value by reading blocks from the driver and/or other executors. + * + * On the driver, if the value is required, it is read lazily from the block manager. */ - @transient private var _value: T = obj + @transient private lazy val _value: T = readBroadcastBlock() + /** The compression codec to use, or None if compression is disabled */ @transient private var compressionCodec: Option[CompressionCodec] = _ /** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */ @@ -79,22 +81,24 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) private val broadcastId = BroadcastBlockId(id) /** Total number of blocks this broadcast variable contains. */ - private val numBlocks: Int = writeBlocks() + private val numBlocks: Int = writeBlocks(obj) - override protected def getValue() = _value + override protected def getValue() = { + _value + } /** * Divide the object into multiple blocks and put those blocks in the block manager. - * + * @param value the object to divide * @return number of blocks this broadcast variable is divided into */ - private def writeBlocks(): Int = { + private def writeBlocks(value: T): Int = { // Store a copy of the broadcast variable in the driver so that tasks run on the driver // do not create a duplicate copy of the broadcast variable's value. - SparkEnv.get.blockManager.putSingle(broadcastId, _value, StorageLevel.MEMORY_AND_DISK, + SparkEnv.get.blockManager.putSingle(broadcastId, value, StorageLevel.MEMORY_AND_DISK, tellMaster = false) val blocks = - TorrentBroadcast.blockifyObject(_value, blockSize, SparkEnv.get.serializer, compressionCodec) + TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) blocks.zipWithIndex.foreach { case (block, i) => SparkEnv.get.blockManager.putBytes( BroadcastBlockId(id, "piece" + i), @@ -152,36 +156,35 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } /** Used by the JVM when serializing this object. */ - private def writeObject(out: ObjectOutputStream) { + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { assertValid() out.defaultWriteObject() } - /** Used by the JVM when deserializing this object. */ - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() + private def readBroadcastBlock(): T = Utils.tryOrIOException { TorrentBroadcast.synchronized { setConf(SparkEnv.get.conf) SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match { case Some(x) => - _value = x.asInstanceOf[T] + x.asInstanceOf[T] case None => logInfo("Started reading broadcast variable " + id) - val start = System.nanoTime() + val startTimeMs = System.currentTimeMillis() val blocks = readBlocks() - val time = (System.nanoTime() - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") + logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) - _value = - TorrentBroadcast.unBlockifyObject[T](blocks, SparkEnv.get.serializer, compressionCodec) + val obj = TorrentBroadcast.unBlockifyObject[T]( + blocks, SparkEnv.get.serializer, compressionCodec) // Store the merged copy in BlockManager so other tasks on this executor don't // need to re-fetch it. SparkEnv.get.blockManager.putSingle( - broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + broadcastId, obj, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + obj } } } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 39150deab863c..4e802e02c4149 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -17,6 +17,8 @@ package org.apache.spark.deploy +import java.net.{URI, URISyntaxException} + import scala.collection.mutable.ListBuffer import org.apache.log4j.Level @@ -114,5 +116,12 @@ private[spark] class ClientArguments(args: Array[String]) { } object ClientArguments { - def isValidJarUrl(s: String): Boolean = s.matches("(.+):(.+)jar") + def isValidJarUrl(s: String): Boolean = { + try { + val uri = new URI(s) + uri.getScheme != null && uri.getAuthority != null && s.endsWith("jar") + } catch { + case _: URISyntaxException => false + } + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index fe0ad9ebbca12..e28eaad8a5180 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -20,12 +20,15 @@ package org.apache.spark.deploy import java.security.PrivilegedExceptionAction import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{Logging, SparkContext, SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.Utils import scala.collection.JavaConversions._ @@ -121,6 +124,33 @@ class SparkHadoopUtil extends Logging { UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename) } + /** + * Returns a function that can be called to find Hadoop FileSystem bytes read. If + * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will + * return the bytes read on r since t. Reflection is required because thread-level FileSystem + * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). + * Returns None if the required method can't be found. + */ + private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration) + : Option[() => Long] = { + val qualifiedPath = path.getFileSystem(conf).makeQualified(path) + val scheme = qualifiedPath.toUri().getScheme() + val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme)) + try { + val threadStats = stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) + val statisticsDataClass = + Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") + val getBytesReadMethod = statisticsDataClass.getDeclaredMethod("getBytesRead") + val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum + val baselineBytesRead = f() + Some(() => f() - baselineBytesRead) + } catch { + case e: NoSuchMethodException => { + logDebug("Couldn't find method for retrieving thread-level FileSystem input data", e) + None + } + } + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index f97bf67fa5a3b..b43e68e40f791 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -158,8 +158,9 @@ object SparkSubmit { args.files = mergeFileLists(args.files, args.primaryResource) } args.files = mergeFileLists(args.files, args.pyFiles) - // Format python file paths properly before adding them to the PYTHONPATH - sysProps("spark.submit.pyFiles") = PythonRunner.formatPaths(args.pyFiles).mkString(",") + if (args.pyFiles != null) { + sysProps("spark.submit.pyFiles") = args.pyFiles + } } // Special flag to avoid deprecation warnings at the client @@ -273,15 +274,32 @@ object SparkSubmit { } } - // Properties given with --conf are superceded by other options, but take precedence over - // properties in the defaults file. + // Load any properties specified through --conf and the default properties file for ((k, v) <- args.sparkProperties) { sysProps.getOrElseUpdate(k, v) } - // Read from default spark properties, if any - for ((k, v) <- args.defaultSparkProperties) { - sysProps.getOrElseUpdate(k, v) + // Resolve paths in certain spark properties + val pathConfigs = Seq( + "spark.jars", + "spark.files", + "spark.yarn.jar", + "spark.yarn.dist.files", + "spark.yarn.dist.archives") + pathConfigs.foreach { config => + // Replace old URIs with resolved URIs, if they exist + sysProps.get(config).foreach { oldValue => + sysProps(config) = Utils.resolveURIs(oldValue) + } + } + + // Resolve and format python file paths properly before adding them to the PYTHONPATH. + // The resolving part is redundant in the case of --py-files, but necessary if the user + // explicitly sets `spark.submit.pyFiles` in his/her default properties file. + sysProps.get("spark.submit.pyFiles").foreach { pyFiles => + val resolvedPyFiles = Utils.resolveURIs(pyFiles) + val formattedPyFiles = PythonRunner.formatPaths(resolvedPyFiles).mkString(",") + sysProps("spark.submit.pyFiles") = formattedPyFiles } (childArgs, childClasspath, sysProps, childMainClass) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 72a452e0aefb5..f0e9ee67f6a67 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy import java.util.jar.JarFile -import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark.util.Utils @@ -72,39 +71,54 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St defaultProperties } - // Respect SPARK_*_MEMORY for cluster mode - driverMemory = sys.env.get("SPARK_DRIVER_MEMORY").orNull - executorMemory = sys.env.get("SPARK_EXECUTOR_MEMORY").orNull - + // Set parameters from command line arguments parseOpts(args.toList) - mergeSparkProperties() + // Populate `sparkProperties` map from properties file + mergeDefaultSparkProperties() + // Use `sparkProperties` map along with env vars to fill in any missing parameters + loadEnvironmentArguments() + checkRequiredArguments() /** - * Fill in any undefined values based on the default properties file or options passed in through - * the '--conf' flag. + * Merge values from the default properties file with those specified through --conf. + * When this is called, `sparkProperties` is already filled with configs from the latter. */ - private def mergeSparkProperties(): Unit = { + private def mergeDefaultSparkProperties(): Unit = { // Use common defaults file, if not specified by user propertiesFile = Option(propertiesFile).getOrElse(Utils.getDefaultPropertiesFile(env)) + // Honor --conf before the defaults file + defaultSparkProperties.foreach { case (k, v) => + if (!sparkProperties.contains(k)) { + sparkProperties(k) = v + } + } + } - val properties = HashMap[String, String]() - properties.putAll(defaultSparkProperties) - properties.putAll(sparkProperties) - - // Use properties file as fallback for values which have a direct analog to - // arguments in this script. - master = Option(master).orElse(properties.get("spark.master")).orNull - executorMemory = Option(executorMemory).orElse(properties.get("spark.executor.memory")).orNull - executorCores = Option(executorCores).orElse(properties.get("spark.executor.cores")).orNull + /** + * Load arguments from environment variables, Spark properties etc. + */ + private def loadEnvironmentArguments(): Unit = { + master = Option(master) + .orElse(sparkProperties.get("spark.master")) + .orElse(env.get("MASTER")) + .orNull + driverMemory = Option(driverMemory) + .orElse(sparkProperties.get("spark.driver.memory")) + .orElse(env.get("SPARK_DRIVER_MEMORY")) + .orNull + executorMemory = Option(executorMemory) + .orElse(sparkProperties.get("spark.executor.memory")) + .orElse(env.get("SPARK_EXECUTOR_MEMORY")) + .orNull + executorCores = Option(executorCores) + .orElse(sparkProperties.get("spark.executor.cores")) + .orNull totalExecutorCores = Option(totalExecutorCores) - .orElse(properties.get("spark.cores.max")) + .orElse(sparkProperties.get("spark.cores.max")) .orNull - name = Option(name).orElse(properties.get("spark.app.name")).orNull - jars = Option(jars).orElse(properties.get("spark.jars")).orNull - - // This supports env vars in older versions of Spark - master = Option(master).orElse(env.get("MASTER")).orNull + name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull + jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull // Try to set main class from JAR if no --class argument is given @@ -131,7 +145,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } /** Ensure that required fields exists. Call this only once all defaults are loaded. */ - private def checkRequiredArguments() = { + private def checkRequiredArguments(): Unit = { if (args.length == 0) { printUsageAndExit(-1) } @@ -166,7 +180,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } } - override def toString = { + override def toString = { s"""Parsed arguments: | master $master | deployMode $deployMode @@ -174,7 +188,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | executorCores $executorCores | totalExecutorCores $totalExecutorCores | propertiesFile $propertiesFile - | extraSparkProperties $sparkProperties | driverMemory $driverMemory | driverCores $driverCores | driverExtraClassPath $driverExtraClassPath @@ -193,8 +206,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | jars $jars | verbose $verbose | - |Default properties from $propertiesFile: - |${defaultSparkProperties.mkString(" ", "\n ", "\n")} + |Spark properties used, including those specified through + | --conf and those from the properties file $propertiesFile: + |${sparkProperties.mkString(" ", "\n ", "\n")} """.stripMargin } @@ -327,7 +341,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } } - private def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { val outStream = SparkSubmit.printStream if (unknownParam != null) { outStream.println("Unknown/unsupported param " + unknownParam) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala index 0125330589da5..2b894a796c8c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala @@ -82,17 +82,8 @@ private[spark] object SparkSubmitDriverBootstrapper { .orElse(confDriverMemory) .getOrElse(defaultDriverMemory) - val newLibraryPath = - if (submitLibraryPath.isDefined) { - // SPARK_SUBMIT_LIBRARY_PATH is already captured in JAVA_OPTS - "" - } else { - confLibraryPath.map("-Djava.library.path=" + _).getOrElse("") - } - val newClasspath = if (submitClasspath.isDefined) { - // SPARK_SUBMIT_CLASSPATH is already captured in CLASSPATH classpath } else { classpath + confClasspath.map(sys.props("path.separator") + _).getOrElse("") @@ -114,7 +105,6 @@ private[spark] object SparkSubmitDriverBootstrapper { val command: Seq[String] = Seq(runner) ++ Seq("-cp", newClasspath) ++ - Seq(newLibraryPath) ++ filteredJavaOpts ++ Seq(s"-Xms$newDriverMemory", s"-Xmx$newDriverMemory") ++ Seq("org.apache.spark.deploy.SparkSubmit") ++ @@ -130,6 +120,13 @@ private[spark] object SparkSubmitDriverBootstrapper { // Start the driver JVM val filteredCommand = command.filter(_.nonEmpty) val builder = new ProcessBuilder(filteredCommand) + val env = builder.environment() + + if (submitLibraryPath.isEmpty && confLibraryPath.nonEmpty) { + val libraryPaths = confLibraryPath ++ sys.env.get(Utils.libraryPathEnvName) + env.put(Utils.libraryPathEnvName, libraryPaths.mkString(sys.props("path.separator"))) + } + val process = builder.start() // Redirect stdout and stderr from the child JVM diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 481f6c93c6a8d..2d1609b973607 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -112,7 +112,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis val ui = { val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) - new SparkUI(conf, appSecManager, replayBus, appId, + SparkUI.createHistoryUI(conf, replayBus, appSecManager, appId, s"${HistoryServer.UI_PATH_PREFIX}/$appId") // Do not call ui.bind() to avoid creating a new server for each application } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index d25c29113d6da..0e249e51a77d8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -84,11 +84,11 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
{acc.name} | {acc.value} | {UIUtils.formatDuration(millis.toLong)} | + } } - val serializationQuantiles = -Result serialization time | +: Distribution(serializationTimes). - get.getQuantiles().map(ms =>{UIUtils.formatDuration(ms.toLong)} | ) val serviceTimes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.executorRunTime.toDouble } - val serviceQuantiles =Duration | +: Distribution(serviceTimes).get.getQuantiles() - .map(ms =>{UIUtils.formatDuration(ms.toLong)} | ) + val serviceQuantiles =Duration | +: getFormattedTimeQuantiles(serviceTimes) + + val gcTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.jvmGCTime.toDouble + } + val gcQuantiles = ++ GC Time + + | +: getFormattedTimeQuantiles(gcTimes) + + val serializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.resultSerializationTime.toDouble + } + val serializationQuantiles = ++ + Result Serialization Time + + | +: getFormattedTimeQuantiles(serializationTimes) val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) => if (info.gettingResultTime > 0) { @@ -142,76 +212,75 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { 0.0 } } - val gettingResultQuantiles =Time spent fetching task results | +: - Distribution(gettingResultTimes).get.getQuantiles().map { millis => -{UIUtils.formatDuration(millis.toLong)} | - } + val gettingResultQuantiles = ++ + Getting Result Time + + | +: + getFormattedTimeQuantiles(gettingResultTimes) // The scheduler delay includes the network delay to send the task to the worker // machine and to send back the result (but not the time to fetch the task result, // if it needed to be fetched from the block manager on the worker). val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) => - val totalExecutionTime = { - if (info.gettingResultTime > 0) { - (info.gettingResultTime - info.launchTime).toDouble - } else { - (info.finishTime - info.launchTime).toDouble - } - } - totalExecutionTime - metrics.get.executorRunTime + getSchedulerDelay(info, metrics.get).toDouble } val schedulerDelayTitle =Scheduler delay | + title={ToolTips.SCHEDULER_DELAY} data-placement="right">Scheduler Delay val schedulerDelayQuantiles = schedulerDelayTitle +: - Distribution(schedulerDelays).get.getQuantiles().map { millis => -{UIUtils.formatDuration(millis.toLong)} | - } + getFormattedTimeQuantiles(schedulerDelays) - def getQuantileCols(data: Seq[Double]) = + def getFormattedSizeQuantiles(data: Seq[Double]) = Distribution(data).get.getQuantiles().map(d =>{Utils.bytesToString(d.toLong)} | ) val inputSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble } - val inputQuantiles =Input | +: getQuantileCols(inputSizes) + val inputQuantiles =Input | +: getFormattedSizeQuantiles(inputSizes) val shuffleReadSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble } val shuffleReadQuantiles =Shuffle Read (Remote) | +: - getQuantileCols(shuffleReadSizes) + getFormattedSizeQuantiles(shuffleReadSizes) val shuffleWriteSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble } - val shuffleWriteQuantiles =Shuffle Write | +: getQuantileCols(shuffleWriteSizes) + val shuffleWriteQuantiles =Shuffle Write | +: + getFormattedSizeQuantiles(shuffleWriteSizes) val memoryBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.memoryBytesSpilled.toDouble } val memoryBytesSpilledQuantiles =Shuffle spill (memory) | +: - getQuantileCols(memoryBytesSpilledSizes) + getFormattedSizeQuantiles(memoryBytesSpilledSizes) val diskBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.diskBytesSpilled.toDouble } val diskBytesSpilledQuantiles =Shuffle spill (disk) | +: - getQuantileCols(diskBytesSpilledSizes) + getFormattedSizeQuantiles(diskBytesSpilledSizes) val listings: Seq[Seq[Node]] = Seq( - serializationQuantiles, - serviceQuantiles, - gettingResultQuantiles, - schedulerDelayQuantiles, - if (hasInput) inputQuantiles else Nil, - if (hasShuffleRead) shuffleReadQuantiles else Nil, - if (hasShuffleWrite) shuffleWriteQuantiles else Nil, - if (hasBytesSpilled) memoryBytesSpilledQuantiles else Nil, - if (hasBytesSpilled) diskBytesSpilledQuantiles else Nil) +
{formatDuration} | -+ | + {UIUtils.formatDuration(schedulerDelay.toLong)} + | +{if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} | -
- {Unparsed(
- info.accumulables.map{acc => s"${acc.name}: ${acc.update.get}"}.mkString(" ") - )} + | + {UIUtils.formatDuration(serializationTime)} | - + {if (hasAccumulators) { +
+ {Unparsed(accumulatorsReadable.mkString(" "))} + |
+ }}
{if (hasInput) {
{inputReadable}
@@ -333,4 +415,15 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
}
}
+
+ private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = {
+ val totalExecutionTime = {
+ if (info.gettingResultTime > 0) {
+ (info.gettingResultTime - info.launchTime)
+ } else {
+ (info.finishTime - info.launchTime)
+ }
+ }
+ totalExecutionTime - metrics.executorRunTime
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
new file mode 100644
index 0000000000000..23d672cabda07
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+/**
+ * Names of the CSS classes corresponding to each type of task detail. Used to allow users
+ * to optionally show/hide columns.
+ */
+private object TaskDetailsClassNames {
+ val SCHEDULER_DELAY = "scheduler_delay"
+ val GC_TIME = "gc_time"
+ val RESULT_SERIALIZATION_TIME = "serialization_time"
+ val GETTING_RESULT_TIME = "getting_result_time"
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
index a336bf7e1ed02..e2813f8eb5ab9 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ui.jobs
+import org.apache.spark.JobExecutionStatus
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
import org.apache.spark.util.collection.OpenHashSet
@@ -36,6 +37,13 @@ private[jobs] object UIData {
var diskBytesSpilled : Long = 0
}
+ class JobUIData(
+ var jobId: Int = -1,
+ var stageIds: Seq[Int] = Seq.empty,
+ var jobGroup: Option[String] = None,
+ var status: JobExecutionStatus = JobExecutionStatus.UNKNOWN
+ )
+
class StageUIData {
var numActiveTasks: Int = _
var numCompleteTasks: Int = _
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
index 8a0075ae8daf7..12d23a92878cf 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
@@ -39,7 +39,8 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
// Worker table
val workers = storageStatusList.map((rddId, _))
- val workerTable = UIUtils.listingTable(workerHeader, workerRow, workers)
+ val workerTable = UIUtils.listingTable(workerHeader, workerRow, workers,
+ id = Some("rdd-storage-by-worker-table"))
// Block table
val blockLocations = StorageUtils.getRddBlockLocations(rddId, storageStatusList)
@@ -49,7 +50,8 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
.map { case (blockId, status) =>
(blockId, status, blockLocations.get(blockId).getOrElse(Seq[String]("Unknown")))
}
- val blockTable = UIUtils.listingTable(blockHeader, blockRow, blocks)
+ val blockTable = UIUtils.listingTable(blockHeader, blockRow, blocks,
+ id = Some("rdd-storage-by-block-table"))
val content =
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala
index 83489ca0679ee..6ced6052d2b18 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala
@@ -31,7 +31,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") {
def render(request: HttpServletRequest): Seq[Node] = {
val rdds = listener.rddInfoList
- val content = UIUtils.listingTable(rddHeader, rddRow, rdds)
+ val content = UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))
UIUtils.headerSparkPage("Storage", content, parent)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
index 76097f1c51f8e..a81291d505583 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
@@ -26,11 +26,10 @@ import org.apache.spark.storage._
/** Web UI showing storage status of all RDD's in the given SparkContext. */
private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storage") {
- val listener = new StorageListener(parent.storageStatusListener)
+ val listener = parent.storageListener
attachPage(new StoragePage(this))
attachPage(new RDDPage(this))
- parent.registerListener(listener)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index f41c8d0315cb3..79e398eb8c104 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -159,17 +159,28 @@ private[spark] object AkkaUtils extends Logging {
def askWithReply[T](
message: Any,
actor: ActorRef,
- retryAttempts: Int,
+ timeout: FiniteDuration): T = {
+ askWithReply[T](message, actor, maxAttempts = 1, retryInterval = Int.MaxValue, timeout)
+ }
+
+ /**
+ * Send a message to the given actor and get its result within a default timeout, or
+ * throw a SparkException if this fails even after the specified number of retries.
+ */
+ def askWithReply[T](
+ message: Any,
+ actor: ActorRef,
+ maxAttempts: Int,
retryInterval: Int,
timeout: FiniteDuration): T = {
// TODO: Consider removing multiple attempts
if (actor == null) {
- throw new SparkException("Error sending message as driverActor is null " +
+ throw new SparkException("Error sending message as actor is null " +
"[message = " + message + "]")
}
var attempts = 0
var lastException: Exception = null
- while (attempts < retryAttempts) {
+ while (attempts < maxAttempts) {
attempts += 1
try {
val future = actor.ask(message)(timeout)
diff --git a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala
index 2b452ad33b021..770ff9d5ad6ae 100644
--- a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala
@@ -29,7 +29,7 @@ private[spark]
class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable {
def value = buffer
- private def readObject(in: ObjectInputStream) {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
val length = in.readInt()
buffer = ByteBuffer.allocate(length)
var amountRead = 0
@@ -44,7 +44,7 @@ class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable
buffer.rewind() // Allow us to read it later
}
- private def writeObject(out: ObjectOutputStream) {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
out.writeInt(buffer.limit())
if (Channels.newChannel(out).write(buffer) != buffer.limit()) {
throw new IOException("Could not fully write buffer to output stream")
diff --git a/core/src/main/scala/org/apache/spark/util/SparkExitCode.scala b/core/src/main/scala/org/apache/spark/util/SparkExitCode.scala
new file mode 100644
index 0000000000000..c93b1cca9f564
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/SparkExitCode.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+private[spark] object SparkExitCode {
+ /** The default uncaught exception handler was reached. */
+ val UNCAUGHT_EXCEPTION = 50
+
+ /** The default uncaught exception handler was called and an exception was encountered while
+ logging the exception. */
+ val UNCAUGHT_EXCEPTION_TWICE = 51
+
+ /** The default uncaught exception handler was reached, and the uncaught exception was an
+ OutOfMemoryError. */
+ val OOM = 52
+
+}
diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
new file mode 100644
index 0000000000000..ad3db1fbb57ed
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import org.apache.spark.Logging
+
+/**
+ * The default uncaught exception handler for Executors terminates the whole process, to avoid
+ * getting into a bad state indefinitely. Since Executors are relatively lightweight, it's better
+ * to fail fast when things go wrong.
+ */
+private[spark] object SparkUncaughtExceptionHandler
+ extends Thread.UncaughtExceptionHandler with Logging {
+
+ override def uncaughtException(thread: Thread, exception: Throwable) {
+ try {
+ logError("Uncaught exception in thread " + thread, exception)
+
+ // We may have been called from a shutdown hook. If so, we must not call System.exit().
+ // (If we do, we will deadlock.)
+ if (!Utils.inShutdown()) {
+ if (exception.isInstanceOf[OutOfMemoryError]) {
+ System.exit(SparkExitCode.OOM)
+ } else {
+ System.exit(SparkExitCode.UNCAUGHT_EXCEPTION)
+ }
+ }
+ } catch {
+ case oom: OutOfMemoryError => Runtime.getRuntime.halt(SparkExitCode.OOM)
+ case t: Throwable => Runtime.getRuntime.halt(SparkExitCode.UNCAUGHT_EXCEPTION_TWICE)
+ }
+ }
+
+ def uncaughtException(exception: Throwable) {
+ uncaughtException(Thread.currentThread(), exception)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 0aeff6455b3fe..063895d3c548d 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -23,8 +23,6 @@ import java.nio.ByteBuffer
import java.util.{Properties, Locale, Random, UUID}
import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor}
-import org.eclipse.jetty.util.MultiException
-
import scala.collection.JavaConversions._
import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
@@ -33,17 +31,17 @@ import scala.reflect.ClassTag
import scala.util.Try
import scala.util.control.{ControlThrowable, NonFatal}
-import com.google.common.io.Files
+import com.google.common.io.{ByteStreams, Files}
import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.commons.lang3.SystemUtils
import org.apache.hadoop.conf.Configuration
import org.apache.log4j.PropertyConfigurator
import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
+import org.eclipse.jetty.util.MultiException
import org.json4s._
import tachyon.client.{TachyonFile,TachyonFS}
import org.apache.spark._
-import org.apache.spark.executor.ExecutorUncaughtExceptionHandler
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
/** CallSite represents a place in user code. It can have a short and a long form. */
@@ -347,15 +345,84 @@ private[spark] object Utils extends Logging {
}
/**
- * Download a file requested by the executor. Supports fetching the file in a variety of ways,
+ * Download a file to target directory. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
*
+ * If `useCache` is true, first attempts to fetch the file to a local cache that's shared
+ * across executors running the same application. `useCache` is used mainly for
+ * the executors, and not in local mode.
+ *
* Throws SparkException if the target file already exists and has different contents than
* the requested file.
*/
- def fetchFile(url: String, targetDir: File, conf: SparkConf, securityMgr: SecurityManager,
- hadoopConf: Configuration) {
- val filename = url.split("/").last
+ def fetchFile(
+ url: String,
+ targetDir: File,
+ conf: SparkConf,
+ securityMgr: SecurityManager,
+ hadoopConf: Configuration,
+ timestamp: Long,
+ useCache: Boolean) {
+ val fileName = url.split("/").last
+ val targetFile = new File(targetDir, fileName)
+ if (useCache) {
+ val cachedFileName = s"${url.hashCode}${timestamp}_cache"
+ val lockFileName = s"${url.hashCode}${timestamp}_lock"
+ val localDir = new File(getLocalDir(conf))
+ val lockFile = new File(localDir, lockFileName)
+ val raf = new RandomAccessFile(lockFile, "rw")
+ // Only one executor entry.
+ // The FileLock is only used to control synchronization for executors download file,
+ // it's always safe regardless of lock type (mandatory or advisory).
+ val lock = raf.getChannel().lock()
+ val cachedFile = new File(localDir, cachedFileName)
+ try {
+ if (!cachedFile.exists()) {
+ doFetchFile(url, localDir, cachedFileName, conf, securityMgr, hadoopConf)
+ }
+ } finally {
+ lock.release()
+ }
+ if (targetFile.exists && !Files.equal(cachedFile, targetFile)) {
+ if (conf.getBoolean("spark.files.overwrite", false)) {
+ targetFile.delete()
+ logInfo((s"File $targetFile exists and does not match contents of $url, " +
+ s"replacing it with $url"))
+ } else {
+ throw new SparkException(s"File $targetFile exists and does not match contents of $url")
+ }
+ }
+ Files.copy(cachedFile, targetFile)
+ } else {
+ doFetchFile(url, targetDir, fileName, conf, securityMgr, hadoopConf)
+ }
+
+ // Decompress the file if it's a .tar or .tar.gz
+ if (fileName.endsWith(".tar.gz") || fileName.endsWith(".tgz")) {
+ logInfo("Untarring " + fileName)
+ Utils.execute(Seq("tar", "-xzf", fileName), targetDir)
+ } else if (fileName.endsWith(".tar")) {
+ logInfo("Untarring " + fileName)
+ Utils.execute(Seq("tar", "-xf", fileName), targetDir)
+ }
+ // Make the file executable - That's necessary for scripts
+ FileUtil.chmod(targetFile.getAbsolutePath, "a+x")
+ }
+
+ /**
+ * Download a file to target directory. Supports fetching the file in a variety of ways,
+ * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
+ *
+ * Throws SparkException if the target file already exists and has different contents than
+ * the requested file.
+ */
+ private def doFetchFile(
+ url: String,
+ targetDir: File,
+ filename: String,
+ conf: SparkConf,
+ securityMgr: SecurityManager,
+ hadoopConf: Configuration) {
val tempDir = getLocalDir(conf)
val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
val targetFile = new File(targetDir, filename)
@@ -443,16 +510,6 @@ private[spark] object Utils extends Logging {
}
Files.move(tempFile, targetFile)
}
- // Decompress the file if it's a .tar or .tar.gz
- if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
- logInfo("Untarring " + filename)
- Utils.execute(Seq("tar", "-xzf", filename), targetDir)
- } else if (filename.endsWith(".tar")) {
- logInfo("Untarring " + filename)
- Utils.execute(Seq("tar", "-xf", filename), targetDir)
- }
- // Make the file executable - That's necessary for scripts
- FileUtil.chmod(targetFile.getAbsolutePath, "a+x")
}
/**
@@ -680,11 +737,15 @@ private[spark] object Utils extends Logging {
}
private def listFilesSafely(file: File): Seq[File] = {
- val files = file.listFiles()
- if (files == null) {
- throw new IOException("Failed to list files for dir: " + file)
+ if (file.exists()) {
+ val files = file.listFiles()
+ if (files == null) {
+ throw new IOException("Failed to list files for dir: " + file)
+ }
+ files
+ } else {
+ List()
}
- files
}
/**
@@ -906,7 +967,37 @@ private[spark] object Utils extends Logging {
block
} catch {
case e: ControlThrowable => throw e
- case t: Throwable => ExecutorUncaughtExceptionHandler.uncaughtException(t)
+ case t: Throwable => SparkUncaughtExceptionHandler.uncaughtException(t)
+ }
+ }
+
+ /**
+ * Execute a block of code that evaluates to Unit, re-throwing any non-fatal uncaught
+ * exceptions as IOException. This is used when implementing Externalizable and Serializable's
+ * read and write methods, since Java's serializer will not report non-IOExceptions properly;
+ * see SPARK-4080 for more context.
+ */
+ def tryOrIOException(block: => Unit) {
+ try {
+ block
+ } catch {
+ case e: IOException => throw e
+ case NonFatal(t) => throw new IOException(t)
+ }
+ }
+
+ /**
+ * Execute a block of code that returns a value, re-throwing any non-fatal uncaught
+ * exceptions as IOException. This is used when implementing Externalizable and Serializable's
+ * read and write methods, since Java's serializer will not report non-IOExceptions properly;
+ * see SPARK-4080 for more context.
+ */
+ def tryOrIOException[T](block: => T): T = {
+ try {
+ block
+ } catch {
+ case e: IOException => throw e
+ case NonFatal(t) => throw new IOException(t)
}
}
@@ -914,7 +1005,8 @@ private[spark] object Utils extends Logging {
private def coreExclusionFunction(className: String): Boolean = {
// A regular expression to match classes of the "core" Spark API that we want to skip when
// finding the call site of a method.
- val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r
+ val SPARK_CORE_CLASS_REGEX =
+ """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r
val SCALA_CLASS_REGEX = """^scala""".r
val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined
val isScalaClass = SCALA_CLASS_REGEX.findFirstIn(className).isDefined
@@ -983,8 +1075,8 @@ private[spark] object Utils extends Logging {
val stream = new FileInputStream(file)
try {
- stream.skip(effectiveStart)
- stream.read(buff)
+ ByteStreams.skipFully(stream, effectiveStart)
+ ByteStreams.readFully(stream, buff)
} finally {
stream.close()
}
@@ -1178,12 +1270,28 @@ private[spark] object Utils extends Logging {
/**
* Timing method based on iterations that permit JVM JIT optimization.
* @param numIters number of iterations
- * @param f function to be executed
+ * @param f function to be executed. If prepare is not None, the running time of each call to f
+ * must be an order of magnitude longer than one millisecond for accurate timing.
+ * @param prepare function to be executed before each call to f. Its running time doesn't count.
+ * @return the total time across all iterations (not couting preparation time)
*/
- def timeIt(numIters: Int)(f: => Unit): Long = {
- val start = System.currentTimeMillis
- times(numIters)(f)
- System.currentTimeMillis - start
+ def timeIt(numIters: Int)(f: => Unit, prepare: Option[() => Unit] = None): Long = {
+ if (prepare.isEmpty) {
+ val start = System.currentTimeMillis
+ times(numIters)(f)
+ System.currentTimeMillis - start
+ } else {
+ var i = 0
+ var sum = 0L
+ while (i < numIters) {
+ prepare.get.apply()
+ val start = System.currentTimeMillis
+ f
+ sum += System.currentTimeMillis - start
+ i += 1
+ }
+ sum
+ }
}
/**
@@ -1272,6 +1380,11 @@ private[spark] object Utils extends Logging {
*/
val isWindows = SystemUtils.IS_OS_WINDOWS
+ /**
+ * Whether the underlying operating system is Mac OS X.
+ */
+ val isMac = SystemUtils.IS_OS_MAC_OSX
+
/**
* Pattern for matching a Windows drive, which contains only a single alphabet character.
*/
@@ -1594,6 +1707,51 @@ private[spark] object Utils extends Logging {
PropertyConfigurator.configure(pro)
}
+ def invoke(
+ clazz: Class[_],
+ obj: AnyRef,
+ methodName: String,
+ args: (Class[_], AnyRef)*): AnyRef = {
+ val (types, values) = args.unzip
+ val method = clazz.getDeclaredMethod(methodName, types: _*)
+ method.setAccessible(true)
+ method.invoke(obj, values.toSeq: _*)
+ }
+
+ /**
+ * Return the current system LD_LIBRARY_PATH name
+ */
+ def libraryPathEnvName: String = {
+ if (isWindows) {
+ "PATH"
+ } else if (isMac) {
+ "DYLD_LIBRARY_PATH"
+ } else {
+ "LD_LIBRARY_PATH"
+ }
+ }
+
+ /**
+ * Return the prefix of a command that appends the given library paths to the
+ * system-specific library path environment variable. On Unix, for instance,
+ * this returns the string LD_LIBRARY_PATH="path1:path2:$LD_LIBRARY_PATH".
+ */
+ def libraryPathEnvPrefix(libraryPaths: Seq[String]): String = {
+ val libraryPathScriptVar = if (isWindows) {
+ s"%${libraryPathEnvName}%"
+ } else {
+ "$" + libraryPathEnvName
+ }
+ val libraryPath = (libraryPaths :+ libraryPathScriptVar).mkString("\"",
+ File.pathSeparator, "\"")
+ val ampersand = if (Utils.isWindows) {
+ " &"
+ } else {
+ ""
+ }
+ s"$libraryPathEnvName=$libraryPath$ampersand"
+ }
+
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala
index ac1528969f0be..4f0bf8384afc9 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala
@@ -27,33 +27,51 @@ import scala.reflect.ClassTag
* Example format: an array of numbers, where each element is also the key.
* See [[KVArraySortDataFormat]] for a more exciting format.
*
- * This trait extends Any to ensure it is universal (and thus compiled to a Java interface).
+ * Note: Declaring and instantiating multiple subclasses of this class would prevent JIT inlining
+ * overridden methods and hence decrease the shuffle performance.
*
* @tparam K Type of the sort key of each element
* @tparam Buffer Internal data structure used by a particular format (e.g., Array[Int]).
*/
// TODO: Making Buffer a real trait would be a better abstraction, but adds some complexity.
-private[spark] trait SortDataFormat[K, Buffer] extends Any {
+private[spark]
+abstract class SortDataFormat[K, Buffer] {
+
+ /**
+ * Creates a new mutable key for reuse. This should be implemented if you want to override
+ * [[getKey(Buffer, Int, K)]].
+ */
+ def newKey(): K = null.asInstanceOf[K]
+
/** Return the sort key for the element at the given index. */
protected def getKey(data: Buffer, pos: Int): K
+ /**
+ * Returns the sort key for the element at the given index and reuse the input key if possible.
+ * The default implementation ignores the reuse parameter and invokes [[getKey(Buffer, Int]].
+ * If you want to override this method, you must implement [[newKey()]].
+ */
+ def getKey(data: Buffer, pos: Int, reuse: K): K = {
+ getKey(data, pos)
+ }
+
/** Swap two elements. */
- protected def swap(data: Buffer, pos0: Int, pos1: Int): Unit
+ def swap(data: Buffer, pos0: Int, pos1: Int): Unit
/** Copy a single element from src(srcPos) to dst(dstPos). */
- protected def copyElement(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int): Unit
+ def copyElement(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int): Unit
/**
* Copy a range of elements starting at src(srcPos) to dst, starting at dstPos.
* Overlapping ranges are allowed.
*/
- protected def copyRange(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int, length: Int): Unit
+ def copyRange(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int, length: Int): Unit
/**
* Allocates a Buffer that can hold up to 'length' elements.
* All elements of the buffer should be considered invalid until data is explicitly copied in.
*/
- protected def allocate(length: Int): Buffer
+ def allocate(length: Int): Buffer
}
/**
@@ -67,9 +85,9 @@ private[spark] trait SortDataFormat[K, Buffer] extends Any {
private[spark]
class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K, Array[T]] {
- override protected def getKey(data: Array[T], pos: Int): K = data(2 * pos).asInstanceOf[K]
+ override def getKey(data: Array[T], pos: Int): K = data(2 * pos).asInstanceOf[K]
- override protected def swap(data: Array[T], pos0: Int, pos1: Int) {
+ override def swap(data: Array[T], pos0: Int, pos1: Int) {
val tmpKey = data(2 * pos0)
val tmpVal = data(2 * pos0 + 1)
data(2 * pos0) = data(2 * pos1)
@@ -78,17 +96,16 @@ class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K,
data(2 * pos1 + 1) = tmpVal
}
- override protected def copyElement(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int) {
+ override def copyElement(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int) {
dst(2 * dstPos) = src(2 * srcPos)
dst(2 * dstPos + 1) = src(2 * srcPos + 1)
}
- override protected def copyRange(src: Array[T], srcPos: Int,
- dst: Array[T], dstPos: Int, length: Int) {
+ override def copyRange(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int, length: Int) {
System.arraycopy(src, 2 * srcPos, dst, 2 * dstPos, 2 * length)
}
- override protected def allocate(length: Int): Array[T] = {
+ override def allocate(length: Int): Array[T] = {
new Array[T](2 * length)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Sorter.scala b/core/src/main/scala/org/apache/spark/util/collection/Sorter.scala
new file mode 100644
index 0000000000000..39f66b8c428c6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/Sorter.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.util.Comparator
+
+/**
+ * A simple wrapper over the Java implementation [[TimSort]].
+ *
+ * The Java implementation is package private, and hence it cannot be called outside package
+ * org.apache.spark.util.collection. This is a simple wrapper of it that is available to spark.
+ */
+private[spark]
+class Sorter[K, Buffer](private val s: SortDataFormat[K, Buffer]) {
+
+ private val timSort = new TimSort(s)
+
+ /**
+ * Sorts the input buffer within range [lo, hi).
+ */
+ def sort(a: Buffer, lo: Int, hi: Int, c: Comparator[_ >: K]): Unit = {
+ timSort.sort(a, lo, hi, c)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 32c5fdad75e58..76e7a2760bcd1 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -19,8 +19,10 @@ package org.apache.spark.util.random
import java.util.Random
-import cern.jet.random.Poisson
-import cern.jet.random.engine.DRand
+import scala.reflect.ClassTag
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.commons.math3.distribution.PoissonDistribution
import org.apache.spark.annotation.DeveloperApi
@@ -39,13 +41,47 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable
/** take a random sample */
def sample(items: Iterator[T]): Iterator[U]
+ /** return a copy of the RandomSampler object */
override def clone: RandomSampler[T, U] =
throw new NotImplementedError("clone() is not implemented.")
}
+private[spark]
+object RandomSampler {
+ /** Default random number generator used by random samplers. */
+ def newDefaultRNG: Random = new XORShiftRandom
+
+ /**
+ * Default maximum gap-sampling fraction.
+ * For sampling fractions <= this value, the gap sampling optimization will be applied.
+ * Above this value, it is assumed that "tradtional" Bernoulli sampling is faster. The
+ * optimal value for this will depend on the RNG. More expensive RNGs will tend to make
+ * the optimal value higher. The most reliable way to determine this value for a new RNG
+ * is to experiment. When tuning for a new RNG, I would expect a value of 0.5 to be close
+ * in most cases, as an initial guess.
+ */
+ val defaultMaxGapSamplingFraction = 0.4
+
+ /**
+ * Default epsilon for floating point numbers sampled from the RNG.
+ * The gap-sampling compute logic requires taking log(x), where x is sampled from an RNG.
+ * To guard against errors from taking log(0), a positive epsilon lower bound is applied.
+ * A good value for this parameter is at or near the minimum positive floating
+ * point value returned by "nextDouble()" (or equivalent), for the RNG being used.
+ */
+ val rngEpsilon = 5e-11
+
+ /**
+ * Sampling fraction arguments may be results of computation, and subject to floating
+ * point jitter. I check the arguments with this epsilon slop factor to prevent spurious
+ * warnings for cases such as summing some numbers to get a sampling fraction of 1.000000001
+ */
+ val roundingEpsilon = 1e-6
+}
+
/**
* :: DeveloperApi ::
- * A sampler based on Bernoulli trials.
+ * A sampler based on Bernoulli trials for partitioning a data sequence.
*
* @param lb lower bound of the acceptance range
* @param ub upper bound of the acceptance range
@@ -53,56 +89,262 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable
* @tparam T item type
*/
@DeveloperApi
-class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
+class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = false)
extends RandomSampler[T, T] {
- private[random] var rng: Random = new XORShiftRandom
+ /** epsilon slop to avoid failure from floating point jitter. */
+ require(
+ lb <= (ub + RandomSampler.roundingEpsilon),
+ s"Lower bound ($lb) must be <= upper bound ($ub)")
+ require(
+ lb >= (0.0 - RandomSampler.roundingEpsilon),
+ s"Lower bound ($lb) must be >= 0.0")
+ require(
+ ub <= (1.0 + RandomSampler.roundingEpsilon),
+ s"Upper bound ($ub) must be <= 1.0")
- def this(ratio: Double) = this(0.0d, ratio)
+ private val rng: Random = new XORShiftRandom
override def setSeed(seed: Long) = rng.setSeed(seed)
override def sample(items: Iterator[T]): Iterator[T] = {
- items.filter { item =>
- val x = rng.nextDouble()
- (x >= lb && x < ub) ^ complement
+ if (ub - lb <= 0.0) {
+ if (complement) items else Iterator.empty
+ } else {
+ if (complement) {
+ items.filter { item => {
+ val x = rng.nextDouble()
+ (x < lb) || (x >= ub)
+ }}
+ } else {
+ items.filter { item => {
+ val x = rng.nextDouble()
+ (x >= lb) && (x < ub)
+ }}
+ }
}
}
/**
* Return a sampler that is the complement of the range specified of the current sampler.
*/
- def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
+ def cloneComplement(): BernoulliCellSampler[T] =
+ new BernoulliCellSampler[T](lb, ub, !complement)
+
+ override def clone = new BernoulliCellSampler[T](lb, ub, complement)
+}
+
+
+/**
+ * :: DeveloperApi ::
+ * A sampler based on Bernoulli trials.
+ *
+ * @param fraction the sampling fraction, aka Bernoulli sampling probability
+ * @tparam T item type
+ */
+@DeveloperApi
+class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] {
+
+ /** epsilon slop to avoid failure from floating point jitter */
+ require(
+ fraction >= (0.0 - RandomSampler.roundingEpsilon)
+ && fraction <= (1.0 + RandomSampler.roundingEpsilon),
+ s"Sampling fraction ($fraction) must be on interval [0, 1]")
- override def clone = new BernoulliSampler[T](lb, ub, complement)
+ private val rng: Random = RandomSampler.newDefaultRNG
+
+ override def setSeed(seed: Long) = rng.setSeed(seed)
+
+ override def sample(items: Iterator[T]): Iterator[T] = {
+ if (fraction <= 0.0) {
+ Iterator.empty
+ } else if (fraction >= 1.0) {
+ items
+ } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
+ new GapSamplingIterator(items, fraction, rng, RandomSampler.rngEpsilon)
+ } else {
+ items.filter { _ => rng.nextDouble() <= fraction }
+ }
+ }
+
+ override def clone = new BernoulliSampler[T](fraction)
}
+
/**
* :: DeveloperApi ::
- * A sampler based on values drawn from Poisson distribution.
+ * A sampler for sampling with replacement, based on values drawn from Poisson distribution.
*
- * @param mean Poisson mean
+ * @param fraction the sampling fraction (with replacement)
* @tparam T item type
*/
@DeveloperApi
-class PoissonSampler[T](mean: Double) extends RandomSampler[T, T] {
+class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] {
+
+ /** Epsilon slop to avoid failure from floating point jitter. */
+ require(
+ fraction >= (0.0 - RandomSampler.roundingEpsilon),
+ s"Sampling fraction ($fraction) must be >= 0")
- private[random] var rng = new Poisson(mean, new DRand)
+ // PoissonDistribution throws an exception when fraction <= 0
+ // If fraction is <= 0, Iterator.empty is used below, so we can use any placeholder value.
+ private val rng = new PoissonDistribution(if (fraction > 0.0) fraction else 1.0)
+ private val rngGap = RandomSampler.newDefaultRNG
override def setSeed(seed: Long) {
- rng = new Poisson(mean, new DRand(seed.toInt))
+ rng.reseedRandomGenerator(seed)
+ rngGap.setSeed(seed)
}
override def sample(items: Iterator[T]): Iterator[T] = {
- items.flatMap { item =>
- val count = rng.nextInt()
- if (count == 0) {
- Iterator.empty
- } else {
- Iterator.fill(count)(item)
- }
+ if (fraction <= 0.0) {
+ Iterator.empty
+ } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
+ new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon)
+ } else {
+ items.flatMap { item => {
+ val count = rng.sample()
+ if (count == 0) Iterator.empty else Iterator.fill(count)(item)
+ }}
+ }
+ }
+
+ override def clone = new PoissonSampler[T](fraction)
+}
+
+
+private[spark]
+class GapSamplingIterator[T: ClassTag](
+ var data: Iterator[T],
+ f: Double,
+ rng: Random = RandomSampler.newDefaultRNG,
+ epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {
+
+ require(f > 0.0 && f < 1.0, s"Sampling fraction ($f) must reside on open interval (0, 1)")
+ require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")
+
+ /** implement efficient linear-sequence drop until Scala includes fix for jira SI-8835. */
+ private val iterDrop: Int => Unit = {
+ val arrayClass = Array.empty[T].iterator.getClass
+ val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
+ data.getClass match {
+ case `arrayClass` => ((n: Int) => { data = data.drop(n) })
+ case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) })
+ case _ => ((n: Int) => {
+ var j = 0
+ while (j < n && data.hasNext) {
+ data.next()
+ j += 1
+ }
+ })
+ }
+ }
+
+ override def hasNext: Boolean = data.hasNext
+
+ override def next(): T = {
+ val r = data.next()
+ advance
+ r
+ }
+
+ private val lnq = math.log1p(-f)
+
+ /** skip elements that won't be sampled, according to geometric dist P(k) = (f)(1-f)^k. */
+ private def advance: Unit = {
+ val u = math.max(rng.nextDouble(), epsilon)
+ val k = (math.log(u) / lnq).toInt
+ iterDrop(k)
+ }
+
+ /** advance to first sample as part of object construction. */
+ advance
+ // Attempting to invoke this closer to the top with other object initialization
+ // was causing it to break in strange ways, so I'm invoking it last, which seems to
+ // work reliably.
+}
+
+private[spark]
+class GapSamplingReplacementIterator[T: ClassTag](
+ var data: Iterator[T],
+ f: Double,
+ rng: Random = RandomSampler.newDefaultRNG,
+ epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {
+
+ require(f > 0.0, s"Sampling fraction ($f) must be > 0")
+ require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")
+
+ /** implement efficient linear-sequence drop until scala includes fix for jira SI-8835. */
+ private val iterDrop: Int => Unit = {
+ val arrayClass = Array.empty[T].iterator.getClass
+ val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
+ data.getClass match {
+ case `arrayClass` => ((n: Int) => { data = data.drop(n) })
+ case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) })
+ case _ => ((n: Int) => {
+ var j = 0
+ while (j < n && data.hasNext) {
+ data.next()
+ j += 1
+ }
+ })
+ }
+ }
+
+ /** current sampling value, and its replication factor, as we are sampling with replacement. */
+ private var v: T = _
+ private var rep: Int = 0
+
+ override def hasNext: Boolean = data.hasNext || rep > 0
+
+ override def next(): T = {
+ val r = v
+ rep -= 1
+ if (rep <= 0) advance
+ r
+ }
+
+ /**
+ * Skip elements with replication factor zero (i.e. elements that won't be sampled).
+ * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is
+ * q is the probabililty of Poisson(0; f)
+ */
+ private def advance: Unit = {
+ val u = math.max(rng.nextDouble(), epsilon)
+ val k = (math.log(u) / (-f)).toInt
+ iterDrop(k)
+ // set the value and replication factor for the next value
+ if (data.hasNext) {
+ v = data.next()
+ rep = poissonGE1
+ }
+ }
+
+ private val q = math.exp(-f)
+
+ /**
+ * Sample from Poisson distribution, conditioned such that the sampled value is >= 1.
+ * This is an adaptation from the algorithm for Generating Poisson distributed random variables:
+ * http://en.wikipedia.org/wiki/Poisson_distribution
+ */
+ private def poissonGE1: Int = {
+ // simulate that the standard poisson sampling
+ // gave us at least one iteration, for a sample of >= 1
+ var pp = q + ((1.0 - q) * rng.nextDouble())
+ var r = 1
+
+ // now continue with standard poisson sampling algorithm
+ pp *= rng.nextDouble()
+ while (pp > q) {
+ r += 1
+ pp *= rng.nextDouble()
}
+ r
}
- override def clone = new PoissonSampler[T](mean)
+ /** advance to first sample as part of object construction. */
+ advance
+ // Attempting to invoke this closer to the top with other object initialization
+ // was causing it to break in strange ways, so I'm invoking it last, which seems to
+ // work reliably.
}
diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala
index 8f95d7c6b799b..4fa357edd6f07 100644
--- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala
@@ -22,8 +22,7 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
-import cern.jet.random.Poisson
-import cern.jet.random.engine.DRand
+import org.apache.commons.math3.distribution.PoissonDistribution
import org.apache.spark.Logging
import org.apache.spark.SparkContext._
@@ -209,7 +208,7 @@ private[spark] object StratifiedSamplingUtils extends Logging {
samplingRateByKey = computeThresholdByKey(finalResult, fractions)
}
(idx: Int, iter: Iterator[(K, V)]) => {
- val rng = new RandomDataGenerator
+ val rng = new RandomDataGenerator()
rng.reSeed(seed + idx)
// Must use the same invoke pattern on the rng as in getSeqOp for without replacement
// in order to generate the same sequence of random numbers when creating the sample
@@ -245,9 +244,9 @@ private[spark] object StratifiedSamplingUtils extends Logging {
// Must use the same invoke pattern on the rng as in getSeqOp for with replacement
// in order to generate the same sequence of random numbers when creating the sample
val copiesAccepted = if (acceptBound == 0) 0L else rng.nextPoisson(acceptBound)
- val copiesWailisted = rng.nextPoisson(finalResult(key).waitListBound)
+ val copiesWaitlisted = rng.nextPoisson(finalResult(key).waitListBound)
val copiesInSample = copiesAccepted +
- (0 until copiesWailisted).count(i => rng.nextUniform() < thresholdByKey(key))
+ (0 until copiesWaitlisted).count(i => rng.nextUniform() < thresholdByKey(key))
if (copiesInSample > 0) {
Iterator.fill(copiesInSample.toInt)(item)
} else {
@@ -261,10 +260,10 @@ private[spark] object StratifiedSamplingUtils extends Logging {
rng.reSeed(seed + idx)
iter.flatMap { item =>
val count = rng.nextPoisson(fractions(item._1))
- if (count > 0) {
- Iterator.fill(count)(item)
- } else {
+ if (count == 0) {
Iterator.empty
+ } else {
+ Iterator.fill(count)(item)
}
}
}
@@ -274,15 +273,24 @@ private[spark] object StratifiedSamplingUtils extends Logging {
/** A random data generator that generates both uniform values and Poisson values. */
private class RandomDataGenerator {
val uniform = new XORShiftRandom()
- var poisson = new Poisson(1.0, new DRand)
+ // commons-math3 doesn't have a method to generate Poisson from an arbitrary mean;
+ // maintain a cache of Poisson(m) distributions for various m
+ val poissonCache = mutable.Map[Double, PoissonDistribution]()
+ var poissonSeed = 0L
- def reSeed(seed: Long) {
+ def reSeed(seed: Long): Unit = {
uniform.setSeed(seed)
- poisson = new Poisson(1.0, new DRand(seed.toInt))
+ poissonSeed = seed
+ poissonCache.clear()
}
def nextPoisson(mean: Double): Int = {
- poisson.nextInt(mean)
+ val poisson = poissonCache.getOrElseUpdate(mean, {
+ val newPoisson = new PoissonDistribution(mean)
+ newPoisson.reseedRandomGenerator(poissonSeed)
+ newPoisson
+ })
+ poisson.sample()
}
def nextUniform(): Double = {
diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
index 55b5713706178..467b890fb4bb9 100644
--- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
@@ -96,13 +96,9 @@ private[spark] object XORShiftRandom {
xorRand.nextInt()
}
- val iters = timeIt(numIters)(_)
-
/* Return results as a map instead of just printing to screen
in case the user wants to do something with them */
- Map("javaTime" -> iters {javaRand.nextInt()},
- "xorTime" -> iters {xorRand.nextInt()})
-
+ Map("javaTime" -> timeIt(numIters) { javaRand.nextInt() },
+ "xorTime" -> timeIt(numIters) { xorRand.nextInt() })
}
-
}
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 814e40c4f77cc..c21a4b30d7726 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -140,11 +140,10 @@ public void intersection() {
public void sample() {
List YARN version | Profile required | 0.23.x to 2.1.x | yarn-alpha | 0.23.x to 2.1.x | yarn-alpha (Deprecated.) | 2.2.x and later | yarn | |
spark.ui.retainedStages
spark.ui.retainedJobs