diff --git a/assembly/pom.xml b/assembly/pom.xml index c65192bde64c6..4e2b773e7d2f3 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index 93db0d5efda5f..0327ffa402671 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml diff --git a/bin/spark-submit b/bin/spark-submit index c557311b4b20e..f92d90c3a66b0 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -22,6 +22,9 @@ export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" ORIG_ARGS=("$@") +# Set COLUMNS for progress bar +export COLUMNS=`tput cols` + while (($#)); do if [ "$1" = "--deploy-mode" ]; then SPARK_SUBMIT_DEPLOY_MODE=$2 diff --git a/core/pom.xml b/core/pom.xml index 492eddda744c2..1feb00b3a7fb8 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml diff --git a/core/src/main/java/org/apache/spark/SparkStageInfo.java b/core/src/main/java/org/apache/spark/SparkStageInfo.java index 04e2247210ecc..fd74321093658 100644 --- a/core/src/main/java/org/apache/spark/SparkStageInfo.java +++ b/core/src/main/java/org/apache/spark/SparkStageInfo.java @@ -26,6 +26,7 @@ public interface SparkStageInfo { int stageId(); int currentAttemptId(); + long submissionTime(); String name(); int numTasks(); int numActiveTasks(); diff --git a/core/src/main/java/org/apache/spark/api/java/function/package.scala b/core/src/main/java/org/apache/spark/api/java/function/package.scala index 7f91de653a64a..0f9bac7164162 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/package.scala +++ b/core/src/main/java/org/apache/spark/api/java/function/package.scala @@ -22,4 +22,4 @@ package org.apache.spark.api.java * these interfaces to pass functions to various Java API methods for Spark. Please visit Spark's * Java programming guide for more details. */ -package object function \ No newline at end of file +package object function diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index badd85ed48c82..14ba37d7c9bd9 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -26,26 +26,24 @@ $(function() { // Switch the class of the arrow from open to closed. $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open'); $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed'); - - // If clicking caused the metrics to expand, automatically check all options for additional - // metrics (don't trigger a click when collapsing metrics, because it leads to weird - // toggling behavior). - if (!$(additionalMetricsDiv).hasClass('collapsed')) { - $(this).parent().find('input:checkbox:not(:checked)').trigger('click'); - } }); - $("input:checkbox:not(:checked)").each(function() { - var column = "table ." + $(this).attr("name"); - $(column).hide(); - }); - // Stripe table rows after rows have been hidden to ensure correct striping. - stripeTables(); + stripeSummaryTable(); $("input:checkbox").click(function() { var column = "table ." + $(this).attr("name"); $(column).toggle(); - stripeTables(); + stripeSummaryTable(); + }); + + $("#select-all-metrics").click(function() { + if (this.checked) { + // Toggle all un-checked options. + $('input:checkbox:not(:checked)').trigger('click'); + } else { + // Toggle all checked options. + $('input:checkbox:checked').trigger('click'); + } }); // Trigger a click on the checkbox if a user clicks the label next to it. diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js index 6bb03015abb51..656147e40d13e 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/table.js +++ b/core/src/main/resources/org/apache/spark/ui/static/table.js @@ -15,16 +15,18 @@ * limitations under the License. */ -/* Adds background colors to stripe table rows. This is necessary (instead of using css or the - * table striping provided by bootstrap) to appropriately stripe tables with hidden rows. */ -function stripeTables() { - $("table.table-striped-custom").each(function() { - $(this).find("tr:not(:hidden)").each(function (index) { - if (index % 2 == 1) { - $(this).css("background-color", "#f9f9f9"); - } else { - $(this).css("background-color", "#ffffff"); - } - }); +/* Adds background colors to stripe table rows in the summary table (on the stage page). This is + * necessary (instead of using css or the table striping provided by bootstrap) because the summary + * table has hidden rows. + * + * An ID selector (rather than a class selector) is used to ensure this runs quickly even on pages + * with thousands of task rows (ID selectors are much faster than class selectors). */ +function stripeSummaryTable() { + $("#task-summary-table").find("tr:not(:hidden)").each(function (index) { + if (index % 2 == 1) { + $(this).css("background-color", "#f9f9f9"); + } else { + $(this).css("background-color", "#ffffff"); + } }); } diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index db57712c83503..cdf85bfbf326f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -168,3 +168,9 @@ span.additional-metric-title { border-left: 5px solid black; display: inline-block; } + +/* Hide all additional metrics by default. This is done here rather than using JavaScript to + * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ +.scheduler_delay, .gc_time, .deserialization_time, .serialization_time, .getting_result_time { + display: none; +} diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 2301caafb07ff..dc1e8f6c21b62 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -244,6 +244,36 @@ trait AccumulatorParam[T] extends AccumulableParam[T, T] { } } +object AccumulatorParam { + + // The following implicit objects were in SparkContext before 1.2 and users had to + // `import SparkContext._` to enable them. Now we move them here to make the compiler find + // them automatically. However, as there are duplicate codes in SparkContext for backward + // compatibility, please update them accordingly if you modify the following implicit objects. + + implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { + def addInPlace(t1: Double, t2: Double): Double = t1 + t2 + def zero(initialValue: Double) = 0.0 + } + + implicit object IntAccumulatorParam extends AccumulatorParam[Int] { + def addInPlace(t1: Int, t2: Int): Int = t1 + t2 + def zero(initialValue: Int) = 0 + } + + implicit object LongAccumulatorParam extends AccumulatorParam[Long] { + def addInPlace(t1: Long, t2: Long) = t1 + t2 + def zero(initialValue: Long) = 0L + } + + implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { + def addInPlace(t1: Float, t2: Float) = t1 + t2 + def zero(initialValue: Float) = 0f + } + + // TODO: Add AccumulatorParams for other types, e.g. lists and strings +} + // TODO: The multi-thread support in accumulators is kind of lame; check // if there's a more intuitive way of doing it right private object Accumulators { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 65edeeffb837a..9b0d5be7a7ab2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -50,7 +50,7 @@ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkD import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ -import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.{SparkUI, ConsoleProgressBar} import org.apache.spark.ui.jobs.JobProgressListener import org.apache.spark.util._ @@ -58,17 +58,33 @@ import org.apache.spark.util._ * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster. * + * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before + * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details. + * * @param config a Spark Config object describing the application configuration. Any settings in * this config overrides the default configs as well as system properties. */ - class SparkContext(config: SparkConf) extends Logging { + // The call site where this SparkContext was constructed. + private val creationSite: CallSite = Utils.getCallSite() + + // If true, log warnings instead of throwing exceptions when multiple SparkContexts are active + private val allowMultipleContexts: Boolean = + config.getBoolean("spark.driver.allowMultipleContexts", false) + + // In order to prevent multiple SparkContexts from being active at the same time, mark this + // context as having started construction. + // NOTE: this must be placed at the beginning of the SparkContext constructor. + SparkContext.markPartiallyConstructed(this, allowMultipleContexts) + // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, // etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It // contains a map from hostname to a list of input format splits on the host. private[spark] var preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map() + val startTime = System.currentTimeMillis() + /** * Create a SparkContext that loads settings from system properties (for instance, when * launching with ./bin/spark-submit). @@ -231,6 +247,13 @@ class SparkContext(config: SparkConf) extends Logging { val statusTracker = new SparkStatusTracker(this) + private[spark] val progressBar: Option[ConsoleProgressBar] = + if (conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { + Some(new ConsoleProgressBar(this)) + } else { + None + } + // Initialize the Spark UI private[spark] val ui: Option[SparkUI] = if (conf.getBoolean("spark.ui.enabled", true)) { @@ -248,8 +271,6 @@ class SparkContext(config: SparkConf) extends Logging { /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) - val startTime = System.currentTimeMillis() - // Add each JAR given through the constructor if (jars != null) { jars.foreach(addJar) @@ -1166,27 +1187,30 @@ class SparkContext(config: SparkConf) extends Logging { /** Shut down the SparkContext. */ def stop() { - postApplicationEnd() - ui.foreach(_.stop()) - // Do this only if not stopped already - best case effort. - // prevent NPE if stopped more than once. - val dagSchedulerCopy = dagScheduler - dagScheduler = null - if (dagSchedulerCopy != null) { - env.metricsSystem.report() - metadataCleaner.cancel() - env.actorSystem.stop(heartbeatReceiver) - cleaner.foreach(_.stop()) - dagSchedulerCopy.stop() - taskScheduler = null - // TODO: Cache.stop()? - env.stop() - SparkEnv.set(null) - listenerBus.stop() - eventLogger.foreach(_.stop()) - logInfo("Successfully stopped SparkContext") - } else { - logInfo("SparkContext already stopped") + SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + postApplicationEnd() + ui.foreach(_.stop()) + // Do this only if not stopped already - best case effort. + // prevent NPE if stopped more than once. + val dagSchedulerCopy = dagScheduler + dagScheduler = null + if (dagSchedulerCopy != null) { + env.metricsSystem.report() + metadataCleaner.cancel() + env.actorSystem.stop(heartbeatReceiver) + cleaner.foreach(_.stop()) + dagSchedulerCopy.stop() + taskScheduler = null + // TODO: Cache.stop()? + env.stop() + SparkEnv.set(null) + listenerBus.stop() + eventLogger.foreach(_.stop()) + logInfo("Successfully stopped SparkContext") + SparkContext.clearActiveContext() + } else { + logInfo("SparkContext already stopped") + } } } @@ -1257,6 +1281,7 @@ class SparkContext(config: SparkConf) extends Logging { logInfo("Starting job: " + callSite.shortForm) dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, resultHandler, localProperties.get) + progressBar.foreach(_.finishAll()) rdd.doCheckpoint() } @@ -1475,6 +1500,11 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def cleanup(cleanupTime: Long) { persistentRdds.clearOldValues(cleanupTime) } + + // In order to prevent multiple SparkContexts from being active at the same time, mark this + // context as having finished construction. + // NOTE: this must be placed at the end of the SparkContext constructor. + SparkContext.setActiveContext(this, allowMultipleContexts) } /** @@ -1483,6 +1513,107 @@ class SparkContext(config: SparkConf) extends Logging { */ object SparkContext extends Logging { + /** + * Lock that guards access to global variables that track SparkContext construction. + */ + private val SPARK_CONTEXT_CONSTRUCTOR_LOCK = new Object() + + /** + * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `None`. + * + * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK + */ + private var activeContext: Option[SparkContext] = None + + /** + * Points to a partially-constructed SparkContext if some thread is in the SparkContext + * constructor, or `None` if no SparkContext is being constructed. + * + * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK + */ + private var contextBeingConstructed: Option[SparkContext] = None + + /** + * Called to ensure that no other SparkContext is running in this JVM. + * + * Throws an exception if a running context is detected and logs a warning if another thread is + * constructing a SparkContext. This warning is necessary because the current locking scheme + * prevents us from reliably distinguishing between cases where another context is being + * constructed and cases where another constructor threw an exception. + */ + private def assertNoOtherContextIsRunning( + sc: SparkContext, + allowMultipleContexts: Boolean): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + contextBeingConstructed.foreach { otherContext => + if (otherContext ne sc) { // checks for reference equality + // Since otherContext might point to a partially-constructed context, guard against + // its creationSite field being null: + val otherContextCreationSite = + Option(otherContext.creationSite).map(_.longForm).getOrElse("unknown location") + val warnMsg = "Another SparkContext is being constructed (or threw an exception in its" + + " constructor). This may indicate an error, since only one SparkContext may be" + + " running in this JVM (see SPARK-2243)." + + s" The other SparkContext was created at:\n$otherContextCreationSite" + logWarning(warnMsg) + } + + activeContext.foreach { ctx => + val errMsg = "Only one SparkContext may be running in this JVM (see SPARK-2243)." + + " To ignore this error, set spark.driver.allowMultipleContexts = true. " + + s"The currently running SparkContext was created at:\n${ctx.creationSite.longForm}" + val exception = new SparkException(errMsg) + if (allowMultipleContexts) { + logWarning("Multiple running SparkContexts detected in the same JVM!", exception) + } else { + throw exception + } + } + } + } + } + + /** + * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is + * running. Throws an exception if a running context is detected and logs a warning if another + * thread is constructing a SparkContext. This warning is necessary because the current locking + * scheme prevents us from reliably distinguishing between cases where another context is being + * constructed and cases where another constructor threw an exception. + */ + private[spark] def markPartiallyConstructed( + sc: SparkContext, + allowMultipleContexts: Boolean): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + assertNoOtherContextIsRunning(sc, allowMultipleContexts) + contextBeingConstructed = Some(sc) + } + } + + /** + * Called at the end of the SparkContext constructor to ensure that no other SparkContext has + * raced with this constructor and started. + */ + private[spark] def setActiveContext( + sc: SparkContext, + allowMultipleContexts: Boolean): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + assertNoOtherContextIsRunning(sc, allowMultipleContexts) + contextBeingConstructed = None + activeContext = Some(sc) + } + } + + /** + * Clears the active SparkContext metadata. This is called by `SparkContext#stop()`. It's + * also called in unit tests to prevent a flood of warnings from test suites that don't / can't + * properly clean up their SparkContexts. + */ + private[spark] def clearActiveContext(): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + activeContext = None + } + } + private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description" private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" @@ -1493,47 +1624,74 @@ object SparkContext extends Logging { private[spark] val DRIVER_IDENTIFIER = "" - implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { + // The following deprecated objects have already been copied to `object AccumulatorParam` to + // make the compiler find them automatically. They are duplicate codes only for backward + // compatibility, please update `object AccumulatorParam` accordingly if you plan to modify the + // following ones. + + @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + + "backward compatibility.", "1.2.0") + object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double) = 0.0 } - implicit object IntAccumulatorParam extends AccumulatorParam[Int] { + @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + + "backward compatibility.", "1.2.0") + object IntAccumulatorParam extends AccumulatorParam[Int] { def addInPlace(t1: Int, t2: Int): Int = t1 + t2 def zero(initialValue: Int) = 0 } - implicit object LongAccumulatorParam extends AccumulatorParam[Long] { + @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + + "backward compatibility.", "1.2.0") + object LongAccumulatorParam extends AccumulatorParam[Long] { def addInPlace(t1: Long, t2: Long) = t1 + t2 def zero(initialValue: Long) = 0L } - implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { + @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + + "backward compatibility.", "1.2.0") + object FloatAccumulatorParam extends AccumulatorParam[Float] { def addInPlace(t1: Float, t2: Float) = t1 + t2 def zero(initialValue: Float) = 0f } - // TODO: Add AccumulatorParams for other types, e.g. lists and strings + // The following deprecated functions have already been moved to `object RDD` to + // make the compiler find them automatically. They are still kept here for backward compatibility + // and just call the corresponding functions in `object RDD`. - implicit def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { - new PairRDDFunctions(rdd) + RDD.rddToPairRDDFunctions(rdd) } - implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd) + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = RDD.rddToAsyncRDDActions(rdd) - implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( rdd: RDD[(K, V)]) = - new SequenceFileRDDFunctions(rdd) + RDD.rddToSequenceFileRDDFunctions(rdd) - implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( rdd: RDD[(K, V)]) = - new OrderedRDDFunctions[K, V, (K, V)](rdd) + RDD.rddToOrderedRDDFunctions(rdd) - implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd) + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = RDD.doubleRDDToDoubleRDDFunctions(rdd) - implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = - new DoubleRDDFunctions(rdd.map(x => num.toDouble(x))) + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = + RDD.numericRDDToDoubleRDDFunctions(rdd) // Implicit conversions to common Writable types, for saveAsSequenceFile @@ -1559,40 +1717,49 @@ object SparkContext extends Logging { arr.map(x => anyToWritable(x)).toArray) } - // Helper objects for converting common types to Writable - private def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T) - : WritableConverter[T] = { - val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]] - new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W])) - } + // The following deprecated functions have already been moved to `object WritableConverter` to + // make the compiler find them automatically. They are still kept here for backward compatibility + // and just call the corresponding functions in `object WritableConverter`. - implicit def intWritableConverter(): WritableConverter[Int] = - simpleWritableConverter[Int, IntWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def intWritableConverter(): WritableConverter[Int] = + WritableConverter.intWritableConverter() - implicit def longWritableConverter(): WritableConverter[Long] = - simpleWritableConverter[Long, LongWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def longWritableConverter(): WritableConverter[Long] = + WritableConverter.longWritableConverter() - implicit def doubleWritableConverter(): WritableConverter[Double] = - simpleWritableConverter[Double, DoubleWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def doubleWritableConverter(): WritableConverter[Double] = + WritableConverter.doubleWritableConverter() - implicit def floatWritableConverter(): WritableConverter[Float] = - simpleWritableConverter[Float, FloatWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def floatWritableConverter(): WritableConverter[Float] = + WritableConverter.floatWritableConverter() - implicit def booleanWritableConverter(): WritableConverter[Boolean] = - simpleWritableConverter[Boolean, BooleanWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def booleanWritableConverter(): WritableConverter[Boolean] = + WritableConverter.booleanWritableConverter() - implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = { - simpleWritableConverter[Array[Byte], BytesWritable](bw => - // getBytes method returns array which is longer then data to be returned - Arrays.copyOfRange(bw.getBytes, 0, bw.getLength) - ) - } + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def bytesWritableConverter(): WritableConverter[Array[Byte]] = + WritableConverter.bytesWritableConverter() - implicit def stringWritableConverter(): WritableConverter[String] = - simpleWritableConverter[String, Text](_.toString) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def stringWritableConverter(): WritableConverter[String] = + WritableConverter.stringWritableConverter() - implicit def writableWritableConverter[T <: Writable]() = - new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T]) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def writableWritableConverter[T <: Writable]() = + WritableConverter.writableWritableConverter() /** * Find the JAR from which a given class was loaded, to make it easy for users to pass @@ -1682,6 +1849,9 @@ object SparkContext extends Logging { def localCpuCount = Runtime.getRuntime.availableProcessors() // local[*] estimates the number of cores on the machine; local[N] uses exactly N threads. val threadCount = if (threads == "*") localCpuCount else threads.toInt + if (threadCount <= 0) { + throw new SparkException(s"Asked to run locally with $threadCount threads") + } val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) val backend = new LocalBackend(scheduler, threadCount) scheduler.initialize(backend) @@ -1816,3 +1986,46 @@ private[spark] class WritableConverter[T]( val writableClass: ClassTag[T] => Class[_ <: Writable], val convert: Writable => T) extends Serializable + +object WritableConverter { + + // Helper objects for converting common types to Writable + private[spark] def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T) + : WritableConverter[T] = { + val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]] + new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W])) + } + + // The following implicit functions were in SparkContext before 1.2 and users had to + // `import SparkContext._` to enable them. Now we move them here to make the compiler find + // them automatically. However, we still keep the old functions in SparkContext for backward + // compatibility and forward to the following functions directly. + + implicit def intWritableConverter(): WritableConverter[Int] = + simpleWritableConverter[Int, IntWritable](_.get) + + implicit def longWritableConverter(): WritableConverter[Long] = + simpleWritableConverter[Long, LongWritable](_.get) + + implicit def doubleWritableConverter(): WritableConverter[Double] = + simpleWritableConverter[Double, DoubleWritable](_.get) + + implicit def floatWritableConverter(): WritableConverter[Float] = + simpleWritableConverter[Float, FloatWritable](_.get) + + implicit def booleanWritableConverter(): WritableConverter[Boolean] = + simpleWritableConverter[Boolean, BooleanWritable](_.get) + + implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = { + simpleWritableConverter[Array[Byte], BytesWritable](bw => + // getBytes method returns array which is longer then data to be returned + Arrays.copyOfRange(bw.getBytes, 0, bw.getLength) + ) + } + + implicit def stringWritableConverter(): WritableConverter[String] = + simpleWritableConverter[String, Text](_.toString) + + implicit def writableWritableConverter[T <: Writable]() = + new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T]) +} diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index c18d763d7ff4d..edbdda8a0bcb6 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -96,6 +96,7 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { new SparkStageInfoImpl( stageId, info.attemptId, + info.submissionTime.getOrElse(0), info.name, info.numTasks, data.numActiveTasks, diff --git a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala index 90b47c847fbca..e5c7c8d0db578 100644 --- a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala +++ b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala @@ -26,6 +26,7 @@ private class SparkJobInfoImpl ( private class SparkStageInfoImpl( val stageId: Int, val currentAttemptId: Int, + val submissionTime: Long, val name: String, val numTasks: Int, val numActiveTasks: Int, diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index e37f3acaf6e30..7af3538262fd6 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -32,13 +32,13 @@ import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.Partitioner._ -import org.apache.spark.SparkContext.rddToPairRDDFunctions import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.{OrderedRDDFunctions, RDD} +import org.apache.spark.rdd.RDD.rddToPairRDDFunctions import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index d50ed32ca085c..97f5c9f257e09 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -33,7 +33,7 @@ import org.apache.hadoop.mapred.{InputFormat, JobConf} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ -import org.apache.spark.SparkContext._ +import org.apache.spark.AccumulatorParam._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast @@ -42,6 +42,9 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} /** * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns * [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones. + * + * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before + * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details. */ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround with Closeable { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 45beb8fc8c925..e0bc00e1eb249 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.api.python import java.io._ import java.net._ -import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} +import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections} import org.apache.spark.input.PortableDataStream @@ -47,7 +47,7 @@ private[spark] class PythonRDD( pythonIncludes: JList[String], preservePartitoning: Boolean, pythonExec: String, - broadcastVars: JList[Broadcast[Array[Byte]]], + broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { @@ -230,8 +230,7 @@ private[spark] class PythonRDD( if (!oldBids.contains(broadcast.id)) { // send new broadcast dataOut.writeLong(broadcast.id) - dataOut.writeInt(broadcast.value.length) - dataOut.write(broadcast.value) + PythonRDD.writeUTF(broadcast.value.path, dataOut) oldBids.add(broadcast.id) } } @@ -368,16 +367,8 @@ private[spark] object PythonRDD extends Logging { } } - def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = { - val file = new DataInputStream(new FileInputStream(filename)) - try { - val length = file.readInt() - val obj = new Array[Byte](length) - file.readFully(obj) - sc.broadcast(obj) - } finally { - file.close() - } + def readBroadcastFromFile(sc: JavaSparkContext, path: String): Broadcast[PythonBroadcast] = { + sc.broadcast(new PythonBroadcast(path)) } def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { @@ -816,3 +807,49 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: } } } + +/** + * An Wrapper for Python Broadcast, which is written into disk by Python. It also will + * write the data into disk after deserialization, then Python can read it from disks. + */ +private[spark] class PythonBroadcast(@transient var path: String) extends Serializable { + + /** + * Read data from disks, then copy it to `out` + */ + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { + val in = new FileInputStream(new File(path)) + try { + Utils.copyStream(in, out) + } finally { + in.close() + } + } + + /** + * Write data into disk, using randomly generated name. + */ + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + val dir = new File(Utils.getLocalDir(SparkEnv.get.conf)) + val file = File.createTempFile("broadcast", "", dir) + path = file.getAbsolutePath + val out = new FileOutputStream(file) + try { + Utils.copyStream(in, out) + } finally { + out.close() + } + } + + /** + * Delete the file once the object is GCed. + */ + override def finalize() { + if (!path.isEmpty) { + val file = new File(path) + if (file.exists()) { + file.delete() + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 4e802e02c4149..2e1e52906ceeb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -75,7 +75,8 @@ private[spark] class ClientArguments(args: Array[String]) { if (!ClientArguments.isValidJarUrl(_jarUrl)) { println(s"Jar url '${_jarUrl}' is not in valid format.") - println(s"Must be a jar file path in URL format (e.g. hdfs://XX.jar, file://XX.jar)") + println(s"Must be a jar file path in URL format " + + "(e.g. hdfs://host:port/XX.jar, file:///XX.jar)") printUsageAndExit(-1) } @@ -119,7 +120,7 @@ object ClientArguments { def isValidJarUrl(s: String): Boolean = { try { val uri = new URI(s) - uri.getScheme != null && uri.getAuthority != null && s.endsWith("jar") + uri.getScheme != null && uri.getPath != null && uri.getPath.endsWith(".jar") } catch { case _: URISyntaxException => false } diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index b9dd8557ee904..c46f84de8444a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -92,6 +92,8 @@ private[deploy] object DeployMessages { case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders + case object ReregisterWithMaster // used when a worker attempts to reconnect to a master + // AppClient to Master case class RegisterApplication(appDescription: ApplicationDescription) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala index aa3743ca7df63..d2687faad62b1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala @@ -134,7 +134,7 @@ private[spark] object SparkSubmitDriverBootstrapper { override def run() = { if (process != null) { process.destroy() - sys.exit(process.waitFor()) + process.waitFor() } } }) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 2d1609b973607..82a54dbfb5330 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -29,22 +29,27 @@ import org.apache.spark.scheduler._ import org.apache.spark.ui.SparkUI import org.apache.spark.util.Utils +/** + * A class that provides application history from event logs stored in the file system. + * This provider checks for new finished applications in the background periodically and + * renders the history application UI by parsing the associated event logs. + */ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHistoryProvider with Logging { + import FsHistoryProvider._ + private val NOT_STARTED = "" // Interval between each check for event log updates private val UPDATE_INTERVAL_MS = conf.getInt("spark.history.fs.updateInterval", conf.getInt("spark.history.updateInterval", 10)) * 1000 - private val logDir = conf.get("spark.history.fs.logDirectory", null) - private val resolvedLogDir = Option(logDir) - .map { d => Utils.resolveURI(d) } - .getOrElse { throw new IllegalArgumentException("Logging directory must be specified.") } + private val logDir = conf.getOption("spark.history.fs.logDirectory") + .map { d => Utils.resolveURI(d).toString } + .getOrElse(DEFAULT_LOG_DIR) - private val fs = Utils.getHadoopFileSystem(resolvedLogDir, - SparkHadoopUtil.get.newConfiguration(conf)) + private val fs = Utils.getHadoopFileSystem(logDir, SparkHadoopUtil.get.newConfiguration(conf)) // A timestamp of when the disk was last accessed to check for log updates private var lastLogCheckTimeMs = -1L @@ -87,14 +92,17 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis private def initialize() { // Validate the log directory. - val path = new Path(resolvedLogDir) + val path = new Path(logDir) if (!fs.exists(path)) { - throw new IllegalArgumentException( - "Logging directory specified does not exist: %s".format(resolvedLogDir)) + var msg = s"Log directory specified does not exist: $logDir." + if (logDir == DEFAULT_LOG_DIR) { + msg += " Did you configure the correct one through spark.fs.history.logDirectory?" + } + throw new IllegalArgumentException(msg) } if (!fs.getFileStatus(path).isDir) { throw new IllegalArgumentException( - "Logging directory specified is not a directory: %s".format(resolvedLogDir)) + "Logging directory specified is not a directory: %s".format(logDir)) } checkForLogs() @@ -134,8 +142,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } } - override def getConfig(): Map[String, String] = - Map("Event Log Location" -> resolvedLogDir.toString) + override def getConfig(): Map[String, String] = Map("Event log directory" -> logDir.toString) /** * Builds the application list based on the current contents of the log directory. @@ -146,7 +153,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis lastLogCheckTimeMs = getMonotonicTimeMs() logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTimeMs)) try { - val logStatus = fs.listStatus(new Path(resolvedLogDir)) + val logStatus = fs.listStatus(new Path(logDir)) val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]() // Load all new logs from the log directory. Only directories that have a modification time @@ -244,6 +251,10 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } +private object FsHistoryProvider { + val DEFAULT_LOG_DIR = "file:/tmp/spark-events" +} + private class FsApplicationHistoryInfo( val logDir: String, id: String, diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 0e249e51a77d8..5fdc350cd8512 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -58,7 +58,13 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { ++ appTable } else { -

No Completed Applications Found

+

No completed applications found!

++ +

Did you specify the correct logging directory? + Please verify your setting of + spark.history.fs.logDirectory and whether you have the permissions to + access it.
It is also possible that your application did not run to + completion or did not stop the SparkContext. +

} } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 5bce32a04d16d..b1270ade9f750 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -17,14 +17,13 @@ package org.apache.spark.deploy.history -import org.apache.spark.SparkConf +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.util.Utils /** * Command-line parser for the master. */ -private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) { - private var logDir: String = null +private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging { private var propertiesFile: String = null parse(args.toList) @@ -32,7 +31,8 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] private def parse(args: List[String]): Unit = { args match { case ("--dir" | "-d") :: value :: tail => - logDir = value + logWarning("Setting log directory through the command line is deprecated as of " + + "Spark 1.1.0. Please set this through spark.history.fs.logDirectory instead.") conf.set("spark.history.fs.logDirectory", value) System.setProperty("spark.history.fs.logDirectory", value) parse(tail) @@ -78,9 +78,10 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] | (default 50) |FsHistoryProvider options: | - | spark.history.fs.logDirectory Directory where app logs are stored (required) - | spark.history.fs.updateInterval How often to reload log data from storage (in seconds, - | default 10) + | spark.history.fs.logDirectory Directory where app logs are stored + | (default: file:/tmp/spark-events) + | spark.history.fs.updateInterval How often to reload log data from storage + | (in seconds, default: 10) |""".stripMargin) System.exit(exitCode) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index 6ff2aa5244847..36a2e2c6a6349 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -18,12 +18,13 @@ package org.apache.spark.deploy.master import java.io._ -import java.nio.ByteBuffer + +import scala.reflect.ClassTag + +import akka.serialization.Serialization import org.apache.spark.Logging -import org.apache.spark.serializer.Serializer -import scala.reflect.ClassTag /** * Stores data in a single on-disk directory with one file per application and worker. @@ -34,10 +35,9 @@ import scala.reflect.ClassTag */ private[spark] class FileSystemPersistenceEngine( val dir: String, - val serialization: Serializer) + val serialization: Serialization) extends PersistenceEngine with Logging { - val serializer = serialization.newInstance() new File(dir).mkdir() override def persist(name: String, obj: Object): Unit = { @@ -56,17 +56,17 @@ private[spark] class FileSystemPersistenceEngine( private def serializeIntoFile(file: File, value: AnyRef) { val created = file.createNewFile() if (!created) { throw new IllegalStateException("Could not create file: " + file) } - - val out = serializer.serializeStream(new FileOutputStream(file)) + val serializer = serialization.findSerializerFor(value) + val serialized = serializer.toBinary(value) + val out = new FileOutputStream(file) try { - out.writeObject(value) + out.write(serialized) } finally { out.close() } - } - def deserializeFromFile[T](file: File): T = { + private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = { val fileData = new Array[Byte](file.length().asInstanceOf[Int]) val dis = new DataInputStream(new FileInputStream(file)) try { @@ -74,7 +74,9 @@ private[spark] class FileSystemPersistenceEngine( } finally { dis.close() } - - serializer.deserializeStream(dis).readObject() + val clazz = m.runtimeClass.asInstanceOf[Class[T]] + val serializer = serialization.serializerFor(clazz) + serializer.fromBinary(fileData).asInstanceOf[T] } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 021454e25804c..7b32c505def9b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -30,6 +30,7 @@ import scala.util.Random import akka.actor._ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import akka.serialization.Serialization import akka.serialization.SerializationExtension import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} @@ -132,15 +133,18 @@ private[spark] class Master( val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match { case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") - val zkFactory = new ZooKeeperRecoveryModeFactory(conf) + val zkFactory = + new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(context.system)) (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => - val fsFactory = new FileSystemRecoveryModeFactory(conf) + val fsFactory = + new FileSystemRecoveryModeFactory(conf, SerializationExtension(context.system)) (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) case "CUSTOM" => val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory")) - val factory = clazz.getConstructor(conf.getClass) - .newInstance(conf).asInstanceOf[StandaloneRecoveryModeFactory] + val factory = clazz.getConstructor(conf.getClass, Serialization.getClass) + .newInstance(conf, SerializationExtension(context.system)) + .asInstanceOf[StandaloneRecoveryModeFactory] (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => (new BlackHolePersistenceEngine(), new MonarchyLeaderAgent(this)) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala index d9d36c1ed5f9f..1096eb0368357 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala @@ -17,9 +17,10 @@ package org.apache.spark.deploy.master +import akka.serialization.Serialization + import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.serializer.JavaSerializer /** * ::DeveloperApi:: @@ -29,7 +30,7 @@ import org.apache.spark.serializer.JavaSerializer * */ @DeveloperApi -abstract class StandaloneRecoveryModeFactory(conf: SparkConf) { +abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) { /** * PersistenceEngine defines how the persistent data(Information about worker, driver etc..) @@ -48,21 +49,21 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf) { * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual * recovery is made by restoring from filesystem. */ -private[spark] class FileSystemRecoveryModeFactory(conf: SparkConf) - extends StandaloneRecoveryModeFactory(conf) with Logging { +private[spark] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization) + extends StandaloneRecoveryModeFactory(conf, serializer) with Logging { val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") def createPersistenceEngine() = { logInfo("Persisting recovery state to directory: " + RECOVERY_DIR) - new FileSystemPersistenceEngine(RECOVERY_DIR, new JavaSerializer(conf)) + new FileSystemPersistenceEngine(RECOVERY_DIR, serializer) } def createLeaderElectionAgent(master: LeaderElectable) = new MonarchyLeaderAgent(master) } -private[spark] class ZooKeeperRecoveryModeFactory(conf: SparkConf) - extends StandaloneRecoveryModeFactory(conf) { - def createPersistenceEngine() = new ZooKeeperPersistenceEngine(new JavaSerializer(conf), conf) +private[spark] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization) + extends StandaloneRecoveryModeFactory(conf, serializer) { + def createPersistenceEngine() = new ZooKeeperPersistenceEngine(conf, serializer) def createLeaderElectionAgent(master: LeaderElectable) = new ZooKeeperLeaderElectionAgent(master, conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 96c2139eb02f0..e11ac031fb9c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -17,27 +17,24 @@ package org.apache.spark.deploy.master +import akka.serialization.Serialization + import scala.collection.JavaConversions._ +import scala.reflect.ClassTag import org.apache.curator.framework.CuratorFramework import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.serializer.Serializer -import java.nio.ByteBuffer -import scala.reflect.ClassTag - -private[spark] class ZooKeeperPersistenceEngine(val serialization: Serializer, conf: SparkConf) +private[spark] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) extends PersistenceEngine with Logging { val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status" val zk: CuratorFramework = SparkCuratorUtil.newClient(conf) - val serializer = serialization.newInstance() - SparkCuratorUtil.mkdir(zk, WORKING_DIR) @@ -59,14 +56,17 @@ private[spark] class ZooKeeperPersistenceEngine(val serialization: Serializer, c } private def serializeIntoFile(path: String, value: AnyRef) { - val serialized = serializer.serialize(value) - zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized.array()) + val serializer = serialization.findSerializerFor(value) + val serialized = serializer.toBinary(value) + zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized) } - def deserializeFromFile[T](filename: String): Option[T] = { + def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = { val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename) + val clazz = m.runtimeClass.asInstanceOf[Class[T]] + val serializer = serialization.serializerFor(clazz) try { - Some(serializer.deserialize(ByteBuffer.wrap(fileData))) + Some(serializer.fromBinary(fileData).asInstanceOf[T]) } catch { case e: Exception => { logWarning("Exception while reading persisted file, deleting", e) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index ca262de832e25..eb11163538b20 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -21,7 +21,6 @@ import java.io.File import java.io.IOException import java.text.SimpleDateFormat import java.util.{UUID, Date} -import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ import scala.collection.mutable.HashMap @@ -177,6 +176,9 @@ private[spark] class Worker( throw new SparkException("Invalid spark URL: " + x) } connected = true + // Cancel any outstanding re-registration attempts because we found a new master + registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer = None } private def tryRegisterAllMasters() { @@ -187,7 +189,12 @@ private[spark] class Worker( } } - private def retryConnectToMaster() { + /** + * Re-register with the master because a network failure or a master failure has occurred. + * If the re-registration attempt threshold is exceeded, the worker exits with error. + * Note that for thread-safety this should only be called from the actor. + */ + private def reregisterWithMaster(): Unit = { Utils.tryOrExit { connectionAttemptCount += 1 if (registered) { @@ -195,12 +202,40 @@ private[spark] class Worker( registrationRetryTimer = None } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)") - tryRegisterAllMasters() + /** + * Re-register with the active master this worker has been communicating with. If there + * is none, then it means this worker is still bootstrapping and hasn't established a + * connection with a master yet, in which case we should re-register with all masters. + * + * It is important to re-register only with the active master during failures. Otherwise, + * if the worker unconditionally attempts to re-register with all masters, the following + * race condition may arise and cause a "duplicate worker" error detailed in SPARK-4592: + * + * (1) Master A fails and Worker attempts to reconnect to all masters + * (2) Master B takes over and notifies Worker + * (3) Worker responds by registering with Master B + * (4) Meanwhile, Worker's previous reconnection attempt reaches Master B, + * causing the same Worker to register with Master B twice + * + * Instead, if we only register with the known active master, we can assume that the + * old master must have died because another master has taken over. Note that this is + * still not safe if the old master recovers within this interval, but this is a much + * less likely scenario. + */ + if (master != null) { + master ! RegisterWorker( + workerId, host, port, cores, memory, webUi.boundPort, publicAddress) + } else { + // We are retrying the initial registration + tryRegisterAllMasters() + } + // We have exceeded the initial registration retry threshold + // All retries from now on should use a higher interval if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) { registrationRetryTimer.foreach(_.cancel()) registrationRetryTimer = Some { context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL, - PROLONGED_REGISTRATION_RETRY_INTERVAL)(retryConnectToMaster) + PROLONGED_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) } } } else { @@ -220,7 +255,7 @@ private[spark] class Worker( connectionAttemptCount = 0 registrationRetryTimer = Some { context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL, - INITIAL_REGISTRATION_RETRY_INTERVAL)(retryConnectToMaster) + INITIAL_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) } case Some(_) => logInfo("Not spawning another attempt to register with the master, since there is an" + @@ -400,12 +435,15 @@ private[spark] class Worker( logInfo(s"$x Disassociated !") masterDisconnected() - case RequestWorkerState => { + case RequestWorkerState => sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, finishedExecutors.values.toList, drivers.values.toList, finishedDrivers.values.toList, activeMasterUrl, cores, memory, coresUsed, memoryUsed, activeMasterWebUiUrl) - } + + case ReregisterWithMaster => + reregisterWithMaster() + } private def masterDisconnected() { diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 4c378a278b4c1..835157fc520aa 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -221,7 +221,7 @@ private[spark] class Executor( // directSend = sending directly back to the driver val serializedResult = { - if (resultSize > maxResultSize) { + if (maxResultSize > 0 && resultSize > maxResultSize) { logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " + s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + s"dropping it.") @@ -334,7 +334,7 @@ private[spark] class Executor( * SparkContext. Also adds any new JARs we fetched to the class loader. */ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) synchronized { // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index ce4225cae6d88..cef203006d685 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -20,7 +20,24 @@ package org.apache.spark.network.netty import org.apache.spark.SparkConf import org.apache.spark.network.util.{TransportConf, ConfigProvider} +/** + * Provides a utility for transforming from a SparkConf inside a Spark JVM (e.g., Executor, + * Driver, or a standalone shuffle service) into a TransportConf with details on our environment + * like the number of cores that are allocated to this JVM. + */ object SparkTransportConf { + /** + * Specifies an upper bound on the number of Netty threads that Spark requires by default. + * In practice, only 2-4 cores should be required to transfer roughly 10 Gb/s, and each core + * that we use will have an initial overhead of roughly 32 MB of off-heap memory, which comes + * at a premium. + * + * Thus, this value should still retain maximum throughput and reduce wasted off-heap memory + * allocation. It can be overridden by setting the number of serverThreads and clientThreads + * manually in Spark's configuration. + */ + private val MAX_DEFAULT_NETTY_THREADS = 8 + /** * Utility for creating a [[TransportConf]] from a [[SparkConf]]. * @param numUsableCores if nonzero, this will restrict the server and client threads to only @@ -29,15 +46,28 @@ object SparkTransportConf { */ def fromSparkConf(_conf: SparkConf, numUsableCores: Int = 0): TransportConf = { val conf = _conf.clone - if (numUsableCores > 0) { - // Only set if serverThreads/clientThreads not already set. - conf.set("spark.shuffle.io.serverThreads", - conf.get("spark.shuffle.io.serverThreads", numUsableCores.toString)) - conf.set("spark.shuffle.io.clientThreads", - conf.get("spark.shuffle.io.clientThreads", numUsableCores.toString)) - } + + // Specify thread configuration based on our JVM's allocation of cores (rather than necessarily + // assuming we have all the machine's cores). + // NB: Only set if serverThreads/clientThreads not already set. + val numThreads = defaultNumThreads(numUsableCores) + conf.set("spark.shuffle.io.serverThreads", + conf.get("spark.shuffle.io.serverThreads", numThreads.toString)) + conf.set("spark.shuffle.io.clientThreads", + conf.get("spark.shuffle.io.clientThreads", numThreads.toString)) + new TransportConf(new ConfigProvider { override def get(name: String): String = conf.get(name) }) } + + /** + * Returns the default number of threads for both the Netty client and server thread pools. + * If numUsableCores is 0, we will use Runtime get an approximate number of available cores. + */ + private def defaultNumThreads(numUsableCores: Int): Int = { + val availableCores = + if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors() + math.min(availableCores, MAX_DEFAULT_NETTY_THREADS) + } } diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index e2fc9c649925e..436dbed1730bc 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -44,5 +44,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.2.0-SNAPSHOT" + val SPARK_VERSION = "1.3.0-SNAPSHOT" } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index cb64d43c6c54a..3add4a76192ca 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -21,6 +21,7 @@ import java.util.{Properties, Random} import scala.collection.{mutable, Map} import scala.collection.mutable.ArrayBuffer +import scala.language.implicitConversions import scala.reflect.{classTag, ClassTag} import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus @@ -28,6 +29,7 @@ import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text +import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ @@ -1309,7 +1311,7 @@ abstract class RDD[T: ClassTag]( def debugSelf (rdd: RDD[_]): Seq[String] = { import Utils.bytesToString - val persistence = storageLevel.description + val persistence = if (storageLevel != StorageLevel.NONE) storageLevel.description else "" val storageInfo = rdd.context.getRDDStorageInfo.filter(_.id == rdd.id).map(info => " CachedPartitions: %d; MemorySize: %s; TachyonSize: %s; DiskSize: %s".format( info.numCachedPartitions, bytesToString(info.memSize), @@ -1383,3 +1385,31 @@ abstract class RDD[T: ClassTag]( new JavaRDD(this)(elementClassTag) } } + +object RDD { + + // The following implicit functions were in SparkContext before 1.2 and users had to + // `import SparkContext._` to enable them. Now we move them here to make the compiler find + // them automatically. However, we still keep the old functions in SparkContext for backward + // compatibility and forward to the following functions directly. + + implicit def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) + (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { + new PairRDDFunctions(rdd) + } + + implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd) + + implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( + rdd: RDD[(K, V)]) = + new SequenceFileRDDFunctions(rdd) + + implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( + rdd: RDD[(K, V)]) = + new OrderedRDDFunctions[K, V, (K, V)](rdd) + + implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd) + + implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = + new DoubleRDDFunctions(rdd.map(x => num.toDouble(x))) +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index e2c301603b4a5..8c43a559409f2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -39,21 +39,24 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long) private[spark] class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, Long)](prev) { - override def getPartitions: Array[Partition] = { + /** The start index of each partition. */ + @transient private val startIndices: Array[Long] = { val n = prev.partitions.size - val startIndices: Array[Long] = - if (n == 0) { - Array[Long]() - } else if (n == 1) { - Array(0L) - } else { - prev.context.runJob( - prev, - Utils.getIteratorSize _, - 0 until n - 1, // do not need to count the last partition - false - ).scanLeft(0L)(_ + _) - } + if (n == 0) { + Array[Long]() + } else if (n == 1) { + Array(0L) + } else { + prev.context.runJob( + prev, + Utils.getIteratorSize _, + 0 until n - 1, // do not need to count the last partition + allowLocal = false + ).scanLeft(0L)(_ + _) + } + } + + override def getPartitions: Array[Partition] = { firstParent[T].partitions.map(x => new ZippedWithIndexRDDPartition(x, startIndices(x.index))) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 22449517d100f..b1222af662e9b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -751,14 +751,15 @@ class DAGScheduler( localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 if (shouldRunLocally) { // Compute very short actions like first() or take() with no parent stages locally. - listenerBus.post(SparkListenerJobStart(job.jobId, Array[Int](), properties)) + listenerBus.post(SparkListenerJobStart(job.jobId, Seq.empty, properties)) runLocally(job) } else { jobIdToActiveJob(jobId) = job activeJobs += job finalStage.resultOfJob = Some(job) - listenerBus.post(SparkListenerJobStart(job.jobId, jobIdToStageIds(jobId).toArray, - properties)) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post(SparkListenerJobStart(job.jobId, stageInfos, properties)) submitStage(finalStage) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 86afe3bd5265f..b62b0c1312693 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -56,8 +56,15 @@ case class SparkListenerTaskEnd( extends SparkListenerEvent @DeveloperApi -case class SparkListenerJobStart(jobId: Int, stageIds: Seq[Int], properties: Properties = null) - extends SparkListenerEvent +case class SparkListenerJobStart( + jobId: Int, + stageInfos: Seq[StageInfo], + properties: Properties = null) + extends SparkListenerEvent { + // Note: this is here for backwards-compatibility with older versions of this event which + // only stored stageIds and not StageInfos: + val stageIds: Seq[Int] = stageInfos.map(_.stageId) +} @DeveloperApi case class SparkListenerJobEnd(jobId: Int, jobResult: JobResult) extends SparkListenerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d8fb640350343..cabdc655f89bf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -536,7 +536,7 @@ private[spark] class TaskSetManager( calculatedTasks += 1 if (maxResultSize > 0 && totalResultSize > maxResultSize) { val msg = s"Total size of serialized results of ${calculatedTasks} tasks " + - s"(${Utils.bytesToString(totalResultSize)}) is bigger than maxResultSize " + + s"(${Utils.bytesToString(totalResultSize)}) is bigger than spark.driver.maxResultSize " + s"(${Utils.bytesToString(maxResultSize)})" logError(msg) abort(msg) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7a6ee56f81689..047fae104b485 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -46,6 +46,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed var totalCoreCount = new AtomicInteger(0) + // Total number of executors that are currently registered var totalRegisteredExecutors = new AtomicInteger(0) val conf = scheduler.sc.conf private val timeout = AkkaUtils.askTimeout(conf) @@ -204,6 +205,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste executorsPendingToRemove -= executorId } totalCoreCount.addAndGet(-executorInfo.totalCores) + totalRegisteredExecutors.addAndGet(-1) scheduler.executorLost(executorId, SlaveLost(reason)) case None => logError(s"Asked to remove non-existent executor $executorId") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index d13795186c48e..10e6886c16a4f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -208,10 +208,12 @@ private[spark] class MesosSchedulerBackend( */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { inClassLoader() { - val (acceptedOffers, declinedOffers) = offers.partition { o => + // Fail-fast on offers we know will be rejected + val (usableOffers, unUsableOffers) = offers.partition { o => val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue + // TODO(pwendell): Should below be 1 + scheduler.CPUS_PER_TASK? (mem >= MemoryUtils.calculateTotalMemory(sc) && // need at least 1 for executor, 1 for task cpus >= 2 * scheduler.CPUS_PER_TASK) || @@ -219,11 +221,12 @@ private[spark] class MesosSchedulerBackend( cpus >= scheduler.CPUS_PER_TASK) } - val offerableWorkers = acceptedOffers.map { o => + val workerOffers = usableOffers.map { o => val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) { getResource(o.getResourcesList, "cpus").toInt } else { // If the executor doesn't exist yet, subtract CPU for executor + // TODO(pwendell): Should below just subtract "1"? getResource(o.getResourcesList, "cpus").toInt - scheduler.CPUS_PER_TASK } @@ -233,17 +236,20 @@ private[spark] class MesosSchedulerBackend( cpus) } - val slaveIdToOffer = acceptedOffers.map(o => o.getSlaveId.getValue -> o).toMap + val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]] + val slavesIdsOfAcceptedOffers = HashSet[String]() + // Call into the TaskSchedulerImpl - scheduler.resourceOffers(offerableWorkers) - .filter(!_.isEmpty) + val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty) + acceptedOffers .foreach { offer => offer.foreach { taskDesc => val slaveId = taskDesc.executorId slaveIdsWithExecutors += slaveId + slavesIdsOfAcceptedOffers += slaveId taskIdToSlaveId(taskDesc.taskId) = slaveId mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) .add(createMesosTask(taskDesc, slaveId)) @@ -257,7 +263,14 @@ private[spark] class MesosSchedulerBackend( d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters) } - declinedOffers.foreach(o => d.declineOffer(o.getId)) + // Decline offers that weren't used + // NOTE: This logic assumes that we only get a single offer for each host in a given batch + for (o <- usableOffers if !slavesIdsOfAcceptedOffers.contains(o.getSlaveId.getValue)) { + d.declineOffer(o.getId) + } + + // Decline offers we ruled out immediately + unUsableOffers.foreach(o => d.declineOffer(o.getId)) } } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index 6908a59a79e60..af873034215a9 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -148,6 +148,7 @@ private[spark] class TachyonBlockManager( logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) } } + client.close() } }) } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala index 6dbad5ff0518e..233d1e2b7c616 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala @@ -116,6 +116,8 @@ private[spark] class TachyonStore( case ioe: IOException => logWarning(s"Failed to fetch the block $blockId from Tachyon", ioe) None + } finally { + is.close() } } diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala new file mode 100644 index 0000000000000..27ba9e18237b5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import java.util.{Timer, TimerTask} + +import org.apache.spark._ + +/** + * ConsoleProgressBar shows the progress of stages in the next line of the console. It poll the + * status of active stages from `sc.statusTracker` periodically, the progress bar will be showed + * up after the stage has ran at least 500ms. If multiple stages run in the same time, the status + * of them will be combined together, showed in one line. + */ +private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { + + // Carrige return + val CR = '\r' + // Update period of progress bar, in milliseconds + val UPDATE_PERIOD = 200L + // Delay to show up a progress bar, in milliseconds + val FIRST_DELAY = 500L + + // The width of terminal + val TerminalWidth = if (!sys.env.getOrElse("COLUMNS", "").isEmpty) { + sys.env.get("COLUMNS").get.toInt + } else { + 80 + } + + var lastFinishTime = 0L + var lastUpdateTime = 0L + var lastProgressBar = "" + + // Schedule a refresh thread to run periodically + private val timer = new Timer("refresh progress", true) + timer.schedule(new TimerTask{ + override def run() { + refresh() + } + }, FIRST_DELAY, UPDATE_PERIOD) + + /** + * Try to refresh the progress bar in every cycle + */ + private def refresh(): Unit = synchronized { + val now = System.currentTimeMillis() + if (now - lastFinishTime < FIRST_DELAY) { + return + } + val stageIds = sc.statusTracker.getActiveStageIds() + val stages = stageIds.map(sc.statusTracker.getStageInfo).flatten.filter(_.numTasks() > 1) + .filter(now - _.submissionTime() > FIRST_DELAY).sortBy(_.stageId()) + if (stages.size > 0) { + show(now, stages.take(3)) // display at most 3 stages in same time + } + } + + /** + * Show progress bar in console. The progress bar is displayed in the next line + * after your last output, keeps overwriting itself to hold in one line. The logging will follow + * the progress bar, then progress bar will be showed in next line without overwrite logs. + */ + private def show(now: Long, stages: Seq[SparkStageInfo]) { + val width = TerminalWidth / stages.size + val bar = stages.map { s => + val total = s.numTasks() + val header = s"[Stage ${s.stageId()}:" + val tailer = s"(${s.numCompletedTasks()} + ${s.numActiveTasks()}) / $total]" + val w = width - header.size - tailer.size + val bar = if (w > 0) { + val percent = w * s.numCompletedTasks() / total + (0 until w).map { i => + if (i < percent) "=" else if (i == percent) ">" else " " + }.mkString("") + } else { + "" + } + header + bar + tailer + }.mkString("") + + // only refresh if it's changed of after 1 minute (or the ssh connection will be closed + // after idle some time) + if (bar != lastProgressBar || now - lastUpdateTime > 60 * 1000L) { + System.err.print(CR + bar) + lastUpdateTime = now + } + lastProgressBar = bar + } + + /** + * Clear the progress bar if showed. + */ + private def clear() { + if (!lastProgressBar.isEmpty) { + System.err.printf(CR + " " * TerminalWidth + CR) + lastProgressBar = "" + } + } + + /** + * Mark all the stages as finished, clear the progress bar if showed, then the progress will not + * interweave with output of jobs. + */ + def finishAll(): Unit = synchronized { + clear() + lastFinishTime = System.currentTimeMillis() + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 049938f827291..176907dffa46a 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -23,7 +23,7 @@ import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.JettyUtils._ import org.apache.spark.ui.env.{EnvironmentListener, EnvironmentTab} import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab} -import org.apache.spark.ui.jobs.{JobProgressListener, JobProgressTab} +import org.apache.spark.ui.jobs.{JobsTab, JobProgressListener, StagesTab} import org.apache.spark.ui.storage.{StorageListener, StorageTab} /** @@ -43,17 +43,20 @@ private[spark] class SparkUI private ( extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath, "SparkUI") with Logging { + val killEnabled = sc.map(_.conf.getBoolean("spark.ui.killEnabled", true)).getOrElse(false) + /** Initialize all components of the server. */ def initialize() { - val jobProgressTab = new JobProgressTab(this) - attachTab(jobProgressTab) + attachTab(new JobsTab(this)) + val stagesTab = new StagesTab(this) + attachTab(stagesTab) attachTab(new StorageTab(this)) attachTab(new EnvironmentTab(this)) attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler(createRedirectHandler("/", "/stages", basePath = basePath)) + attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) attachHandler( - createRedirectHandler("/stages/stage/kill", "/stages", jobProgressTab.handleKillRequest)) + createRedirectHandler("/stages/stage/kill", "/stages", stagesTab.handleKillRequest)) // If the UI is live, then serve sc.foreach { _.env.metricsSystem.getServletHandlers.foreach(attachHandler) } } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 7bc1e24d58711..09079bbd43f6f 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -26,7 +26,8 @@ import org.apache.spark.Logging /** Utility functions for generating XML pages with spark content. */ private[spark] object UIUtils extends Logging { - val TABLE_CLASS = "table table-bordered table-striped-custom table-condensed sortable" + val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed sortable" + val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { @@ -169,7 +170,8 @@ private[spark] object UIUtils extends Logging { title: String, content: => Seq[Node], activeTab: SparkUITab, - refreshInterval: Option[Int] = None): Seq[Node] = { + refreshInterval: Option[Int] = None, + helpText: Option[String] = None): Seq[Node] = { val appName = activeTab.appName val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." @@ -178,6 +180,9 @@ private[spark] object UIUtils extends Logging { {tab.name} } + val helpButton: Seq[Node] = helpText.map { helpText => + (?) + }.getOrElse(Seq.empty) @@ -201,6 +206,7 @@ private[spark] object UIUtils extends Logging {

{title} + {helpButton}

@@ -243,12 +249,10 @@ private[spark] object UIUtils extends Logging { data: Iterable[T], fixedWidth: Boolean = false, id: Option[String] = None, - headerClasses: Seq[String] = Seq.empty): Seq[Node] = { + headerClasses: Seq[String] = Seq.empty, + stripeRowsWithCss: Boolean = true): Seq[Node] = { - var listingTableClass = TABLE_CLASS - if (fixedWidth) { - listingTableClass += " table-fixed" - } + val listingTableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED val colWidth = 100.toDouble / headers.size val colWidthAttr = if (fixedWidth) colWidth + "%" else "" @@ -283,4 +287,24 @@ private[spark] object UIUtils extends Logging { } + + def makeProgressBar( + started: Int, + completed: Int, + failed: Int, + skipped:Int, + total: Int): Seq[Node] = { + val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) + val startWidth = "width: %s%%".format((started.toDouble/total)*100) + +
+ + {completed}/{total} + { if (failed > 0) s"($failed failed)" } + { if (skipped > 0) s"($skipped skipped)" } + +
+
+
+ } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 71b59b1d078ca..363cb96de7998 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -57,7 +57,7 @@ private[ui] class ExecutorsPage( val execInfoSorted = execInfo.sortBy(_.id) val execTable = - +
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala new file mode 100644 index 0000000000000..ea2d187a0e8e4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import scala.xml.{Node, NodeSeq} + +import javax.servlet.http.HttpServletRequest + +import org.apache.spark.JobExecutionStatus +import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.jobs.UIData.JobUIData + +/** Page showing list of all ongoing and recently finished jobs */ +private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { + private val startTime: Option[Long] = parent.sc.map(_.startTime) + private val listener = parent.listener + + private def jobsTable(jobs: Seq[JobUIData]): Seq[Node] = { + val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined) + + val columns: Seq[Node] = { + + + + + + + } + + def makeRow(job: JobUIData): Seq[Node] = { + val lastStageInfo = listener.stageIdToInfo.get(job.stageIds.max) + val lastStageData = lastStageInfo.flatMap { s => + listener.stageIdToData.get((s.stageId, s.attemptId)) + } + val isComplete = job.status == JobExecutionStatus.SUCCEEDED + val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") + val lastStageDescription = lastStageData.flatMap(_.description).getOrElse("") + val duration: Option[Long] = { + job.startTime.map { start => + val end = job.endTime.getOrElse(System.currentTimeMillis()) + end - start + } + } + val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") + val formattedSubmissionTime = job.startTime.map(UIUtils.formatDate).getOrElse("Unknown") + val detailUrl = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), job.jobId) + + + + + + + + + } + +
Executor ID Address{if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id"}DescriptionSubmittedDurationStages: Succeeded/TotalTasks (for all stages): Succeeded/Total
+ {job.jobId} {job.jobGroup.map(id => s"($id)").getOrElse("")} + +
{lastStageDescription}
+ {lastStageName} +
+ {formattedSubmissionTime} + {formattedDuration} + {job.completedStageIndices.size}/{job.stageIds.size - job.numSkippedStages} + {if (job.numFailedStages > 0) s"(${job.numFailedStages} failed)"} + {if (job.numSkippedStages > 0) s"(${job.numSkippedStages} skipped)"} + + {UIUtils.makeProgressBar(started = job.numActiveTasks, completed = job.numCompletedTasks, + failed = job.numFailedTasks, skipped = job.numSkippedTasks, + total = job.numTasks - job.numSkippedTasks)} +
+ {columns} + + {jobs.map(makeRow)} + +
+ } + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val activeJobs = listener.activeJobs.values.toSeq + val completedJobs = listener.completedJobs.reverse.toSeq + val failedJobs = listener.failedJobs.reverse.toSeq + val now = System.currentTimeMillis + + val activeJobsTable = + jobsTable(activeJobs.sortBy(_.startTime.getOrElse(-1L)).reverse) + val completedJobsTable = + jobsTable(completedJobs.sortBy(_.endTime.getOrElse(-1L)).reverse) + val failedJobsTable = + jobsTable(failedJobs.sortBy(_.endTime.getOrElse(-1L)).reverse) + + val summary: NodeSeq = +
+
    + {if (startTime.isDefined) { + // Total duration is not meaningful unless the UI is live +
  • + Total Duration: + {UIUtils.formatDuration(now - startTime.get)} +
  • + }} +
  • + Scheduling Mode: + {listener.schedulingMode.map(_.toString).getOrElse("Unknown")} +
  • +
  • + Active Jobs: + {activeJobs.size} +
  • +
  • + Completed Jobs: + {completedJobs.size} +
  • +
  • + Failed Jobs: + {failedJobs.size} +
  • +
+
+ + val content = summary ++ +

Active Jobs ({activeJobs.size})

++ activeJobsTable ++ +

Completed Jobs ({completedJobs.size})

++ completedJobsTable ++ +

Failed Jobs ({failedJobs.size})

++ failedJobsTable + + val helpText = """A job is triggered by a action, like "count()" or "saveAsTextFile()".""" + + " Click on a job's title to see information about the stages of tasks associated with" + + " the job." + + UIUtils.headerSparkPage("Spark Jobs", content, parent, helpText = Some(helpText)) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala similarity index 87% rename from core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala rename to core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 83a7898071c9b..b0f8ca2ab0d3f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -25,7 +25,7 @@ import org.apache.spark.scheduler.Schedulable import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing list of all ongoing and recently finished stages and pools */ -private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") { +private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { private val sc = parent.sc private val listener = parent.listener private def isFairScheduler = parent.isFairScheduler @@ -41,11 +41,14 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") val activeStagesTable = new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, - parent, parent.killEnabled) + parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, + killEnabled = parent.killEnabled) val completedStagesTable = - new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent) + new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent.basePath, + parent.listener, isFairScheduler = parent.isFairScheduler, killEnabled = false) val failedStagesTable = - new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent) + new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent.basePath, + parent.listener, isFairScheduler = parent.isFairScheduler) // For now, pool information is only accessible in live UIs val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable]) @@ -93,7 +96,7 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")

Failed Stages ({numFailedStages})

++ failedStagesTable.toNodeSeq - UIUtils.headerSparkPage("Spark Stages", content, parent) + UIUtils.headerSparkPage("Spark Stages (for all jobs)", content, parent) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index fa0f96bff34ff..9836d11a6d85f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -25,7 +25,7 @@ import org.apache.spark.ui.jobs.UIData.StageUIData import org.apache.spark.util.Utils /** Stage summary grouped by executors. */ -private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobProgressTab) { +private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: StagesTab) { private val listener = parent.listener def toNodeSeq: Seq[Node] = { @@ -36,7 +36,7 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobPr /** Special table which merges two header cells. */ private def executorTable[T](): Seq[Node] = { - +
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala new file mode 100644 index 0000000000000..77d36209c6048 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import scala.collection.mutable +import scala.xml.{NodeSeq, Node} + +import javax.servlet.http.HttpServletRequest + +import org.apache.spark.JobExecutionStatus +import org.apache.spark.scheduler.StageInfo +import org.apache.spark.ui.{UIUtils, WebUIPage} + +/** Page showing statistics and stage list for a given job */ +private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { + private val listener = parent.listener + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val jobId = request.getParameter("id").toInt + val jobDataOption = listener.jobIdToData.get(jobId) + if (jobDataOption.isEmpty) { + val content = +
+

No information to display for job {jobId}

+
+ return UIUtils.headerSparkPage( + s"Details for Job $jobId", content, parent) + } + val jobData = jobDataOption.get + val isComplete = jobData.status != JobExecutionStatus.RUNNING + val stages = jobData.stageIds.map { stageId => + // This could be empty if the JobProgressListener hasn't received information about the + // stage or if the stage information has been garbage collected + listener.stageIdToInfo.getOrElse(stageId, + new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, "Unknown")) + } + + val activeStages = mutable.Buffer[StageInfo]() + val completedStages = mutable.Buffer[StageInfo]() + // If the job is completed, then any pending stages are displayed as "skipped": + val pendingOrSkippedStages = mutable.Buffer[StageInfo]() + val failedStages = mutable.Buffer[StageInfo]() + for (stage <- stages) { + if (stage.submissionTime.isEmpty) { + pendingOrSkippedStages += stage + } else if (stage.completionTime.isDefined) { + if (stage.failureReason.isDefined) { + failedStages += stage + } else { + completedStages += stage + } + } else { + activeStages += stage + } + } + + val activeStagesTable = + new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, + parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, + killEnabled = parent.killEnabled) + val pendingOrSkippedStagesTable = + new StageTableBase(pendingOrSkippedStages.sortBy(_.stageId).reverse, + parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, + killEnabled = false) + val completedStagesTable = + new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent.basePath, + parent.listener, isFairScheduler = parent.isFairScheduler, killEnabled = false) + val failedStagesTable = + new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent.basePath, + parent.listener, isFairScheduler = parent.isFairScheduler) + + val shouldShowActiveStages = activeStages.nonEmpty + val shouldShowPendingStages = !isComplete && pendingOrSkippedStages.nonEmpty + val shouldShowCompletedStages = completedStages.nonEmpty + val shouldShowSkippedStages = isComplete && pendingOrSkippedStages.nonEmpty + val shouldShowFailedStages = failedStages.nonEmpty + + val summary: NodeSeq = +
+
    +
  • + Status: + {jobData.status} +
  • + { + if (jobData.jobGroup.isDefined) { +
  • + Job Group: + {jobData.jobGroup.get} +
  • + } + } + { + if (shouldShowActiveStages) { +
  • + Active Stages: + {activeStages.size} +
  • + } + } + { + if (shouldShowPendingStages) { +
  • + + Pending Stages: + {pendingOrSkippedStages.size} +
  • + } + } + { + if (shouldShowCompletedStages) { +
  • + Completed Stages: + {completedStages.size} +
  • + } + } + { + if (shouldShowSkippedStages) { +
  • + Skipped Stages: + {pendingOrSkippedStages.size} +
  • + } + } + { + if (shouldShowFailedStages) { +
  • + Failed Stages: + {failedStages.size} +
  • + } + } +
+
+ + var content = summary + if (shouldShowActiveStages) { + content ++=

Active Stages ({activeStages.size})

++ + activeStagesTable.toNodeSeq + } + if (shouldShowPendingStages) { + content ++=

Pending Stages ({pendingOrSkippedStages.size})

++ + pendingOrSkippedStagesTable.toNodeSeq + } + if (shouldShowCompletedStages) { + content ++=

Completed Stages ({completedStages.size})

++ + completedStagesTable.toNodeSeq + } + if (shouldShowSkippedStages) { + content ++=

Skipped Stages ({pendingOrSkippedStages.size})

++ + pendingOrSkippedStagesTable.toNodeSeq + } + if (shouldShowFailedStages) { + content ++=

Failed Stages ({failedStages.size})

++ + failedStagesTable.toNodeSeq + } + UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 8bbde51e1801c..72935beb3a34a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import scala.collection.mutable.{HashMap, ListBuffer} +import scala.collection.mutable.{HashMap, HashSet, ListBuffer} import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi @@ -40,48 +40,145 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { import JobProgressListener._ + // Define a handful of type aliases so that data structures' types can serve as documentation. + // These type aliases are public because they're used in the types of public fields: + type JobId = Int type StageId = Int type StageAttemptId = Int + type PoolName = String + type ExecutorId = String - // How many stages to remember - val retainedStages = conf.getInt("spark.ui.retainedStages", DEFAULT_RETAINED_STAGES) - // How many jobs to remember - val retailedJobs = conf.getInt("spark.ui.retainedJobs", DEFAULT_RETAINED_JOBS) - + // Jobs: val activeJobs = new HashMap[JobId, JobUIData] val completedJobs = ListBuffer[JobUIData]() val failedJobs = ListBuffer[JobUIData]() val jobIdToData = new HashMap[JobId, JobUIData] + // Stages: val activeStages = new HashMap[StageId, StageInfo] val completedStages = ListBuffer[StageInfo]() + val skippedStages = ListBuffer[StageInfo]() val failedStages = ListBuffer[StageInfo]() val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData] val stageIdToInfo = new HashMap[StageId, StageInfo] - - // Number of completed and failed stages, may not actually equal to completedStages.size and - // failedStages.size respectively due to completedStage and failedStages only maintain the latest - // part of the stages, the earlier ones will be removed when there are too many stages for - // memory sake. + val stageIdToActiveJobIds = new HashMap[StageId, HashSet[JobId]] + val poolToActiveStages = HashMap[PoolName, HashMap[StageId, StageInfo]]() + // Total of completed and failed stages that have ever been run. These may be greater than + // `completedStages.size` and `failedStages.size` if we have run more stages or jobs than + // JobProgressListener's retention limits. var numCompletedStages = 0 var numFailedStages = 0 - // Map from pool name to a hash map (map from stage id to StageInfo). - val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]() - - val executorIdToBlockManagerId = HashMap[String, BlockManagerId]() + // Misc: + val executorIdToBlockManagerId = HashMap[ExecutorId, BlockManagerId]() + def blockManagerIds = executorIdToBlockManagerId.values.toSeq var schedulingMode: Option[SchedulingMode] = None - def blockManagerIds = executorIdToBlockManagerId.values.toSeq + // To limit the total memory usage of JobProgressListener, we only track information for a fixed + // number of non-active jobs and stages (there is no limit for active jobs and stages): + + val retainedStages = conf.getInt("spark.ui.retainedStages", DEFAULT_RETAINED_STAGES) + val retainedJobs = conf.getInt("spark.ui.retainedJobs", DEFAULT_RETAINED_JOBS) + + // We can test for memory leaks by ensuring that collections that track non-active jobs and + // stages do not grow without bound and that collections for active jobs/stages eventually become + // empty once Spark is idle. Let's partition our collections into ones that should be empty + // once Spark is idle and ones that should have a hard- or soft-limited sizes. + // These methods are used by unit tests, but they're defined here so that people don't forget to + // update the tests when adding new collections. Some collections have multiple levels of + // nesting, etc, so this lets us customize our notion of "size" for each structure: + + // These collections should all be empty once Spark is idle (no active stages / jobs): + private[spark] def getSizesOfActiveStateTrackingCollections: Map[String, Int] = { + Map( + "activeStages" -> activeStages.size, + "activeJobs" -> activeJobs.size, + "poolToActiveStages" -> poolToActiveStages.values.map(_.size).sum, + "stageIdToActiveJobIds" -> stageIdToActiveJobIds.values.map(_.size).sum + ) + } + + // These collections should stop growing once we have run at least `spark.ui.retainedStages` + // stages and `spark.ui.retainedJobs` jobs: + private[spark] def getSizesOfHardSizeLimitedCollections: Map[String, Int] = { + Map( + "completedJobs" -> completedJobs.size, + "failedJobs" -> failedJobs.size, + "completedStages" -> completedStages.size, + "skippedStages" -> skippedStages.size, + "failedStages" -> failedStages.size + ) + } + + // These collections may grow arbitrarily, but once Spark becomes idle they should shrink back to + // some bound based on the `spark.ui.retainedStages` and `spark.ui.retainedJobs` settings: + private[spark] def getSizesOfSoftSizeLimitedCollections: Map[String, Int] = { + Map( + "jobIdToData" -> jobIdToData.size, + "stageIdToData" -> stageIdToData.size, + "stageIdToStageInfo" -> stageIdToInfo.size + ) + } + + /** If stages is too large, remove and garbage collect old stages */ + private def trimStagesIfNecessary(stages: ListBuffer[StageInfo]) = synchronized { + if (stages.size > retainedStages) { + val toRemove = math.max(retainedStages / 10, 1) + stages.take(toRemove).foreach { s => + stageIdToData.remove((s.stageId, s.attemptId)) + stageIdToInfo.remove(s.stageId) + } + stages.trimStart(toRemove) + } + } + + /** If jobs is too large, remove and garbage collect old jobs */ + private def trimJobsIfNecessary(jobs: ListBuffer[JobUIData]) = synchronized { + if (jobs.size > retainedJobs) { + val toRemove = math.max(retainedJobs / 10, 1) + jobs.take(toRemove).foreach { job => + jobIdToData.remove(job.jobId) + } + jobs.trimStart(toRemove) + } + } override def onJobStart(jobStart: SparkListenerJobStart) = synchronized { - val jobGroup = Option(jobStart.properties).map(_.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) + val jobGroup = for ( + props <- Option(jobStart.properties); + group <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) + ) yield group val jobData: JobUIData = - new JobUIData(jobStart.jobId, jobStart.stageIds, jobGroup, JobExecutionStatus.RUNNING) + new JobUIData( + jobId = jobStart.jobId, + startTime = Some(System.currentTimeMillis), + endTime = None, + stageIds = jobStart.stageIds, + jobGroup = jobGroup, + status = JobExecutionStatus.RUNNING) + // Compute (a potential underestimate of) the number of tasks that will be run by this job. + // This may be an underestimate because the job start event references all of the result + // stages's transitive stage dependencies, but some of these stages might be skipped if their + // output is available from earlier runs. + // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. + jobData.numTasks = { + val allStages = jobStart.stageInfos + val missingStages = allStages.filter(_.completionTime.isEmpty) + missingStages.map(_.numTasks).sum + } jobIdToData(jobStart.jobId) = jobData activeJobs(jobStart.jobId) = jobData + for (stageId <- jobStart.stageIds) { + stageIdToActiveJobIds.getOrElseUpdate(stageId, new HashSet[StageId]).add(jobStart.jobId) + } + // If there's no information for a stage, store the StageInfo received from the scheduler + // so that we can display stage descriptions for pending stages: + for (stageInfo <- jobStart.stageInfos) { + stageIdToInfo.getOrElseUpdate(stageInfo.stageId, stageInfo) + stageIdToData.getOrElseUpdate((stageInfo.stageId, stageInfo.attemptId), new StageUIData) + } } override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { @@ -89,14 +186,31 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { logWarning(s"Job completed for unknown job ${jobEnd.jobId}") new JobUIData(jobId = jobEnd.jobId) } + jobData.endTime = Some(System.currentTimeMillis()) jobEnd.jobResult match { case JobSucceeded => completedJobs += jobData + trimJobsIfNecessary(completedJobs) jobData.status = JobExecutionStatus.SUCCEEDED case JobFailed(exception) => failedJobs += jobData + trimJobsIfNecessary(failedJobs) jobData.status = JobExecutionStatus.FAILED } + for (stageId <- jobData.stageIds) { + stageIdToActiveJobIds.get(stageId).foreach { jobsUsingStage => + jobsUsingStage.remove(jobEnd.jobId) + stageIdToInfo.get(stageId).foreach { stageInfo => + if (stageInfo.submissionTime.isEmpty) { + // if this stage is pending, it won't complete, so mark it as "skipped": + skippedStages += stageInfo + trimStagesIfNecessary(skippedStages) + jobData.numSkippedStages += 1 + jobData.numSkippedTasks += stageInfo.numTasks + } + } + } + } } override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { @@ -118,23 +232,24 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { if (stage.failureReason.isEmpty) { completedStages += stage numCompletedStages += 1 - trimIfNecessary(completedStages) + trimStagesIfNecessary(completedStages) } else { failedStages += stage numFailedStages += 1 - trimIfNecessary(failedStages) + trimStagesIfNecessary(failedStages) } - } - /** If stages is too large, remove and garbage collect old stages */ - private def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized { - if (stages.size > retainedStages) { - val toRemove = math.max(retainedStages / 10, 1) - stages.take(toRemove).foreach { s => - stageIdToData.remove((s.stageId, s.attemptId)) - stageIdToInfo.remove(s.stageId) + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveStages -= 1 + if (stage.failureReason.isEmpty) { + jobData.completedStageIndices.add(stage.stageId) + } else { + jobData.numFailedStages += 1 } - stages.trimStart(toRemove) } } @@ -157,6 +272,14 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo]) stages(stage.stageId) = stage + + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveStages += 1 + } } override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { @@ -169,6 +292,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.numActiveTasks += 1 stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo)) } + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveTasks += 1 + } } override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { @@ -226,6 +356,20 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { taskData.taskInfo = info taskData.taskMetrics = metrics taskData.errorMessage = errorMessage + + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskEnd.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveTasks -= 1 + taskEnd.reason match { + case Success => + jobData.numCompletedTasks += 1 + case _ => + jobData.numFailedTasks += 1 + } + } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala new file mode 100644 index 0000000000000..b2bbfdee56946 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import org.apache.spark.scheduler.SchedulingMode +import org.apache.spark.ui.{SparkUI, SparkUITab} + +/** Web UI showing progress status of all jobs in the given SparkContext. */ +private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { + val sc = parent.sc + val killEnabled = parent.killEnabled + def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) + val listener = parent.jobProgressListener + + attachPage(new AllJobsPage(this)) + attachPage(new JobPage(this)) +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 770d99eea1c9d..5fc6cc7533150 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -25,7 +25,7 @@ import org.apache.spark.scheduler.{Schedulable, StageInfo} import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing specific pool details */ -private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") { +private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { private val sc = parent.sc private val listener = parent.listener @@ -37,8 +37,9 @@ private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") { case Some(s) => s.values.toSeq case None => Seq[StageInfo]() } - val activeStagesTable = - new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, parent) + val activeStagesTable = new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, + parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, + killEnabled = parent.killEnabled) // For now, pool information is only accessible in live UIs val pools = sc.map(_.getPoolForName(poolName).get).toSeq diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index 64178e1e33d41..df1899e7a9b84 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -24,7 +24,7 @@ import org.apache.spark.scheduler.{Schedulable, StageInfo} import org.apache.spark.ui.UIUtils /** Table showing list of pools */ -private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) { +private[ui] class PoolTable(pools: Seq[Schedulable], parent: StagesTab) { private val listener = parent.listener def toNodeSeq: Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 16bc3f6c18d09..bfa54f8492068 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.{Utils, Distribution} import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} /** Page showing statistics and task list for a given stage */ -private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { +private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -114,6 +114,10 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { ++ @@ -73,25 +72,11 @@ private[ui] class StageTableBase(
Executor ID Address Stage Id
} - private def makeProgressBar(started: Int, completed: Int, failed: Int, total: Int): Seq[Node] = - { - val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) - val startWidth = "width: %s%%".format((started.toDouble/total)*100) - -
- - {completed}/{total} { if (failed > 0) s"($failed failed)" else "" } - -
-
-
- } - private def makeDescription(s: StageInfo): Seq[Node] = { // scalastyle:off val killLink = if (killEnabled) { val killLinkUri = "%s/stages/stage/kill?id=%s&terminate=true" - .format(UIUtils.prependBaseUri(parent.basePath), s.stageId) + .format(UIUtils.prependBaseUri(basePath), s.stageId) val confirm = "return window.confirm('Are you sure you want to kill stage %s ?');" .format(s.stageId) @@ -101,7 +86,7 @@ private[ui] class StageTableBase( // scalastyle:on val nameLinkUri ="%s/stages/stage?id=%s&attempt=%s" - .format(UIUtils.prependBaseUri(parent.basePath), s.stageId, s.attemptId) + .format(UIUtils.prependBaseUri(basePath), s.stageId, s.attemptId) val nameLink = {s.name} val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) @@ -115,7 +100,7 @@ private[ui] class StageTableBase( Text("RDD: ") ++ // scalastyle:off cachedRddInfos.map { i => - {i.name} + {i.name} } // scalastyle:on }} @@ -167,7 +152,7 @@ private[ui] class StageTableBase( {if (isFairScheduler) { + .format(UIUtils.prependBaseUri(basePath), stageData.schedulingPool)}> {stageData.schedulingPool} @@ -180,8 +165,9 @@ private[ui] class StageTableBase( {formattedDuration} - {makeProgressBar(stageData.numActiveTasks, stageData.completedIndices.size, - stageData.numFailedTasks, s.numTasks)} + {UIUtils.makeProgressBar(started = stageData.numActiveTasks, + completed = stageData.completedIndices.size, failed = stageData.numFailedTasks, + skipped = 0, total = s.numTasks)} {inputReadWithUnit} {outputWriteWithUnit} @@ -195,9 +181,10 @@ private[ui] class StageTableBase( private[ui] class FailedStageTable( stages: Seq[StageInfo], - parent: JobProgressTab, - killEnabled: Boolean = false) - extends StageTableBase(stages, parent, killEnabled) { + basePath: String, + listener: JobProgressListener, + isFairScheduler: Boolean) + extends StageTableBase(stages, basePath, listener, isFairScheduler, killEnabled = false) { override protected def columns: Seq[Node] = super.columns ++ Failure Reason diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala similarity index 83% rename from core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala rename to core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 03ca918e2e8b3..937261de00e3a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -19,18 +19,16 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest -import org.apache.spark.SparkConf import org.apache.spark.scheduler.SchedulingMode import org.apache.spark.ui.{SparkUI, SparkUITab} -/** Web UI showing progress status of all jobs in the given SparkContext. */ -private[ui] class JobProgressTab(parent: SparkUI) extends SparkUITab(parent, "stages") { +/** Web UI showing progress status of all stages in the given SparkContext. */ +private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages") { val sc = parent.sc - val conf = sc.map(_.conf).getOrElse(new SparkConf) - val killEnabled = sc.map(_.conf.getBoolean("spark.ui.killEnabled", true)).getOrElse(false) + val killEnabled = parent.killEnabled val listener = parent.jobProgressListener - attachPage(new JobProgressPage(this)) + attachPage(new AllStagesPage(this)) attachPage(new StagePage(this)) attachPage(new PoolPage(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index eb371bd0ea7ed..ca942c4051c84 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -20,6 +20,9 @@ package org.apache.spark.ui.jobs /** * Names of the CSS classes corresponding to each type of task detail. Used to allow users * to optionally show/hide columns. + * + * If new optional metrics are added here, they should also be added to the end of webui.css + * to have the style set to "display: none;" by default. */ private object TaskDetailsClassNames { val SCHEDULER_DELAY = "scheduler_delay" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 2f7d618df5f6f..48fd7caa1a1ed 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -40,9 +40,28 @@ private[jobs] object UIData { class JobUIData( var jobId: Int = -1, + var startTime: Option[Long] = None, + var endTime: Option[Long] = None, var stageIds: Seq[Int] = Seq.empty, var jobGroup: Option[String] = None, - var status: JobExecutionStatus = JobExecutionStatus.UNKNOWN + var status: JobExecutionStatus = JobExecutionStatus.UNKNOWN, + /* Tasks */ + // `numTasks` is a potential underestimate of the true number of tasks that this job will run. + // This may be an underestimate because the job start event references all of the result + // stages's transitive stage dependencies, but some of these stages might be skipped if their + // output is available from earlier runs. + // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. + var numTasks: Int = 0, + var numActiveTasks: Int = 0, + var numCompletedTasks: Int = 0, + var numSkippedTasks: Int = 0, + var numFailedTasks: Int = 0, + /* Stages */ + var numActiveStages: Int = 0, + // This needs to be a set instead of a simple count to prevent double-counting of rerun stages: + var completedStageIndices: OpenHashSet[Int] = new OpenHashSet[Int](), + var numSkippedStages: Int = 0, + var numFailedStages: Int = 0 ) class StageUIData { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 7e536edfe807b..e7b80e8774b9c 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -31,6 +31,21 @@ import org.apache.spark.scheduler._ import org.apache.spark.storage._ import org.apache.spark._ +/** + * Serializes SparkListener events to/from JSON. This protocol provides strong backwards- + * and forwards-compatibility guarantees: any version of Spark should be able to read JSON output + * written by any other version, including newer versions. + * + * JsonProtocolSuite contains backwards-compatibility tests which check that the current version of + * JsonProtocol is able to read output written by earlier versions. We do not currently have tests + * for reading newer JSON output with older Spark versions. + * + * To ensure that we provide these guarantees, follow these rules when modifying these methods: + * + * - Never delete any JSON fields. + * - Any new JSON fields should be optional; use `Utils.jsonOption` when reading these fields + * in `*FromJson` methods. + */ private[spark] object JsonProtocol { // TODO: Remove this file and put JSON serialization into each individual class. @@ -121,6 +136,7 @@ private[spark] object JsonProtocol { val properties = propertiesToJson(jobStart.properties) ("Event" -> Utils.getFormattedClassName(jobStart)) ~ ("Job ID" -> jobStart.jobId) ~ + ("Stage Infos" -> jobStart.stageInfos.map(stageInfoToJson)) ~ // Added in Spark 1.2.0 ("Stage IDs" -> jobStart.stageIds) ~ ("Properties" -> properties) } @@ -455,7 +471,12 @@ private[spark] object JsonProtocol { val jobId = (json \ "Job ID").extract[Int] val stageIds = (json \ "Stage IDs").extract[List[JValue]].map(_.extract[Int]) val properties = propertiesFromJson(json \ "Properties") - SparkListenerJobStart(jobId, stageIds, properties) + // The "Stage Infos" field was added in Spark 1.2.0 + val stageInfos = Utils.jsonOption(json \ "Stage Infos") + .map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse { + stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, "unknown")) + } + SparkListenerJobStart(jobId, stageInfos, properties) } def jobEndFromJson(json: JValue): SparkListenerJobEnd = { @@ -667,6 +688,10 @@ private[spark] object JsonProtocol { } def blockManagerIdFromJson(json: JValue): BlockManagerId = { + // On metadata fetch fail, block manager ID can be null (SPARK-4471) + if (json == JNothing) { + return null + } val executorId = (json \ "Executor ID").extract[String] val host = (json \ "Host").extract[String] val port = (json \ "Port").extract[Int] diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 2889e171f627e..ac40f19ed6799 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -52,7 +52,7 @@ private[spark] class MetadataCleaner( logDebug( "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " + "and period of " + periodSeconds + " secs") - timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) + timer.schedule(task, delaySeconds * 1000, periodSeconds * 1000) } def cancel() { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 26fa0cb6d7bde..8a0f5a602de12 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -76,10 +76,6 @@ class ExternalAppendOnlyMap[K, V, C]( private val sparkConf = SparkEnv.get.conf private val diskBlockManager = blockManager.diskBlockManager - // Number of pairs inserted since last spill; note that we count them even if a value is merged - // with a previous key in case we're doing something like groupBy where the result grows - protected[this] var elementsRead = 0L - /** * Size of object batches when reading/writing from serializers. * @@ -132,7 +128,7 @@ class ExternalAppendOnlyMap[K, V, C]( currentMap = new SizeTrackingAppendOnlyMap[K, C] } currentMap.changeValue(curEntry._1, update) - elementsRead += 1 + addElementsRead() } } @@ -209,8 +205,6 @@ class ExternalAppendOnlyMap[K, V, C]( } spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) - - elementsRead = 0 } def diskBytesSpilled: Long = _diskBytesSpilled diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index c1ce13683b569..15bda1c9cc29c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -119,10 +119,6 @@ private[spark] class ExternalSorter[K, V, C]( private var map = new SizeTrackingAppendOnlyMap[(Int, K), C] private var buffer = new SizeTrackingPairBuffer[(Int, K), C] - // Number of pairs read from input since last spill; note that we count them even if a value is - // merged with a previous key in case we're doing something like groupBy where the result grows - protected[this] var elementsRead = 0L - // Total spilling statistics private var _diskBytesSpilled = 0L @@ -204,15 +200,22 @@ private[spark] class ExternalSorter[K, V, C]( if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) } while (records.hasNext) { - elementsRead += 1 + addElementsRead() kv = records.next() map.changeValue((getPartition(kv._1), kv._1), update) maybeSpillCollection(usingMap = true) } + } else if (bypassMergeSort) { + // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies + if (records.hasNext) { + spillToPartitionFiles(records.map { kv => + ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) + }) + } } else { // Stick values into our buffer while (records.hasNext) { - elementsRead += 1 + addElementsRead() val kv = records.next() buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) maybeSpillCollection(usingMap = false) @@ -340,6 +343,10 @@ private[spark] class ExternalSorter[K, V, C]( * @param collection whichever collection we're using (map or buffer) */ private def spillToPartitionFiles(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { + spillToPartitionFiles(collection.iterator) + } + + private def spillToPartitionFiles(iterator: Iterator[((Int, K), C)]): Unit = { assert(bypassMergeSort) // Create our file writers if we haven't done so yet @@ -354,9 +361,9 @@ private[spark] class ExternalSorter[K, V, C]( } } - val it = collection.iterator // No need to sort stuff, just write each element out - while (it.hasNext) { - val elem = it.next() + // No need to sort stuff, just write each element out + while (iterator.hasNext) { + val elem = iterator.next() val partitionId = elem._1._1 val key = elem._1._2 val value = elem._2 @@ -752,6 +759,12 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled context.taskMetrics.diskBytesSpilled += diskBytesSpilled + context.taskMetrics.shuffleWriteMetrics.filter(_ => bypassMergeSort).foreach { m => + if (curWriteMetrics != null) { + m.shuffleBytesWritten += curWriteMetrics.shuffleBytesWritten + m.shuffleWriteTime += curWriteMetrics.shuffleWriteTime + } + } lengths } diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 0e4c6d633a4a9..9f54312074856 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -24,10 +24,7 @@ import org.apache.spark.SparkEnv * Spills contents of an in-memory collection to disk when the memory threshold * has been exceeded. */ -private[spark] trait Spillable[C] { - - this: Logging => - +private[spark] trait Spillable[C] extends Logging { /** * Spills the current in-memory collection to disk, and releases the memory. * @@ -36,16 +33,29 @@ private[spark] trait Spillable[C] { protected def spill(collection: C): Unit // Number of elements read from input since last spill - protected var elementsRead: Long + protected def elementsRead: Long = _elementsRead + + // Called by subclasses every time a record is read + // It's used for checking spilling frequency + protected def addElementsRead(): Unit = { _elementsRead += 1 } // Memory manager that can be used to acquire/release memory private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager - // What threshold of elementsRead we start estimating collection size at + // Threshold for `elementsRead` before we start tracking this collection's memory usage private[this] val trackMemoryThreshold = 1000 - // How much of the shared memory pool this collection has claimed - private[this] var myMemoryThreshold = 0L + // Initial threshold for the size of a collection before we start tracking its memory usage + // Exposed for testing + private[this] val initialMemoryThreshold: Long = + SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) + + // Threshold for this collection's size in bytes before we start tracking its memory usage + // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 + private[this] var myMemoryThreshold = initialMemoryThreshold + + // Number of elements read from input since last spill + private[this] var _elementsRead = 0L // Number of bytes spilled in total private[this] var _memoryBytesSpilled = 0L @@ -76,6 +86,7 @@ private[spark] trait Spillable[C] { spill(collection) + _elementsRead = 0 // Keep track of spills, and release memory _memoryBytesSpilled += currentMemory releaseMemoryForThisThread() @@ -94,8 +105,9 @@ private[spark] trait Spillable[C] { * Release our memory back to the shuffle pool so that other threads can grab it. */ private def releaseMemoryForThisThread(): Unit = { - shuffleMemoryManager.release(myMemoryThreshold) - myMemoryThreshold = 0L + // The amount we requested does not include the initial memory tracking threshold + shuffleMemoryManager.release(myMemoryThreshold - initialMemoryThreshold) + myMemoryThreshold = initialMemoryThreshold } /** @@ -106,7 +118,7 @@ private[spark] trait Spillable[C] { @inline private def logSpillage(size: Long) { val threadId = Thread.currentThread().getId logInfo("Thread %d spilling in-memory map of %s to disk (%d time%s so far)" - .format(threadId, org.apache.spark.util.Utils.bytesToString(size), - _spillCount, if (_spillCount > 1) "s" else "")) + .format(threadId, org.apache.spark.util.Utils.bytesToString(size), + _spillCount, if (_spillCount > 1) "s" else "")) } } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 4b27477790212..ce804f94f3267 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -37,20 +37,24 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { .set("spark.dynamicAllocation.enabled", "true") intercept[SparkException] { new SparkContext(conf) } SparkEnv.get.stop() // cleanup the created environment + SparkContext.clearActiveContext() // Only min val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "1") intercept[SparkException] { new SparkContext(conf1) } SparkEnv.get.stop() + SparkContext.clearActiveContext() // Only max val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "2") intercept[SparkException] { new SparkContext(conf2) } SparkEnv.get.stop() + SparkContext.clearActiveContext() // Both min and max, but min > max intercept[SparkException] { createSparkContext(2, 1) } SparkEnv.get.stop() + SparkContext.clearActiveContext() // Both min and max, and min == max val sc1 = createSparkContext(1, 1) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index cda942e15a704..85e5f9ab444b3 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -95,14 +95,14 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - // 10 partitions from 4 keys - val NUM_BLOCKS = 10 + // 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys + val NUM_BLOCKS = 201 val a = sc.parallelize(1 to 4, NUM_BLOCKS) val b = a.map(x => (x, x*2)) // NOTE: The default Java serializer doesn't create zero-sized blocks. // So, use Kryo - val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10)) + val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(NUM_BLOCKS)) .setSerializer(new KryoSerializer(conf)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId @@ -122,13 +122,13 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - // 10 partitions from 4 keys - val NUM_BLOCKS = 10 + // 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys + val NUM_BLOCKS = 201 val a = sc.parallelize(1 to 4, NUM_BLOCKS) val b = a.map(x => (x, x*2)) // NOTE: The default Java serializer should create zero-sized blocks - val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10)) + val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(NUM_BLOCKS)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 4) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 31edad1c56c73..1362022104195 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -21,15 +21,68 @@ import org.scalatest.FunSuite import org.apache.hadoop.io.BytesWritable -class SparkContextSuite extends FunSuite { - //Regression test for SPARK-3121 +class SparkContextSuite extends FunSuite with LocalSparkContext { + + /** Allows system properties to be changed in tests */ + private def withSystemProperty[T](property: String, value: String)(block: => T): T = { + val originalValue = System.getProperty(property) + try { + System.setProperty(property, value) + block + } finally { + if (originalValue == null) { + System.clearProperty(property) + } else { + System.setProperty(property, originalValue) + } + } + } + + test("Only one SparkContext may be active at a time") { + // Regression test for SPARK-4180 + withSystemProperty("spark.driver.allowMultipleContexts", "false") { + val conf = new SparkConf().setAppName("test").setMaster("local") + sc = new SparkContext(conf) + // A SparkContext is already running, so we shouldn't be able to create a second one + intercept[SparkException] { new SparkContext(conf) } + // After stopping the running context, we should be able to create a new one + resetSparkContext() + sc = new SparkContext(conf) + } + } + + test("Can still construct a new SparkContext after failing to construct a previous one") { + withSystemProperty("spark.driver.allowMultipleContexts", "false") { + // This is an invalid configuration (no app name or master URL) + intercept[SparkException] { + new SparkContext(new SparkConf()) + } + // Even though those earlier calls failed, we should still be able to create a new context + sc = new SparkContext(new SparkConf().setMaster("local").setAppName("test")) + } + } + + test("Check for multiple SparkContexts can be disabled via undocumented debug option") { + withSystemProperty("spark.driver.allowMultipleContexts", "true") { + var secondSparkContext: SparkContext = null + try { + val conf = new SparkConf().setAppName("test").setMaster("local") + sc = new SparkContext(conf) + secondSparkContext = new SparkContext(conf) + } finally { + Option(secondSparkContext).foreach(_.stop()) + } + } + } + test("BytesWritable implicit conversion is correct") { + // Regression test for SPARK-3121 val bytesWritable = new BytesWritable() val inputArray = (1 to 10).map(_.toByte).toArray bytesWritable.set(inputArray, 0, 10) bytesWritable.set(inputArray, 0, 5) - val converter = SparkContext.bytesWritableConverter() + val converter = WritableConverter.bytesWritableConverter() val byteArray = converter.convert(bytesWritable) assert(byteArray.length === 5) diff --git a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala index 94a2bdd74e744..d2dae34be7bfb 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala @@ -23,17 +23,26 @@ import org.scalatest.Matchers class ClientSuite extends FunSuite with Matchers { test("correctly validates driver jar URL's") { ClientArguments.isValidJarUrl("http://someHost:8080/foo.jar") should be (true) - ClientArguments.isValidJarUrl("file://some/path/to/a/jarFile.jar") should be (true) + + // file scheme with authority and path is valid. + ClientArguments.isValidJarUrl("file://somehost/path/to/a/jarFile.jar") should be (true) + + // file scheme without path is not valid. + // In this case, jarFile.jar is recognized as authority. + ClientArguments.isValidJarUrl("file://jarFile.jar") should be (false) + + // file scheme without authority but with triple slash is valid. + ClientArguments.isValidJarUrl("file:///some/path/to/a/jarFile.jar") should be (true) ClientArguments.isValidJarUrl("hdfs://someHost:1234/foo.jar") should be (true) ClientArguments.isValidJarUrl("hdfs://someHost:1234/foo") should be (false) ClientArguments.isValidJarUrl("/missing/a/protocol/jarfile.jar") should be (false) ClientArguments.isValidJarUrl("not-even-a-path.jar") should be (false) - // No authority + // This URI doesn't have authority and path. ClientArguments.isValidJarUrl("hdfs:someHost:1234/jarfile.jar") should be (false) - // Invalid syntax + // Invalid syntax. ClientArguments.isValidJarUrl("hdfs:") should be (false) } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 6d2e696dc2fc4..e079ca3b1e896 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -739,6 +739,11 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("zipWithIndex chained with other RDDs (SPARK-4433)") { + val count = sc.parallelize(0 until 10, 2).zipWithIndex().repartition(4).count() + assert(count === 10) + } + test("zipWithUniqueId") { val n = 10 val data = sc.parallelize(0 until n, 3) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 1809b5396d53e..472191551a01f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -579,13 +579,13 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // single 10M result val thrown = intercept[SparkException] {sc.makeRDD(genBytes(10 << 20)(0), 1).collect()} - assert(thrown.getMessage().contains("bigger than maxResultSize")) + assert(thrown.getMessage().contains("bigger than spark.driver.maxResultSize")) // multiple 1M results val thrown2 = intercept[SparkException] { sc.makeRDD(0 until 10, 10).map(genBytes(1 << 20)).collect() } - assert(thrown2.getMessage().contains("bigger than maxResultSize")) + assert(thrown2.getMessage().contains("bigger than spark.driver.maxResultSize")) } test("speculative and noPref task should be scheduled after node-local") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala index bef8d3a58ba63..e60e70afd3218 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala @@ -30,9 +30,11 @@ import java.nio.ByteBuffer import java.util.Collections import java.util import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with EasyMockSugar { - test("mesos resource offer is launching tasks") { + + test("mesos resource offers result in launching tasks") { def createOffer(id: Int, mem: Int, cpu: Int) = { val builder = Offer.newBuilder() builder.addResourcesBuilder() @@ -43,46 +45,61 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea .setName("cpus") .setType(Value.Type.SCALAR) .setScalar(Scalar.newBuilder().setValue(cpu)) - builder.setId(OfferID.newBuilder().setValue(id.toString).build()).setFrameworkId(FrameworkID.newBuilder().setValue("f1")) - .setSlaveId(SlaveID.newBuilder().setValue("s1")).setHostname("localhost").build() + builder.setId(OfferID.newBuilder().setValue(s"o${id.toString}").build()).setFrameworkId(FrameworkID.newBuilder().setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")).setHostname(s"host${id.toString}").build() } val driver = EasyMock.createMock(classOf[SchedulerDriver]) val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl]) val sc = EasyMock.createMock(classOf[SparkContext]) - EasyMock.expect(sc.executorMemory).andReturn(100).anyTimes() EasyMock.expect(sc.getSparkHome()).andReturn(Option("/path")).anyTimes() EasyMock.expect(sc.executorEnvs).andReturn(new mutable.HashMap).anyTimes() EasyMock.expect(sc.conf).andReturn(new SparkConf).anyTimes() EasyMock.replay(sc) + val minMem = MemoryUtils.calculateTotalMemory(sc).toInt val minCpu = 4 - val offers = new java.util.ArrayList[Offer] - offers.add(createOffer(1, minMem, minCpu)) - offers.add(createOffer(1, minMem - 1, minCpu)) + + val mesosOffers = new java.util.ArrayList[Offer] + mesosOffers.add(createOffer(1, minMem, minCpu)) + mesosOffers.add(createOffer(2, minMem - 1, minCpu)) + mesosOffers.add(createOffer(3, minMem, minCpu)) + val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") - val workerOffers = Seq(offers.get(0)).map(o => new WorkerOffer( - o.getSlaveId.getValue, - o.getHostname, + + val expectedWorkerOffers = new ArrayBuffer[WorkerOffer](2) + expectedWorkerOffers.append(new WorkerOffer( + mesosOffers.get(0).getSlaveId.getValue, + mesosOffers.get(0).getHostname, + 2 + )) + expectedWorkerOffers.append(new WorkerOffer( + mesosOffers.get(2).getSlaveId.getValue, + mesosOffers.get(2).getHostname, 2 )) val taskDesc = new TaskDescription(1L, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) - EasyMock.expect(taskScheduler.resourceOffers(EasyMock.eq(workerOffers))).andReturn(Seq(Seq(taskDesc))) + EasyMock.expect(taskScheduler.resourceOffers(EasyMock.eq(expectedWorkerOffers))).andReturn(Seq(Seq(taskDesc))) EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() EasyMock.replay(taskScheduler) + val capture = new Capture[util.Collection[TaskInfo]] EasyMock.expect( driver.launchTasks( - EasyMock.eq(Collections.singleton(offers.get(0).getId)), + EasyMock.eq(Collections.singleton(mesosOffers.get(0).getId)), EasyMock.capture(capture), EasyMock.anyObject(classOf[Filters]) ) - ).andReturn(Status.valueOf(1)) - EasyMock.expect(driver.declineOffer(offers.get(1).getId)).andReturn(Status.valueOf(1)) + ).andReturn(Status.valueOf(1)).once + EasyMock.expect(driver.declineOffer(mesosOffers.get(1).getId)).andReturn(Status.valueOf(1)).times(1) + EasyMock.expect(driver.declineOffer(mesosOffers.get(2).getId)).andReturn(Status.valueOf(1)).times(1) EasyMock.replay(driver) - backend.resourceOffers(driver, offers) + + backend.resourceOffers(driver, mesosOffers) + + EasyMock.verify(driver) assert(capture.getValue.size() == 1) val taskInfo = capture.getValue.iterator().next() assert(taskInfo.getName.equals("n1")) @@ -90,5 +107,19 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea assert(cpus.getName.equals("cpus")) assert(cpus.getScalar.getValue.equals(2.0)) assert(taskInfo.getSlaveId.getValue.equals("s1")) + + // Unwanted resources offered on an existing node. Make sure they are declined + val mesosOffers2 = new java.util.ArrayList[Offer] + mesosOffers2.add(createOffer(1, minMem, minCpu)) + EasyMock.reset(taskScheduler) + EasyMock.reset(driver) + EasyMock.expect(taskScheduler.resourceOffers(EasyMock.anyObject(classOf[Seq[WorkerOffer]])).andReturn(Seq(Seq()))) + EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() + EasyMock.replay(taskScheduler) + EasyMock.expect(driver.declineOffer(mesosOffers2.get(0).getId)).andReturn(Status.valueOf(1)).times(1) + EasyMock.replay(driver) + + backend.resourceOffers(driver, mesosOffers2) + EasyMock.verify(driver) } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index bacf6a16fc233..d2857b8b55664 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -17,16 +17,20 @@ package org.apache.spark.ui -import org.apache.spark.api.java.StorageLevels -import org.apache.spark.{SparkException, SparkConf, SparkContext} -import org.openqa.selenium.WebDriver +import scala.collection.JavaConversions._ + +import org.openqa.selenium.{By, WebDriver} import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ +import org.apache.spark._ +import org.apache.spark.SparkContext._ import org.apache.spark.LocalSparkContext._ +import org.apache.spark.api.java.StorageLevels +import org.apache.spark.shuffle.FetchFailedException /** * Selenium tests for the Spark Web UI. These tests are not run by default @@ -89,7 +93,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { sc.parallelize(1 to 10).map { x => throw new Exception()}.collect() } eventually(timeout(5 seconds), interval(50 milliseconds)) { - go to sc.ui.get.appUIAddress + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") find(id("active")).get.text should be("Active Stages (0)") find(id("failed")).get.text should be("Failed Stages (1)") } @@ -101,7 +105,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { sc.parallelize(1 to 10).map { x => unserializableObject}.collect() } eventually(timeout(5 seconds), interval(50 milliseconds)) { - go to sc.ui.get.appUIAddress + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") find(id("active")).get.text should be("Active Stages (0)") // The failure occurs before the stage becomes active, hence we should still show only one // failed stage, not two: @@ -109,4 +113,191 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { } } } + + test("spark.ui.killEnabled should properly control kill button display") { + def getSparkContext(killEnabled: Boolean): SparkContext = { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.ui.enabled", "true") + .set("spark.ui.killEnabled", killEnabled.toString) + new SparkContext(conf) + } + + def hasKillLink = find(className("kill-link")).isDefined + def runSlowJob(sc: SparkContext) { + sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() + } + + withSpark(getSparkContext(killEnabled = true)) { sc => + runSlowJob(sc) + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") + assert(hasKillLink) + } + } + + withSpark(getSparkContext(killEnabled = false)) { sc => + runSlowJob(sc) + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") + assert(!hasKillLink) + } + } + } + + test("jobs page should not display job group name unless some job was submitted in a job group") { + withSpark(newSparkContext()) { sc => + // If no job has been run in a job group, then "(Job Group)" should not appear in the header + sc.parallelize(Seq(1, 2, 3)).count() + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs") + val tableHeaders = findAll(cssSelector("th")).map(_.text).toSeq + tableHeaders should not contain "Job Id (Job Group)" + } + // Once at least one job has been run in a job group, then we should display the group name: + sc.setJobGroup("my-job-group", "my-job-group-description") + sc.parallelize(Seq(1, 2, 3)).count() + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs") + val tableHeaders = findAll(cssSelector("th")).map(_.text).toSeq + tableHeaders should contain ("Job Id (Job Group)") + } + } + } + + test("job progress bars should handle stage / task failures") { + withSpark(newSparkContext()) { sc => + val data = sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity) + val shuffleHandle = + data.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle + // Simulate fetch failures: + val mappedData = data.map { x => + val taskContext = TaskContext.get + if (taskContext.attemptId() == 1) { // Cause this stage to fail on its first attempt. + val env = SparkEnv.get + val bmAddress = env.blockManager.blockManagerId + val shuffleId = shuffleHandle.shuffleId + val mapId = 0 + val reduceId = taskContext.partitionId() + val message = "Simulated fetch failure" + throw new FetchFailedException(bmAddress, shuffleId, mapId, reduceId, message) + } else { + x + } + } + mappedData.count() + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs") + find(cssSelector(".stage-progress-cell")).get.text should be ("2/2 (1 failed)") + // Ideally, the following test would pass, but currently we overcount completed tasks + // if task recomputations occur: + // find(cssSelector(".progress-cell .progress")).get.text should be ("2/2 (1 failed)") + // Instead, we guarantee that the total number of tasks is always correct, while the number + // of completed tasks may be higher: + find(cssSelector(".progress-cell .progress")).get.text should be ("3/2 (1 failed)") + } + } + } + + test("job details page should display useful information for stages that haven't started") { + withSpark(newSparkContext()) { sc => + // Create a multi-stage job with a long delay in the first stage: + val rdd = sc.parallelize(Seq(1, 2, 3)).map { x => + // This long sleep call won't slow down the tests because we don't actually need to wait + // for the job to finish. + Thread.sleep(20000) + }.groupBy(identity).map(identity).groupBy(identity).map(identity) + // Start the job: + rdd.countAsync() + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs/job/?id=0") + find(id("active")).get.text should be ("Active Stages (1)") + find(id("pending")).get.text should be ("Pending Stages (2)") + // Essentially, we want to check that none of the stage rows show + // "No data available for this stage". Checking for the absence of that string is brittle + // because someone could change the error message and cause this test to pass by accident. + // Instead, it's safer to check that each row contains a link to a stage details page. + findAll(cssSelector("tbody tr")).foreach { row => + val link = row.underlying.findElement(By.xpath(".//a")) + link.getAttribute("href") should include ("stage") + } + } + } + } + + test("job progress bars / cells reflect skipped stages / tasks") { + withSpark(newSparkContext()) { sc => + // Create an RDD that involves multiple stages: + val rdd = sc.parallelize(1 to 8, 8) + .map(x => x).groupBy((x: Int) => x, numPartitions = 8) + .flatMap(x => x._2).groupBy((x: Int) => x, numPartitions = 8) + // Run it twice; this will cause the second job to have two "phantom" stages that were + // mentioned in its job start event but which were never actually executed: + rdd.count() + rdd.count() + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs") + // The completed jobs table should have two rows. The first row will be the most recent job: + val firstRow = find(cssSelector("tbody tr")).get.underlying + val firstRowColumns = firstRow.findElements(By.tagName("td")) + firstRowColumns(0).getText should be ("1") + firstRowColumns(4).getText should be ("1/1 (2 skipped)") + firstRowColumns(5).getText should be ("8/8 (16 skipped)") + // The second row is the first run of the job, where nothing was skipped: + val secondRow = findAll(cssSelector("tbody tr")).toSeq(1).underlying + val secondRowColumns = secondRow.findElements(By.tagName("td")) + secondRowColumns(0).getText should be ("0") + secondRowColumns(4).getText should be ("3/3") + secondRowColumns(5).getText should be ("24/24") + } + } + } + + test("stages that aren't run appear as 'skipped stages' after a job finishes") { + withSpark(newSparkContext()) { sc => + // Create an RDD that involves multiple stages: + val rdd = + sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity) + // Run it twice; this will cause the second job to have two "phantom" stages that were + // mentioned in its job start event but which were never actually executed: + rdd.count() + rdd.count() + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs/job/?id=1") + find(id("pending")) should be (None) + find(id("active")) should be (None) + find(id("failed")) should be (None) + find(id("completed")).get.text should be ("Completed Stages (1)") + find(id("skipped")).get.text should be ("Skipped Stages (2)") + // Essentially, we want to check that none of the stage rows show + // "No data available for this stage". Checking for the absence of that string is brittle + // because someone could change the error message and cause this test to pass by accident. + // Instead, it's safer to check that each row contains a link to a stage details page. + findAll(cssSelector("tbody tr")).foreach { row => + val link = row.underlying.findElement(By.xpath(".//a")) + link.getAttribute("href") should include ("stage") + } + } + } + } + + test("jobs with stages that are skipped should show correct link descriptions on all jobs page") { + withSpark(newSparkContext()) { sc => + // Create an RDD that involves multiple stages: + val rdd = + sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity) + // Run it twice; this will cause the second job to have two "phantom" stages that were + // mentioned in its job start event but which were never actually executed: + rdd.count() + rdd.count() + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs") + findAll(cssSelector("tbody tr a")).foreach { link => + link.text.toLowerCase should include ("count") + link.text.toLowerCase should not include "unknown" + } + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 7c102cc7f4049..12af60caf7d54 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -28,32 +28,106 @@ import org.apache.spark.util.Utils class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers { - test("test LRU eviction of stages") { - val conf = new SparkConf() - conf.set("spark.ui.retainedStages", 5.toString) - val listener = new JobProgressListener(conf) - def createStageStartEvent(stageId: Int) = { - val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "") - SparkListenerStageSubmitted(stageInfo) + private def createStageStartEvent(stageId: Int) = { + val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "") + SparkListenerStageSubmitted(stageInfo) + } + + private def createStageEndEvent(stageId: Int, failed: Boolean = false) = { + val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "") + if (failed) { + stageInfo.failureReason = Some("Failed!") } + SparkListenerStageCompleted(stageInfo) + } - def createStageEndEvent(stageId: Int) = { - val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "") - SparkListenerStageCompleted(stageInfo) + private def createJobStartEvent(jobId: Int, stageIds: Seq[Int]) = { + val stageInfos = stageIds.map { stageId => + new StageInfo(stageId, 0, stageId.toString, 0, null, "") } + SparkListenerJobStart(jobId, stageInfos) + } + + private def createJobEndEvent(jobId: Int, failed: Boolean = false) = { + val result = if (failed) JobFailed(new Exception("dummy failure")) else JobSucceeded + SparkListenerJobEnd(jobId, result) + } + + private def runJob(listener: SparkListener, jobId: Int, shouldFail: Boolean = false) { + val stagesThatWontBeRun = jobId * 200 to jobId * 200 + 10 + val stageIds = jobId * 100 to jobId * 100 + 50 + listener.onJobStart(createJobStartEvent(jobId, stageIds ++ stagesThatWontBeRun)) + for (stageId <- stageIds) { + listener.onStageSubmitted(createStageStartEvent(stageId)) + listener.onStageCompleted(createStageEndEvent(stageId, failed = stageId % 2 == 0)) + } + listener.onJobEnd(createJobEndEvent(jobId, shouldFail)) + } + + private def assertActiveJobsStateIsEmpty(listener: JobProgressListener) { + listener.getSizesOfActiveStateTrackingCollections.foreach { case (fieldName, size) => + assert(size === 0, s"$fieldName was not empty") + } + } + + test("test LRU eviction of stages") { + val conf = new SparkConf() + conf.set("spark.ui.retainedStages", 5.toString) + val listener = new JobProgressListener(conf) for (i <- 1 to 50) { listener.onStageSubmitted(createStageStartEvent(i)) listener.onStageCompleted(createStageEndEvent(i)) } + assertActiveJobsStateIsEmpty(listener) listener.completedStages.size should be (5) - listener.completedStages.count(_.stageId == 50) should be (1) - listener.completedStages.count(_.stageId == 49) should be (1) - listener.completedStages.count(_.stageId == 48) should be (1) - listener.completedStages.count(_.stageId == 47) should be (1) - listener.completedStages.count(_.stageId == 46) should be (1) + listener.completedStages.map(_.stageId).toSet should be (Set(50, 49, 48, 47, 46)) + } + + test("test LRU eviction of jobs") { + val conf = new SparkConf() + conf.set("spark.ui.retainedStages", 5.toString) + conf.set("spark.ui.retainedJobs", 5.toString) + val listener = new JobProgressListener(conf) + + // Run a bunch of jobs to get the listener into a state where we've exceeded both the + // job and stage retention limits: + for (jobId <- 1 to 10) { + runJob(listener, jobId, shouldFail = false) + } + for (jobId <- 200 to 210) { + runJob(listener, jobId, shouldFail = true) + } + assertActiveJobsStateIsEmpty(listener) + // Snapshot the sizes of various soft- and hard-size-limited collections: + val softLimitSizes = listener.getSizesOfSoftSizeLimitedCollections + val hardLimitSizes = listener.getSizesOfHardSizeLimitedCollections + // Run some more jobs: + for (jobId <- 11 to 50) { + runJob(listener, jobId, shouldFail = false) + // We shouldn't exceed the hard / soft limit sizes after the jobs have finished: + listener.getSizesOfSoftSizeLimitedCollections should be (softLimitSizes) + listener.getSizesOfHardSizeLimitedCollections should be (hardLimitSizes) + } + + listener.completedJobs.size should be (5) + listener.completedJobs.map(_.jobId).toSet should be (Set(50, 49, 48, 47, 46)) + + for (jobId <- 51 to 100) { + runJob(listener, jobId, shouldFail = true) + // We shouldn't exceed the hard / soft limit sizes after the jobs have finished: + listener.getSizesOfSoftSizeLimitedCollections should be (softLimitSizes) + listener.getSizesOfHardSizeLimitedCollections should be (hardLimitSizes) + } + assertActiveJobsStateIsEmpty(listener) + + // Completed and failed jobs each their own size limits, so this should still be the same: + listener.completedJobs.size should be (5) + listener.completedJobs.map(_.jobId).toSet should be (Set(50, 49, 48, 47, 46)) + listener.failedJobs.size should be (5) + listener.failedJobs.map(_.jobId).toSet should be (Set(100, 99, 98, 97, 96)) } test("test executor id to summary") { diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 50f42054b9296..593d6dd8c3794 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.util import java.util.Properties +import org.apache.spark.shuffle.MetadataFetchFailedException + import scala.collection.Map import org.json4s.jackson.JsonMethods._ @@ -47,7 +49,12 @@ class JsonProtocolSuite extends FunSuite { val taskEndWithOutput = SparkListenerTaskEnd(1, 0, "ResultTask", Success, makeTaskInfo(123L, 234, 67, 345L, false), makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = true, hasOutput = true)) - val jobStart = SparkListenerJobStart(10, Seq[Int](1, 2, 3, 4), properties) + val jobStart = { + val stageIds = Seq[Int](1, 2, 3, 4) + val stageInfos = stageIds.map(x => + makeStageInfo(x, x * 200, x * 300, x * 400L, x * 500L)) + SparkListenerJobStart(10, stageInfos, properties) + } val jobEnd = SparkListenerJobEnd(20, JobSucceeded) val environmentUpdate = SparkListenerEnvironmentUpdate(Map[String, Seq[(String, String)]]( "JVM Information" -> Seq(("GC speed", "9999 objects/s"), ("Java home", "Land of coffee")), @@ -111,10 +118,13 @@ class JsonProtocolSuite extends FunSuite { // TaskEndReason val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, "Some exception") + val fetchMetadataFailed = new MetadataFetchFailedException(17, + 19, "metadata Fetch failed exception").toTaskEndReason val exceptionFailure = new ExceptionFailure(exception, None) testTaskEndReason(Success) testTaskEndReason(Resubmitted) testTaskEndReason(fetchFailed) + testTaskEndReason(fetchMetadataFailed) testTaskEndReason(exceptionFailure) testTaskEndReason(TaskResultLost) testTaskEndReason(TaskKilled) @@ -224,6 +234,19 @@ class JsonProtocolSuite extends FunSuite { assert(expectedExecutorLostFailure === JsonProtocol.taskEndReasonFromJson(oldEvent)) } + test("SparkListenerJobStart backward compatibility") { + // Prior to Spark 1.2.0, SparkListenerJobStart did not have a "Stage Infos" property. + val stageIds = Seq[Int](1, 2, 3, 4) + val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400, x * 500)) + val dummyStageInfos = + stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, "unknown")) + val jobStart = SparkListenerJobStart(10, stageInfos, properties) + val oldEvent = JsonProtocol.jobStartToJson(jobStart).removeField({_._1 == "Stage Infos"}) + val expectedJobStart = + SparkListenerJobStart(10, dummyStageInfos, properties) + assertEquals(expectedJobStart, JsonProtocol.jobStartFromJson(oldEvent)) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ @@ -306,7 +329,7 @@ class JsonProtocolSuite extends FunSuite { case (e1: SparkListenerJobStart, e2: SparkListenerJobStart) => assert(e1.jobId === e2.jobId) assert(e1.properties === e2.properties) - assertSeqEquals(e1.stageIds, e2.stageIds, (i1: Int, i2: Int) => assert(i1 === i2)) + assert(e1.stageIds === e2.stageIds) case (e1: SparkListenerJobEnd, e2: SparkListenerJobEnd) => assert(e1.jobId === e2.jobId) assertEquals(e1.jobResult, e2.jobResult) @@ -413,9 +436,13 @@ class JsonProtocolSuite extends FunSuite { } private def assertEquals(bm1: BlockManagerId, bm2: BlockManagerId) { - assert(bm1.executorId === bm2.executorId) - assert(bm1.host === bm2.host) - assert(bm1.port === bm2.port) + if (bm1 == null || bm2 == null) { + assert(bm1 === bm2) + } else { + assert(bm1.executorId === bm2.executorId) + assert(bm1.host === bm2.host) + assert(bm1.port === bm2.port) + } } private def assertEquals(result1: JobResult, result2: JobResult) { @@ -1051,6 +1078,260 @@ class JsonProtocolSuite extends FunSuite { |{ | "Event": "SparkListenerJobStart", | "Job ID": 10, + | "Stage Infos": [ + | { + | "Stage ID": 1, + | "Stage Attempt ID": 0, + | "Stage Name": "greetings", + | "Number of Tasks": 200, + | "RDD Info": [ + | { + | "RDD ID": 1, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 200, + | "Number of Cached Partitions": 300, + | "Memory Size": 400, + | "Tachyon Size": 0, + | "Disk Size": 500 + | } + | ], + | "Details": "details", + | "Accumulables": [ + | { + | "ID": 2, + | "Name": " Accumulable 2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 1, + | "Name": " Accumulable 1", + | "Update": "delta1", + | "Value": "val1" + | } + | ] + | }, + | { + | "Stage ID": 2, + | "Stage Attempt ID": 0, + | "Stage Name": "greetings", + | "Number of Tasks": 400, + | "RDD Info": [ + | { + | "RDD ID": 2, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 400, + | "Number of Cached Partitions": 600, + | "Memory Size": 800, + | "Tachyon Size": 0, + | "Disk Size": 1000 + | }, + | { + | "RDD ID": 3, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 401, + | "Number of Cached Partitions": 601, + | "Memory Size": 801, + | "Tachyon Size": 0, + | "Disk Size": 1001 + | } + | ], + | "Details": "details", + | "Accumulables": [ + | { + | "ID": 2, + | "Name": " Accumulable 2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 1, + | "Name": " Accumulable 1", + | "Update": "delta1", + | "Value": "val1" + | } + | ] + | }, + | { + | "Stage ID": 3, + | "Stage Attempt ID": 0, + | "Stage Name": "greetings", + | "Number of Tasks": 600, + | "RDD Info": [ + | { + | "RDD ID": 3, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 600, + | "Number of Cached Partitions": 900, + | "Memory Size": 1200, + | "Tachyon Size": 0, + | "Disk Size": 1500 + | }, + | { + | "RDD ID": 4, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 601, + | "Number of Cached Partitions": 901, + | "Memory Size": 1201, + | "Tachyon Size": 0, + | "Disk Size": 1501 + | }, + | { + | "RDD ID": 5, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 602, + | "Number of Cached Partitions": 902, + | "Memory Size": 1202, + | "Tachyon Size": 0, + | "Disk Size": 1502 + | } + | ], + | "Details": "details", + | "Accumulables": [ + | { + | "ID": 2, + | "Name": " Accumulable 2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 1, + | "Name": " Accumulable 1", + | "Update": "delta1", + | "Value": "val1" + | } + | ] + | }, + | { + | "Stage ID": 4, + | "Stage Attempt ID": 0, + | "Stage Name": "greetings", + | "Number of Tasks": 800, + | "RDD Info": [ + | { + | "RDD ID": 4, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 800, + | "Number of Cached Partitions": 1200, + | "Memory Size": 1600, + | "Tachyon Size": 0, + | "Disk Size": 2000 + | }, + | { + | "RDD ID": 5, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 801, + | "Number of Cached Partitions": 1201, + | "Memory Size": 1601, + | "Tachyon Size": 0, + | "Disk Size": 2001 + | }, + | { + | "RDD ID": 6, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 802, + | "Number of Cached Partitions": 1202, + | "Memory Size": 1602, + | "Tachyon Size": 0, + | "Disk Size": 2002 + | }, + | { + | "RDD ID": 7, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 803, + | "Number of Cached Partitions": 1203, + | "Memory Size": 1603, + | "Tachyon Size": 0, + | "Disk Size": 2003 + | } + | ], + | "Details": "details", + | "Accumulables": [ + | { + | "ID": 2, + | "Name": " Accumulable 2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 1, + | "Name": " Accumulable 1", + | "Update": "delta1", + | "Value": "val1" + | } + | ] + | } + | ], | "Stage IDs": [ | 1, | 2, diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index f26e40fbd4b36..3cb42d416de4f 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -127,6 +127,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe test("empty partitions with spilling") { val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.spill.initialMemoryThreshold", "512") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -152,6 +153,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe test("empty partitions with spilling, bypass merge-sort") { val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.spill.initialMemoryThreshold", "512") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) diff --git a/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala b/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala new file mode 100644 index 0000000000000..4918e2d92beb4 --- /dev/null +++ b/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.sparktest + +/** + * A test suite to make sure all `implicit` functions work correctly. + * Please don't `import org.apache.spark.SparkContext._` in this class. + * + * As `implicit` is a compiler feature, we don't need to run this class. + * What we need to do is making the compiler happy. + */ +class ImplicitSuite { + + // We only want to test if `implict` works well with the compiler, so we don't need a real + // SparkContext. + def mockSparkContext[T]: org.apache.spark.SparkContext = null + + // We only want to test if `implict` works well with the compiler, so we don't need a real RDD. + def mockRDD[T]: org.apache.spark.rdd.RDD[T] = null + + def testRddToPairRDDFunctions(): Unit = { + val rdd: org.apache.spark.rdd.RDD[(Int, Int)] = mockRDD + rdd.groupByKey() + } + + def testRddToAsyncRDDActions(): Unit = { + val rdd: org.apache.spark.rdd.RDD[Int] = mockRDD + rdd.countAsync() + } + + def testRddToSequenceFileRDDFunctions(): Unit = { + // TODO eliminating `import intToIntWritable` needs refactoring SequenceFileRDDFunctions. + // That will be a breaking change. + import org.apache.spark.SparkContext.intToIntWritable + val rdd: org.apache.spark.rdd.RDD[(Int, Int)] = mockRDD + rdd.saveAsSequenceFile("/a/test/path") + } + + def testRddToOrderedRDDFunctions(): Unit = { + val rdd: org.apache.spark.rdd.RDD[(Int, Int)] = mockRDD + rdd.sortByKey() + } + + def testDoubleRDDToDoubleRDDFunctions(): Unit = { + val rdd: org.apache.spark.rdd.RDD[Double] = mockRDD + rdd.stats() + } + + def testNumericRDDToDoubleRDDFunctions(): Unit = { + val rdd: org.apache.spark.rdd.RDD[Int] = mockRDD + rdd.stats() + } + + def testDoubleAccumulatorParam(): Unit = { + val sc = mockSparkContext + sc.accumulator(123.4) + } + + def testIntAccumulatorParam(): Unit = { + val sc = mockSparkContext + sc.accumulator(123) + } + + def testLongAccumulatorParam(): Unit = { + val sc = mockSparkContext + sc.accumulator(123L) + } + + def testFloatAccumulatorParam(): Unit = { + val sc = mockSparkContext + sc.accumulator(123F) + } + + def testIntWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Int, Int]("/a/test/path") + } + + def testLongWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Long, Long]("/a/test/path") + } + + def testDoubleWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Double, Double]("/a/test/path") + } + + def testFloatWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Float, Float]("/a/test/path") + } + + def testBooleanWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Boolean, Boolean]("/a/test/path") + } + + def testBytesWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Array[Byte], Array[Byte]]("/a/test/path") + } + + def testStringWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[String, String]("/a/test/path") + } + + def testWritableWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[org.apache.hadoop.io.Text, org.apache.hadoop.io.Text]("/a/test/path") + } +} diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index a6e90a15ee84b..e0aca467ac949 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -28,13 +28,19 @@ # - Send output to stderr and have useful logging in stdout # Note: The following variables must be set before use! -GIT_USERNAME=${GIT_USERNAME:-pwendell} -GIT_PASSWORD=${GIT_PASSWORD:-XXX} +ASF_USERNAME=${ASF_USERNAME:-pwendell} +ASF_PASSWORD=${ASF_PASSWORD:-XXX} GPG_PASSPHRASE=${GPG_PASSPHRASE:-XXX} GIT_BRANCH=${GIT_BRANCH:-branch-1.0} -RELEASE_VERSION=${RELEASE_VERSION:-1.0.0} +RELEASE_VERSION=${RELEASE_VERSION:-1.2.0} +NEXT_VERSION=${NEXT_VERSION:-1.2.1} RC_NAME=${RC_NAME:-rc2} -USER_NAME=${USER_NAME:-pwendell} + +M2_REPO=~/.m2/repository +SPARK_REPO=$M2_REPO/org/apache/spark +NEXUS_ROOT=https://repository.apache.org/service/local/staging +NEXUS_UPLOAD=$NEXUS_ROOT/deploy/maven2 +NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads if [ -z "$JAVA_HOME" ]; then echo "Error: JAVA_HOME is not set, cannot proceed." @@ -47,31 +53,90 @@ set -e GIT_TAG=v$RELEASE_VERSION-$RC_NAME if [[ ! "$@" =~ --package-only ]]; then - echo "Creating and publishing release" + echo "Creating release commit and publishing to Apache repository" # Artifact publishing - git clone https://git-wip-us.apache.org/repos/asf/spark.git -b $GIT_BRANCH - cd spark + git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git \ + -b $GIT_BRANCH + pushd spark export MAVEN_OPTS="-Xmx3g -XX:MaxPermSize=1g -XX:ReservedCodeCacheSize=1g" - mvn -Pyarn release:clean - - mvn -DskipTests \ - -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ - -Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \ - -Dmaven.javadoc.skip=true \ - -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Dtag=$GIT_TAG -DautoVersionSubmodules=true \ - -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - --batch-mode release:prepare - - mvn -DskipTests \ - -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ - -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Dmaven.javadoc.skip=true \ + # Create release commits and push them to github + # NOTE: This is done "eagerly" i.e. we don't check if we can succesfully build + # or before we coin the release commit. This helps avoid races where + # other people add commits to this branch while we are in the middle of building. + old=" ${RELEASE_VERSION}-SNAPSHOT<\/version>" + new=" ${RELEASE_VERSION}<\/version>" + find . -name pom.xml -o -name package.scala | grep -v dev | xargs -I {} sed -i \ + -e "s/$old/$new/" {} + git commit -a -m "Preparing Spark release $GIT_TAG" + echo "Creating tag $GIT_TAG at the head of $GIT_BRANCH" + git tag $GIT_TAG + + old=" ${RELEASE_VERSION}<\/version>" + new=" ${NEXT_VERSION}-SNAPSHOT<\/version>" + find . -name pom.xml -o -name package.scala | grep -v dev | xargs -I {} sed -i \ + -e "s/$old/$new/" {} + git commit -a -m "Preparing development version ${NEXT_VERSION}-SNAPSHOT" + git push origin $GIT_TAG + git push origin HEAD:$GIT_BRANCH + git checkout -f $GIT_TAG + + # Using Nexus API documented here: + # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API + echo "Creating Nexus staging repository" + repo_request="Apache Spark $GIT_TAG" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) + staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") + echo "Created Nexus staging repository: $staged_repo_id" + + rm -rf $SPARK_REPO + + mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - release:perform + clean install - cd .. + ./dev/change-version-to-2.11.sh + + mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ + -Dscala-2.11 -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ + clean install + + ./dev/change-version-to-2.10.sh + + pushd $SPARK_REPO + + # Remove any extra files generated during install + find . -type f |grep -v \.jar |grep -v \.pom | xargs rm + + echo "Creating hash and signature files" + for file in $(find . -type f) + do + echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file; + gpg --print-md MD5 $file > $file.md5; + gpg --print-md SHA1 $file > $file.sha1 + done + + echo "Uplading files to $NEXUS_UPLOAD" + for file in $(find . -type f) + do + # strip leading ./ + file_short=$(echo $file | sed -e "s/\.\///") + dest_url="$NEXUS_UPLOAD/org/apache/spark/$file_short" + echo " Uploading $file_short" + curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url + done + + echo "Closing nexus staging repository" + repo_request="$staged_repo_idApache Spark $GIT_TAG" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) + echo "Closed Nexus staging repository: $staged_repo_id" + + popd + popd rm -rf spark fi @@ -102,6 +167,12 @@ make_binary_release() { cp -r spark spark-$RELEASE_VERSION-bin-$NAME cd spark-$RELEASE_VERSION-bin-$NAME + + # TODO There should probably be a flag to make-distribution to allow 2.11 support + if [[ $FLAGS == *scala-2.11* ]]; then + ./dev/change-version-to-2.11.sh + fi + ./make-distribution.sh --name $NAME --tgz $FLAGS 2>&1 | tee ../binary-release-$NAME.log cd .. cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . @@ -118,22 +189,24 @@ make_binary_release() { spark-$RELEASE_VERSION-bin-$NAME.tgz.sha } + make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" & +make_binary_release "hadoop1-scala2.11" "-Phive -Dscala-2.11" & make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" & make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" & -make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" & make_binary_release "mapr3" "-Pmapr3 -Phive -Phive-thriftserver" & make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive -Phive-thriftserver" & +make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" & wait # Copy data echo "Copying release tarballs" rc_folder=spark-$RELEASE_VERSION-$RC_NAME -ssh $USER_NAME@people.apache.org \ - mkdir /home/$USER_NAME/public_html/$rc_folder +ssh $ASF_USERNAME@people.apache.org \ + mkdir /home/$ASF_USERNAME/public_html/$rc_folder scp spark-* \ - $USER_NAME@people.apache.org:/home/$USER_NAME/public_html/$rc_folder/ + $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_folder/ # Docs cd spark @@ -143,12 +216,12 @@ cd docs JAVA_HOME=$JAVA_7_HOME PRODUCTION=1 jekyll build echo "Copying release documentation" rc_docs_folder=${rc_folder}-docs -ssh $USER_NAME@people.apache.org \ - mkdir /home/$USER_NAME/public_html/$rc_docs_folder -rsync -r _site/* $USER_NAME@people.apache.org:/home/$USER_NAME/public_html/$rc_docs_folder +ssh $ASF_USERNAME@people.apache.org \ + mkdir /home/$ASF_USERNAME/public_html/$rc_docs_folder +rsync -r _site/* $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_docs_folder echo "Release $RELEASE_VERSION completed:" echo "Git tag:\t $GIT_TAG" echo "Release commit:\t $release_hash" -echo "Binary location:\t http://people.apache.org/~$USER_NAME/$rc_folder" -echo "Doc location:\t http://people.apache.org/~$USER_NAME/$rc_docs_folder" +echo "Binary location:\t http://people.apache.org/~$ASF_USERNAME/$rc_folder" +echo "Doc location:\t http://people.apache.org/~$ASF_USERNAME/$rc_docs_folder" diff --git a/docs/README.md b/docs/README.md index d2d58e435d4c4..119484038083f 100644 --- a/docs/README.md +++ b/docs/README.md @@ -43,7 +43,7 @@ You can modify the default Jekyll build as follows: ## Pygments We also use pygments (http://pygments.org) for syntax highlighting in documentation markdown pages, -so you will also need to install that (it requires Python) by running `sudo easy_install Pygments`. +so you will also need to install that (it requires Python) by running `sudo pip install Pygments`. To mark a block of code in your markdown to be syntax highlighted by jekyll during the compile phase, use the following sytax: @@ -53,6 +53,11 @@ phase, use the following sytax: // supported languages too. {% endhighlight %} +## Sphinx + +We use Sphinx to generate Python API docs, so you will need to install it by running +`sudo pip install sphinx`. + ## API Docs (Scaladoc and Sphinx) You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory. diff --git a/docs/_config.yml b/docs/_config.yml index cdea02fcffbc5..a96a76dd9ab5e 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -13,8 +13,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.2.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.2.0 +SPARK_VERSION: 1.3.0-SNAPSHOT +SPARK_VERSION_SHORT: 1.3.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.18.1 diff --git a/docs/building-spark.md b/docs/building-spark.md index bb18414092aae..40a47410e683a 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -92,8 +92,11 @@ mvn -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -DskipTests clean package # Apache Hadoop 2.3.X mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -DskipTests clean package -# Apache Hadoop 2.4.X -mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package +# Apache Hadoop 2.4.X or 2.5.X +mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=VERSION -DskipTests clean package + +Versions of Hadoop after 2.5.X may or may not work with the -Phadoop-2.4 profile (they were +released after this version of Spark). # Different versions of HDFS and YARN. mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -DskipTests clean package @@ -109,7 +112,7 @@ Hive 0.12.0 using the `-Phive-0.12.0` profile. mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package # Apache Hadoop 2.4.X with Hive 12 support -mvn -Pyarn -Phive -Phive-thriftserver-0.12.0 -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package +mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-0.12.0 -Phive-thriftserver -DskipTests clean package {% endhighlight %} # Building for Scala 2.11 diff --git a/docs/configuration.md b/docs/configuration.md index 8839162c3a13e..0b77f5ab645c9 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -224,6 +224,7 @@ Apart from these, the following properties are also available, and may be useful (Experimental) Whether to give user-added jars precedence over Spark's own jars when loading classes in Executors. This feature can be used to mitigate conflicts between Spark's dependencies and user dependencies. It is currently an experimental feature. + (Currently, this setting does not work for YARN, see SPARK-2996 for more details). diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index 530798f2b8022..66bf5f1a855ed 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -12,16 +12,14 @@ on the [Amazon Web Services site](http://aws.amazon.com/). `spark-ec2` is designed to manage multiple named clusters. You can launch a new cluster (telling the script its size and giving it a name), -shutdown an existing cluster, or log into a cluster. Each cluster -launches a set of instances, which are tagged with the cluster name, -and placed into EC2 security groups. If you don't specify a security -group, the `spark-ec2` script will create security groups based on the -cluster name you request. For example, a cluster named +shutdown an existing cluster, or log into a cluster. Each cluster is +identified by placing its machines into EC2 security groups whose names +are derived from the name of the cluster. For example, a cluster named `test` will contain a master node in a security group called `test-master`, and a number of slave nodes in a security group called -`test-slaves`. You can also specify a security group prefix to be used -in place of the cluster name. Machines in a cluster can be identified -by looking for the "Name" tag of the instance in the Amazon EC2 Console. +`test-slaves`. The `spark-ec2` script will create these security groups +for you based on the cluster name you request. You can also use them to +identify machines belonging to each cluster in the Amazon EC2 Console. # Before You Start diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index fdb9f98e214e5..e298c51f8a5b7 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -6,6 +6,47 @@ title: GraphX Programming Guide * This will become a table of contents (this text will be scraped). {:toc} + + +[EdgeRDD]: api/scala/index.html#org.apache.spark.graphx.EdgeRDD +[Edge]: api/scala/index.html#org.apache.spark.graphx.Edge +[EdgeTriplet]: api/scala/index.html#org.apache.spark.graphx.EdgeTriplet +[Graph]: api/scala/index.html#org.apache.spark.graphx.Graph +[GraphOps]: api/scala/index.html#org.apache.spark.graphx.GraphOps +[Graph.mapVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@mapVertices[VD2]((VertexId,VD)⇒VD2)(ClassTag[VD2]):Graph[VD2,ED] +[Graph.reverse]: api/scala/index.html#org.apache.spark.graphx.Graph@reverse:Graph[VD,ED] +[Graph.subgraph]: api/scala/index.html#org.apache.spark.graphx.Graph@subgraph((EdgeTriplet[VD,ED])⇒Boolean,(VertexId,VD)⇒Boolean):Graph[VD,ED] +[Graph.mask]: api/scala/index.html#org.apache.spark.graphx.Graph@mask[VD2,ED2](Graph[VD2,ED2])(ClassTag[VD2],ClassTag[ED2]):Graph[VD,ED] +[Graph.groupEdges]: api/scala/index.html#org.apache.spark.graphx.Graph@groupEdges((ED,ED)⇒ED):Graph[VD,ED] +[GraphOps.joinVertices]: api/scala/index.html#org.apache.spark.graphx.GraphOps@joinVertices[U](RDD[(VertexId,U)])((VertexId,VD,U)⇒VD)(ClassTag[U]):Graph[VD,ED] +[Graph.outerJoinVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@outerJoinVertices[U,VD2](RDD[(VertexId,U)])((VertexId,VD,Option[U])⇒VD2)(ClassTag[U],ClassTag[VD2]):Graph[VD2,ED] +[Graph.aggregateMessages]: api/scala/index.html#org.apache.spark.graphx.Graph@aggregateMessages[A]((EdgeContext[VD,ED,A])⇒Unit,(A,A)⇒A,TripletFields)(ClassTag[A]):VertexRDD[A] +[EdgeContext]: api/scala/index.html#org.apache.spark.graphx.EdgeContext +[Graph.mapReduceTriplets]: api/scala/index.html#org.apache.spark.graphx.Graph@mapReduceTriplets[A](mapFunc:org.apache.spark.graphx.EdgeTriplet[VD,ED]=>Iterator[(org.apache.spark.graphx.VertexId,A)],reduceFunc:(A,A)=>A,activeSetOpt:Option[(org.apache.spark.graphx.VertexRDD[_],org.apache.spark.graphx.EdgeDirection)])(implicitevidence$10:scala.reflect.ClassTag[A]):org.apache.spark.graphx.VertexRDD[A] +[GraphOps.collectNeighborIds]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighborIds(EdgeDirection):VertexRDD[Array[VertexId]] +[GraphOps.collectNeighbors]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighbors(EdgeDirection):VertexRDD[Array[(VertexId,VD)]] +[RDD Persistence]: programming-guide.html#rdd-persistence +[Graph.cache]: api/scala/index.html#org.apache.spark.graphx.Graph@cache():Graph[VD,ED] +[GraphOps.pregel]: api/scala/index.html#org.apache.spark.graphx.GraphOps@pregel[A](A,Int,EdgeDirection)((VertexId,VD,A)⇒VD,(EdgeTriplet[VD,ED])⇒Iterator[(VertexId,A)],(A,A)⇒A)(ClassTag[A]):Graph[VD,ED] +[PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy$ +[GraphLoader.edgeListFile]: api/scala/index.html#org.apache.spark.graphx.GraphLoader$@edgeListFile(SparkContext,String,Boolean,Int):Graph[Int,Int] +[Graph.apply]: api/scala/index.html#org.apache.spark.graphx.Graph$@apply[VD,ED](RDD[(VertexId,VD)],RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED] +[Graph.fromEdgeTuples]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdgeTuples[VD](RDD[(VertexId,VertexId)],VD,Option[PartitionStrategy])(ClassTag[VD]):Graph[VD,Int] +[Graph.fromEdges]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdges[VD,ED](RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED] +[PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy +[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph$@partitionBy(partitionStrategy:org.apache.spark.graphx.PartitionStrategy):org.apache.spark.graphx.Graph[VD,ED] +[PageRank]: api/scala/index.html#org.apache.spark.graphx.lib.PageRank$ +[ConnectedComponents]: api/scala/index.html#org.apache.spark.graphx.lib.ConnectedComponents$ +[TriangleCount]: api/scala/index.html#org.apache.spark.graphx.lib.TriangleCount$ +[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph@partitionBy(PartitionStrategy):Graph[VD,ED] +[EdgeContext.sendToSrc]: api/scala/index.html#org.apache.spark.graphx.EdgeContext@sendToSrc(msg:A):Unit +[EdgeContext.sendToDst]: api/scala/index.html#org.apache.spark.graphx.EdgeContext@sendToDst(msg:A):Unit +[TripletFields]: api/java/org/apache/spark/graphx/TripletFields.html +[TripletFields.All]: api/java/org/apache/spark/graphx/TripletFields.html#All +[TripletFields.None]: api/java/org/apache/spark/graphx/TripletFields.html#None +[TripletFields.Src]: api/java/org/apache/spark/graphx/TripletFields.html#Src +[TripletFields.Dst]: api/java/org/apache/spark/graphx/TripletFields.html#Dst +

- Data-Parallel vs. Graph-Parallel - -

+1. To improve performance we have introduced a new version of +[`mapReduceTriplets`][Graph.mapReduceTriplets] called +[`aggregateMessages`][Graph.aggregateMessages] which takes the messages previously returned from +[`mapReduceTriplets`][Graph.mapReduceTriplets] through a callback ([`EdgeContext`][EdgeContext]) +rather than by return value. +We are deprecating [`mapReduceTriplets`][Graph.mapReduceTriplets] and encourage users to consult +the [transition guide](#mrTripletsTransition). -However, the same restrictions that enable these substantial performance gains also make it -difficult to express many of the important stages in a typical graph-analytics pipeline: -constructing the graph, modifying its structure, or expressing computation that spans multiple -graphs. Furthermore, how we look at data depends on our objectives and the same raw data may have -many different table and graph views. - -

- Tables and Graphs - -

- -As a consequence, it is often necessary to be able to move between table and graph views of the same -physical data and to leverage the properties of each view to easily and efficiently express -computation. However, existing graph analytics pipelines must compose graph-parallel and data- -parallel systems, leading to extensive data movement and duplication and a complicated programming -model. - -

- Graph Analytics Pipeline - -

- -The goal of the GraphX project is to unify graph-parallel and data-parallel computation in one -system with a single composable API. The GraphX API enables users to view data both as a graph and -as collections (i.e., RDDs) without data movement or duplication. By incorporating recent advances -in graph-parallel systems, GraphX is able to optimize the execution of graph operations. - -## GraphX Replaces the Spark Bagel API - -Prior to the release of GraphX, graph computation in Spark was expressed using Bagel, an -implementation of Pregel. GraphX improves upon Bagel by exposing a richer property graph API, a -more streamlined version of the Pregel abstraction, and system optimizations to improve performance -and reduce memory overhead. While we plan to eventually deprecate Bagel, we will continue to -support the [Bagel API](api/scala/index.html#org.apache.spark.bagel.package) and -[Bagel programming guide](bagel-programming-guide.html). However, we encourage Bagel users to -explore the new GraphX API and comment on issues that may complicate the transition from Bagel. - -## Migrating from Spark 0.9.1 - -GraphX in Spark {{site.SPARK_VERSION}} contains one user-facing interface change from Spark 0.9.1. [`EdgeRDD`][EdgeRDD] may now store adjacent vertex attributes to construct the triplets, so it has gained a type parameter. The edges of a graph of type `Graph[VD, ED]` are of type `EdgeRDD[ED, VD]` rather than `EdgeRDD[ED]`. - -[EdgeRDD]: api/scala/index.html#org.apache.spark.graphx.EdgeRDD +2. In Spark 1.0 and 1.1, the type signature of [`EdgeRDD`][EdgeRDD] switched from +`EdgeRDD[ED]` to `EdgeRDD[ED, VD]` to enable some caching optimizations. We have since discovered +a more elegant solution and have restored the type signature to the more natural `EdgeRDD[ED]` type. # Getting Started @@ -108,9 +96,10 @@ import org.apache.spark.rdd.RDD If you are not using the Spark shell you will also need a `SparkContext`. To learn more about getting started with Spark refer to the [Spark Quick Start Guide](quick-start.html). -# The Property Graph +# The Property Graph + The [property graph](api/scala/index.html#org.apache.spark.graphx.Graph) is a directed multigraph with user defined objects attached to each vertex and edge. A directed multigraph is a directed graph with potentially multiple parallel edges sharing the same source and destination vertex. The @@ -123,7 +112,7 @@ identifiers. The property graph is parameterized over the vertex (`VD`) and edge (`ED`) types. These are the types of the objects associated with each vertex and edge respectively. -> GraphX optimizes the representation of vertex and edge types when they are plain old data-types +> GraphX optimizes the representation of vertex and edge types when they are primitive data types > (e.g., int, double, etc...) reducing the in memory footprint by storing them in specialized > arrays. @@ -142,8 +131,8 @@ var graph: Graph[VertexProperty, String] = null Like RDDs, property graphs are immutable, distributed, and fault-tolerant. Changes to the values or structure of the graph are accomplished by producing a new graph with the desired changes. Note that substantial parts of the original graph (i.e., unaffected structure, attributes, and indicies) -are reused in the new graph reducing the cost of this inherently functional data-structure. The -graph is partitioned across the executors using a range of vertex-partitioning heuristics. As with +are reused in the new graph reducing the cost of this inherently functional data structure. The +graph is partitioned across the executors using a range of vertex partitioning heuristics. As with RDDs, each partition of the graph can be recreated on a different machine in the event of a failure. Logically the property graph corresponds to a pair of typed collections (RDDs) encoding the @@ -153,12 +142,12 @@ the vertices and edges of the graph: {% highlight scala %} class Graph[VD, ED] { val vertices: VertexRDD[VD] - val edges: EdgeRDD[ED, VD] + val edges: EdgeRDD[ED] } {% endhighlight %} -The classes `VertexRDD[VD]` and `EdgeRDD[ED, VD]` extend and are optimized versions of `RDD[(VertexID, -VD)]` and `RDD[Edge[ED]]` respectively. Both `VertexRDD[VD]` and `EdgeRDD[ED, VD]` provide additional +The classes `VertexRDD[VD]` and `EdgeRDD[ED]` extend and are optimized versions of `RDD[(VertexID, +VD)]` and `RDD[Edge[ED]]` respectively. Both `VertexRDD[VD]` and `EdgeRDD[ED]` provide additional functionality built around graph computation and leverage internal optimizations. We discuss the `VertexRDD` and `EdgeRDD` API in greater detail in the section on [vertex and edge RDDs](#vertex_and_edge_rdds) but for now they can be thought of as simply RDDs of the form: @@ -211,7 +200,6 @@ In the above example we make use of the [`Edge`][Edge] case class. Edges have a `dstId` corresponding to the source and destination vertex identifiers. In addition, the `Edge` class has an `attr` member which stores the edge property. -[Edge]: api/scala/index.html#org.apache.spark.graphx.Edge We can deconstruct a graph into the respective vertex and edge views by using the `graph.vertices` and `graph.edges` members respectively. @@ -237,7 +225,6 @@ The triplet view logically joins the vertex and edge properties yielding an `RDD[EdgeTriplet[VD, ED]]` containing instances of the [`EdgeTriplet`][EdgeTriplet] class. This *join* can be expressed in the following SQL expression: -[EdgeTriplet]: api/scala/index.html#org.apache.spark.graphx.EdgeTriplet {% highlight sql %} SELECT src.id, dst.id, src.attr, e.attr, dst.attr @@ -278,9 +265,6 @@ core operators are defined in [`GraphOps`][GraphOps]. However, thanks to Scala operators in `GraphOps` are automatically available as members of `Graph`. For example, we can compute the in-degree of each vertex (defined in `GraphOps`) by the following: -[Graph]: api/scala/index.html#org.apache.spark.graphx.Graph -[GraphOps]: api/scala/index.html#org.apache.spark.graphx.GraphOps - {% highlight scala %} val graph: Graph[(String, String), String] // Use the implicit GraphOps.inDegrees operator @@ -310,7 +294,7 @@ class Graph[VD, ED] { val degrees: VertexRDD[Int] // Views of the graph as collections ============================================================= val vertices: VertexRDD[VD] - val edges: EdgeRDD[ED, VD] + val edges: EdgeRDD[ED] val triplets: RDD[EdgeTriplet[VD, ED]] // Functions for caching graphs ================================================================== def persist(newLevel: StorageLevel = StorageLevel.MEMORY_ONLY): Graph[VD, ED] @@ -341,10 +325,10 @@ class Graph[VD, ED] { // Aggregate information about adjacent triplets ================================================= def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexID]] def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexID, VD)]] - def mapReduceTriplets[A: ClassTag]( - mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexID, A)], - reduceFunc: (A, A) => A, - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None) + def aggregateMessages[Msg: ClassTag]( + sendMsg: EdgeContext[VD, ED, Msg] => Unit, + mergeMsg: (Msg, Msg) => Msg, + tripletFields: TripletFields = TripletFields.All) : VertexRDD[A] // Iterative graph-parallel computation ========================================================== def pregel[A](initialMsg: A, maxIterations: Int, activeDirection: EdgeDirection)( @@ -363,8 +347,7 @@ class Graph[VD, ED] { ## Property Operators -In direct analogy to the RDD `map` operator, the property -graph contains the following: +Like the RDD `map` operator, the property graph contains the following: {% highlight scala %} class Graph[VD, ED] { @@ -377,7 +360,7 @@ class Graph[VD, ED] { Each of these operators yields a new graph with the vertex or edge properties modified by the user defined `map` function. -> Note that in all cases the graph structure is unaffected. This is a key feature of these operators +> Note that in each case the graph structure is unaffected. This is a key feature of these operators > which allows the resulting graph to reuse the structural indices of the original graph. The > following snippets are logically equivalent, but the first one does not preserve the structural > indices and would not benefit from the GraphX system optimizations: @@ -390,14 +373,13 @@ val newGraph = Graph(newVertices, graph.edges) val newGraph = graph.mapVertices((id, attr) => mapUdf(id, attr)) {% endhighlight %} -[Graph.mapVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@mapVertices[VD2]((VertexId,VD)⇒VD2)(ClassTag[VD2]):Graph[VD2,ED] These operators are often used to initialize the graph for a particular computation or project away -unnecessary properties. For example, given a graph with the out-degrees as the vertex properties +unnecessary properties. For example, given a graph with the out degrees as the vertex properties (we describe how to construct such a graph later), we initialize it for PageRank: {% highlight scala %} -// Given a graph where the vertex property is the out-degree +// Given a graph where the vertex property is the out degree val inputGraph: Graph[Int, String] = graph.outerJoinVertices(graph.outDegrees)((vid, _, degOpt) => degOpt.getOrElse(0)) // Construct a graph where each edge contains the weight @@ -406,9 +388,10 @@ val outputGraph: Graph[Double, Double] = inputGraph.mapTriplets(triplet => 1.0 / triplet.srcAttr).mapVertices((id, _) => 1.0) {% endhighlight %} -## Structural Operators +## Structural Operators + Currently GraphX supports only a simple set of commonly used structural operators and we expect to add more in the future. The following is a list of the basic structural operators. @@ -425,9 +408,8 @@ class Graph[VD, ED] { The [`reverse`][Graph.reverse] operator returns a new graph with all the edge directions reversed. This can be useful when, for example, trying to compute the inverse PageRank. Because the reverse operation does not modify vertex or edge properties or change the number of edges, it can be -implemented efficiently without data-movement or duplication. +implemented efficiently without data movement or duplication. -[Graph.reverse]: api/scala/index.html#org.apache.spark.graphx.Graph@reverse:Graph[VD,ED] The [`subgraph`][Graph.subgraph] operator takes vertex and edge predicates and returns the graph containing only the vertices that satisfy the vertex predicate (evaluate to true) and edges that @@ -435,7 +417,6 @@ satisfy the edge predicate *and connect vertices that satisfy the vertex predica operator can be used in number of situations to restrict the graph to the vertices and edges of interest or eliminate broken links. For example in the following code we remove broken links: -[Graph.subgraph]: api/scala/index.html#org.apache.spark.graphx.Graph@subgraph((EdgeTriplet[VD,ED])⇒Boolean,(VertexId,VD)⇒Boolean):Graph[VD,ED] {% highlight scala %} // Create an RDD for the vertices @@ -469,13 +450,12 @@ validGraph.triplets.map( > Note in the above example only the vertex predicate is provided. The `subgraph` operator defaults > to `true` if the vertex or edge predicates are not provided. -The [`mask`][Graph.mask] operator also constructs a subgraph by returning a graph that contains the +The [`mask`][Graph.mask] operator constructs a subgraph by returning a graph that contains the vertices and edges that are also found in the input graph. This can be used in conjunction with the `subgraph` operator to restrict a graph based on the properties in another related graph. For example, we might run connected components using the graph with missing vertices and then restrict the answer to the valid subgraph. -[Graph.mask]: api/scala/index.html#org.apache.spark.graphx.Graph@mask[VD2,ED2](Graph[VD2,ED2])(ClassTag[VD2],ClassTag[ED2]):Graph[VD,ED] {% highlight scala %} // Run Connected Components @@ -490,10 +470,9 @@ The [`groupEdges`][Graph.groupEdges] operator merges parallel edges (i.e., dupli pairs of vertices) in the multigraph. In many numerical applications, parallel edges can be *added* (their weights combined) into a single edge thereby reducing the size of the graph. -[Graph.groupEdges]: api/scala/index.html#org.apache.spark.graphx.Graph@groupEdges((ED,ED)⇒ED):Graph[VD,ED] + ## Join Operators - In many cases it is necessary to join data from external collections (RDDs) with graphs. For example, we might have extra user properties that we want to merge with an existing graph or we @@ -514,10 +493,8 @@ returns a new graph with the vertex properties obtained by applying the user def to the result of the joined vertices. Vertices without a matching value in the RDD retain their original value. -[GraphOps.joinVertices]: api/scala/index.html#org.apache.spark.graphx.GraphOps@joinVertices[U](RDD[(VertexId,U)])((VertexId,VD,U)⇒VD)(ClassTag[U]):Graph[VD,ED] - -> Note that if the RDD contains more than one value for a given vertex only one will be used. It -> is therefore recommended that the input RDD be first made unique using the following which will +> Note that if the RDD contains more than one value for a given vertex only one will be used. It +> is therefore recommended that the input RDD be made unique using the following which will > also *pre-index* the resulting values to substantially accelerate the subsequent join. > {% highlight scala %} val nonUniqueCosts: RDD[(VertexID, Double)] @@ -533,8 +510,6 @@ property type. Because not all vertices may have a matching value in the input function takes an `Option` type. For example, we can setup a graph for PageRank by initializing vertex properties with their `outDegree`. -[Graph.outerJoinVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@outerJoinVertices[U,VD2](RDD[(VertexId,U)])((VertexId,VD,Option[U])⇒VD2)(ClassTag[U],ClassTag[VD2]):Graph[VD2,ED] - {% highlight scala %} val outDegrees: VertexRDD[Int] = graph.outDegrees @@ -555,65 +530,76 @@ val joinedGraph = graph.joinVertices(uniqueCosts, (id: VertexID, oldCost: Double, extraCost: Double) => oldCost + extraCost) {% endhighlight %} +> + + ## Neighborhood Aggregation -A key part of graph computation is aggregating information about the neighborhood of each vertex. -For example we might want to know the number of followers each user has or the average age of the +A key step in may graph analytics tasks is aggregating information about the neighborhood of each +vertex. +For example, we might want to know the number of followers each user has or the average age of the the followers of each user. Many iterative graph algorithms (e.g., PageRank, Shortest Path, and connected components) repeatedly aggregate properties of neighboring vertices (e.g., current PageRank Value, shortest path to the source, and smallest reachable vertex id). -### Map Reduce Triplets (mapReduceTriplets) - +> To improve performance the primary aggregation operator changed from +`graph.mapReduceTriplets` to the new `graph.AggregateMessages`. While the changes in the API are +relatively small, we provide a transition guide below. -[Graph.mapReduceTriplets]: api/scala/index.html#org.apache.spark.graphx.Graph@mapReduceTriplets[A](mapFunc:org.apache.spark.graphx.EdgeTriplet[VD,ED]=>Iterator[(org.apache.spark.graphx.VertexId,A)],reduceFunc:(A,A)=>A,activeSetOpt:Option[(org.apache.spark.graphx.VertexRDD[_],org.apache.spark.graphx.EdgeDirection)])(implicitevidence$10:scala.reflect.ClassTag[A]):org.apache.spark.graphx.VertexRDD[A] + -The core (heavily optimized) aggregation primitive in GraphX is the -[`mapReduceTriplets`][Graph.mapReduceTriplets] operator: +### Aggregate Messages (aggregateMessages) + +The core aggregation operation in GraphX is [`aggregateMessages`][Graph.aggregateMessages]. +This operator applies a user defined `sendMsg` function to each edge triplet in the graph +and then uses the `mergeMsg` function to aggregate those messages at their destination vertex. {% highlight scala %} class Graph[VD, ED] { - def mapReduceTriplets[A]( - map: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], - reduce: (A, A) => A) - : VertexRDD[A] + def aggregateMessages[Msg: ClassTag]( + sendMsg: EdgeContext[VD, ED, Msg] => Unit, + mergeMsg: (Msg, Msg) => Msg, + tripletFields: TripletFields = TripletFields.All) + : VertexRDD[Msg] } {% endhighlight %} -The [`mapReduceTriplets`][Graph.mapReduceTriplets] operator takes a user defined map function which -is applied to each triplet and can yield *messages* destined to either (none or both) vertices in -the triplet. To facilitate optimized pre-aggregation, we currently only support messages destined -to the source or destination vertex of the triplet. The user defined `reduce` function combines the -messages destined to each vertex. The `mapReduceTriplets` operator returns a `VertexRDD[A]` -containing the aggregate message (of type `A`) destined to each vertex. Vertices that do not +The user defined `sendMsg` function takes an [`EdgeContext`][EdgeContext], which exposes the +source and destination attributes along with the edge attribute and functions +([`sendToSrc`][EdgeContext.sendToSrc], and [`sendToDst`][EdgeContext.sendToDst]) to send +messages to the source and destination attributes. Think of `sendMsg` as the map +function in map-reduce. +The user defined `mergeMsg` function takes two messages destined to the same vertex and +yields a single message. Think of `mergeMsg` as the reduce function in map-reduce. +The [`aggregateMessages`][Graph.aggregateMessages] operator returns a `VertexRDD[Msg]` +containing the aggregate message (of type `Msg`) destined to each vertex. Vertices that did not receive a message are not included in the returned `VertexRDD`. -
- -

Note that mapReduceTriplets takes an additional optional activeSet -(not shown above see API docs for details) which restricts the map phase to edges adjacent to the -vertices in the provided VertexRDD:

- -{% highlight scala %} - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None -{% endhighlight %} - -

The EdgeDirection specifies which edges adjacent to the vertex set are included in the map -phase. If the direction is In, then the user defined map function will -only be run only on edges with the destination vertex in the active set. If the direction is -Out, then the map function will only be run only on edges originating from -vertices in the active set. If the direction is Either, then the map -function will be run only on edges with either vertex in the active set. If the direction is -Both, then the map function will be run only on edges with both vertices -in the active set. The active set must be derived from the set of vertices in the graph. -Restricting computation to triplets adjacent to a subset of the vertices is often necessary in -incremental iterative computation and is a key part of the GraphX implementation of Pregel.

- -
- -In the following example we use the `mapReduceTriplets` operator to compute the average age of the -more senior followers of each user. + + +In addition, [`aggregateMessages`][Graph.aggregateMessages] takes an optional +`tripletsFields` which indicates what data is accessed in the [`EdgeContext`][EdgeContext] +(i.e., the source vertex attribute but not the destination vertex attribute). +The possible options for the `tripletsFields` are defined in [`TripletFields`][TripletFields] and +the default value is [`TripletFields.All`][TripletFields.All] which indicates that the user +defined `sendMsg` function may access any of the fields in the [`EdgeContext`][EdgeContext]. +The `tripletFields` argument can be used to notify GraphX that only part of the +[`EdgeContext`][EdgeContext] will be needed allowing GraphX to select an optimized join strategy. +For example if we are computing the average age of the followers of each user we would only require +the source field and so we would use [`TripletFields.Src`][TripletFields.Src] to indicate that we +only require the source field + +> In earlier versions of GraphX we used byte code inspection to infer the +[`TripletFields`][TripletFields] however we have found that bytecode inspection to be +slightly unreliable and instead opted for more explicit user control. + +In the following example we use the [`aggregateMessages`][Graph.aggregateMessages] operator to +compute the average age of the more senior followers of each user. {% highlight scala %} // Import random graph generation library @@ -622,14 +608,11 @@ import org.apache.spark.graphx.util.GraphGenerators val graph: Graph[Double, Int] = GraphGenerators.logNormalGraph(sc, numVertices = 100).mapVertices( (id, _) => id.toDouble ) // Compute the number of older followers and their total age -val olderFollowers: VertexRDD[(Int, Double)] = graph.mapReduceTriplets[(Int, Double)]( +val olderFollowers: VertexRDD[(Int, Double)] = graph.aggregateMessages[(Int, Double)]( triplet => { // Map Function if (triplet.srcAttr > triplet.dstAttr) { // Send message to destination vertex containing counter and age - Iterator((triplet.dstId, (1, triplet.srcAttr))) - } else { - // Don't send a message for this triplet - Iterator.empty + triplet.sendToDst(1, triplet.srcAttr) } }, // Add counter and age @@ -642,10 +625,57 @@ val avgAgeOfOlderFollowers: VertexRDD[Double] = avgAgeOfOlderFollowers.collect.foreach(println(_)) {% endhighlight %} -> Note that the `mapReduceTriplets` operation performs optimally when the messages (and the sums of -> messages) are constant sized (e.g., floats and addition instead of lists and concatenation). More -> precisely, the result of `mapReduceTriplets` should ideally be sub-linear in the degree of each -> vertex. +> The `aggregateMessages` operation performs optimally when the messages (and the sums of +> messages) are constant sized (e.g., floats and addition instead of lists and concatenation). + + + +### Map Reduce Triplets Transition Guide (Legacy) + +In earlier versions of GraphX we neighborhood aggregation was accomplished using the +[`mapReduceTriplets`][Graph.mapReduceTriplets] operator: + +{% highlight scala %} +class Graph[VD, ED] { + def mapReduceTriplets[Msg]( + map: EdgeTriplet[VD, ED] => Iterator[(VertexId, Msg)], + reduce: (Msg, Msg) => Msg) + : VertexRDD[Msg] +} +{% endhighlight %} + +The [`mapReduceTriplets`][Graph.mapReduceTriplets] operator takes a user defined map function which +is applied to each triplet and can yield *messages* which are aggregated using the user defined +`reduce` function. +However, we found the user of the returned iterator to be expensive and it inhibited our ability to +apply additional optimizations (e.g., local vertex renumbering). +In [`aggregateMessages`][Graph.aggregateMessages] we introduced the EdgeContext which exposes the +triplet fields and also functions to explicitly send messages to the source and destination vertex. +Furthermore we removed bytecode inspection and instead require the user to indicate what fields +in the triplet are actually required. + +The following code block using `mapReduceTriplets`: + +{% highlight scala %} +val graph: Graph[Int, Float] = ... +def msgFun(triplet: Triplet[Int, Float]): Iterator[(Int, String)] = { + Iterator((triplet.dstId, "Hi")) +} +def reduceFun(a: Int, b: Int): Int = a + b +val result = graph.mapReduceTriplets[String](msgFun, reduceFun) +{% endhighlight %} + +can be rewritten using `aggregateMessages` as: + +{% highlight scala %} +val graph: Graph[Int, Float] = ... +def msgFun(triplet: EdgeContext[Int, Float, String]) { + triplet.sendToDst("Hi") +} +def reduceFun(a: Int, b: Int): Int = a + b +val result = graph.aggregateMessages[String](msgFun, reduceFun) +{% endhighlight %} + ### Computing Degree Information @@ -673,10 +703,6 @@ attributes at each vertex. This can be easily accomplished using the [`collectNeighborIds`][GraphOps.collectNeighborIds] and the [`collectNeighbors`][GraphOps.collectNeighbors] operators. -[GraphOps.collectNeighborIds]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighborIds(EdgeDirection):VertexRDD[Array[VertexId]] -[GraphOps.collectNeighbors]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighbors(EdgeDirection):VertexRDD[Array[(VertexId,VD)]] - - {% highlight scala %} class GraphOps[VD, ED] { def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] @@ -684,36 +710,34 @@ class GraphOps[VD, ED] { } {% endhighlight %} -> Note that these operators can be quite costly as they duplicate information and require +> These operators can be quite costly as they duplicate information and require > substantial communication. If possible try expressing the same computation using the -> `mapReduceTriplets` operator directly. +> [`aggregateMessages`][Graph.aggregateMessages] operator directly. ## Caching and Uncaching In Spark, RDDs are not persisted in memory by default. To avoid recomputation, they must be explicitly cached when using them multiple times (see the [Spark Programming Guide][RDD Persistence]). Graphs in GraphX behave the same way. **When using a graph multiple times, make sure to call [`Graph.cache()`][Graph.cache] on it first.** -[RDD Persistence]: programming-guide.html#rdd-persistence -[Graph.cache]: api/scala/index.html#org.apache.spark.graphx.Graph@cache():Graph[VD,ED] In iterative computations, *uncaching* may also be necessary for best performance. By default, cached RDDs and graphs will remain in memory until memory pressure forces them to be evicted in LRU order. For iterative computation, intermediate results from previous iterations will fill up the cache. Though they will eventually be evicted, the unnecessary data stored in memory will slow down garbage collection. It would be more efficient to uncache intermediate results as soon as they are no longer necessary. This involves materializing (caching and forcing) a graph or RDD every iteration, uncaching all other datasets, and only using the materialized dataset in future iterations. However, because graphs are composed of multiple RDDs, it can be difficult to unpersist them correctly. **For iterative computation we recommend using the Pregel API, which correctly unpersists intermediate results.** -# Pregel API -Graphs are inherently recursive data-structures as properties of vertices depend on properties of +# Pregel API + +Graphs are inherently recursive data structures as properties of vertices depend on properties of their neighbors which in turn depend on properties of *their* neighbors. As a consequence many important graph algorithms iteratively recompute the properties of each vertex until a fixed-point condition is reached. A range of graph-parallel abstractions have been proposed -to express these iterative algorithms. GraphX exposes a Pregel-like operator which is a fusion of -the widely used Pregel and GraphLab abstractions. +to express these iterative algorithms. GraphX exposes a variant of the Pregel API. -At a high-level the Pregel operator in GraphX is a bulk-synchronous parallel messaging abstraction -*constrained to the topology of the graph*. The Pregel operator executes in a series of super-steps -in which vertices receive the *sum* of their inbound messages from the previous super- step, compute +At a high level the Pregel operator in GraphX is a bulk-synchronous parallel messaging abstraction +*constrained to the topology of the graph*. The Pregel operator executes in a series of super steps +in which vertices receive the *sum* of their inbound messages from the previous super step, compute a new value for the vertex property, and then send messages to neighboring vertices in the next -super-step. Unlike Pregel and instead more like GraphLab messages are computed in parallel as a +super step. Unlike Pregel, messages are computed in parallel as a function of the edge triplet and the message computation has access to both the source and -destination vertex attributes. Vertices that do not receive a message are skipped within a super- +destination vertex attributes. Vertices that do not receive a message are skipped within a super step. The Pregel operators terminates iteration and returns the final graph when there are no messages remaining. @@ -724,8 +748,6 @@ messages remaining. The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch* of its implementation (note calls to graph.cache have been removed): -[GraphOps.pregel]: api/scala/index.html#org.apache.spark.graphx.GraphOps@pregel[A](A,Int,EdgeDirection)((VertexId,VD,A)⇒VD,(EdgeTriplet[VD,ED])⇒Iterator[(VertexId,A)],(A,A)⇒A)(ClassTag[A]):Graph[VD,ED] - {% highlight scala %} class GraphOps[VD, ED] { def pregel[A] @@ -795,9 +817,10 @@ val sssp = initialGraph.pregel(Double.PositiveInfinity)( println(sssp.vertices.collect.mkString("\n")) {% endhighlight %} -# Graph Builders +# Graph Builders + GraphX provides several ways of building a graph from a collection of vertices and edges in an RDD or on disk. None of the graph builders repartitions the graph's edges by default; instead, edges are left in their default partitions (such as their original blocks in HDFS). [`Graph.groupEdges`][Graph.groupEdges] requires the graph to be repartitioned because it assumes identical edges will be colocated on the same partition, so you must call [`Graph.partitionBy`][Graph.partitionBy] before calling `groupEdges`. {% highlight scala %} @@ -848,18 +871,12 @@ object Graph { [`Graph.fromEdgeTuples`][Graph.fromEdgeTuples] allows creating a graph from only an RDD of edge tuples, assigning the edges the value 1, and automatically creating any vertices mentioned by edges and assigning them the default value. It also supports deduplicating the edges; to deduplicate, pass `Some` of a [`PartitionStrategy`][PartitionStrategy] as the `uniqueEdges` parameter (for example, `uniqueEdges = Some(PartitionStrategy.RandomVertexCut)`). A partition strategy is necessary to colocate identical edges on the same partition so they can be deduplicated. -[PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy$ - -[GraphLoader.edgeListFile]: api/scala/index.html#org.apache.spark.graphx.GraphLoader$@edgeListFile(SparkContext,String,Boolean,Int):Graph[Int,Int] -[Graph.apply]: api/scala/index.html#org.apache.spark.graphx.Graph$@apply[VD,ED](RDD[(VertexId,VD)],RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED] -[Graph.fromEdgeTuples]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdgeTuples[VD](RDD[(VertexId,VertexId)],VD,Option[PartitionStrategy])(ClassTag[VD]):Graph[VD,Int] -[Graph.fromEdges]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdges[VD,ED](RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED] + # Vertex and Edge RDDs - GraphX exposes `RDD` views of the vertices and edges stored within the graph. However, because -GraphX maintains the vertices and edges in optimized data-structures and these data-structures +GraphX maintains the vertices and edges in optimized data structures and these data structures provide additional functionality, the vertices and edges are returned as `VertexRDD` and `EdgeRDD` respectively. In this section we review some of the additional useful functionality in these types. @@ -870,7 +887,7 @@ The `VertexRDD[A]` extends `RDD[(VertexID, A)]` and adds the additional constrai attribute of type `A`. Internally, this is achieved by storing the vertex attributes in a reusable hash-map data-structure. As a consequence if two `VertexRDD`s are derived from the same base `VertexRDD` (e.g., by `filter` or `mapValues`) they can be joined in constant time without hash -evaluations. To leverage this indexed data-structure, the `VertexRDD` exposes the following +evaluations. To leverage this indexed data structure, the `VertexRDD` exposes the following additional functionality: {% highlight scala %} @@ -893,7 +910,7 @@ class VertexRDD[VD] extends RDD[(VertexID, VD)] { Notice, for example, how the `filter` operator returns an `VertexRDD`. Filter is actually implemented using a `BitSet` thereby reusing the index and preserving the ability to do fast joins with other `VertexRDD`s. Likewise, the `mapValues` operators do not allow the `map` function to -change the `VertexID` thereby enabling the same `HashMap` data-structures to be reused. Both the +change the `VertexID` thereby enabling the same `HashMap` data structures to be reused. Both the `leftJoin` and `innerJoin` are able to identify when joining two `VertexRDD`s derived from the same `HashMap` and implement the join by linear scan rather than costly point lookups. @@ -916,21 +933,19 @@ val setC: VertexRDD[Double] = setA.innerJoin(setB)((id, a, b) => a + b) ## EdgeRDDs -The `EdgeRDD[ED, VD]`, which extends `RDD[Edge[ED]]` organizes the edges in blocks partitioned using one +The `EdgeRDD[ED]`, which extends `RDD[Edge[ED]]` organizes the edges in blocks partitioned using one of the various partitioning strategies defined in [`PartitionStrategy`][PartitionStrategy]. Within each partition, edge attributes and adjacency structure, are stored separately enabling maximum reuse when changing attribute values. -[PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy - The three additional functions exposed by the `EdgeRDD` are: {% highlight scala %} // Transform the edge attributes while preserving the structure -def mapValues[ED2](f: Edge[ED] => ED2): EdgeRDD[ED2, VD] +def mapValues[ED2](f: Edge[ED] => ED2): EdgeRDD[ED2] // Revere the edges reusing both attributes and structure -def reverse: EdgeRDD[ED, VD] +def reverse: EdgeRDD[ED] // Join two `EdgeRDD`s partitioned using the same partitioning strategy. -def innerJoin[ED2, ED3](other: EdgeRDD[ED2, VD])(f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3, VD] +def innerJoin[ED2, ED3](other: EdgeRDD[ED2])(f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3] {% endhighlight %} In most applications we have found that operations on the `EdgeRDD` are accomplished through the @@ -960,7 +975,6 @@ the [`Graph.partitionBy`][Graph.partitionBy] operator. The default partitioning the initial partitioning of the edges as provided on graph construction. However, users can easily switch to 2D-partitioning or other heuristics included in GraphX. -[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph$@partitionBy(partitionStrategy:org.apache.spark.graphx.PartitionStrategy):org.apache.spark.graphx.Graph[VD,ED]

+# Graph Algorithms + GraphX includes a set of graph algorithms to simplify analytics tasks. The algorithms are contained in the `org.apache.spark.graphx.lib` package and can be accessed directly as methods on `Graph` via [`GraphOps`][GraphOps]. This section describes the algorithms and how they are used. -## PageRank +## PageRank + PageRank measures the importance of each vertex in a graph, assuming an edge from *u* to *v* represents an endorsement of *v*'s importance by *u*. For example, if a Twitter user is followed by many others, the user will be ranked highly. GraphX comes with static and dynamic implementations of PageRank as methods on the [`PageRank` object][PageRank]. Static PageRank runs for a fixed number of iterations, while dynamic PageRank runs until the ranks converge (i.e., stop changing by more than a specified tolerance). [`GraphOps`][GraphOps] allows calling these algorithms directly as methods on `Graph`. GraphX also includes an example social network dataset that we can run PageRank on. A set of users is given in `graphx/data/users.txt`, and a set of relationships between users is given in `graphx/data/followers.txt`. We compute the PageRank of each user as follows: -[PageRank]: api/scala/index.html#org.apache.spark.graphx.lib.PageRank$ - {% highlight scala %} // Load the edges as a graph val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt") @@ -1014,8 +1028,6 @@ println(ranksByUsername.collect().mkString("\n")) The connected components algorithm labels each connected component of the graph with the ID of its lowest-numbered vertex. For example, in a social network, connected components can approximate clusters. GraphX contains an implementation of the algorithm in the [`ConnectedComponents` object][ConnectedComponents], and we compute the connected components of the example social network dataset from the [PageRank section](#pagerank) as follows: -[ConnectedComponents]: api/scala/index.html#org.apache.spark.graphx.lib.ConnectedComponents$ - {% highlight scala %} // Load the graph as in the PageRank example val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt") @@ -1037,9 +1049,6 @@ println(ccByUsername.collect().mkString("\n")) A vertex is part of a triangle when it has two adjacent vertices with an edge between them. GraphX implements a triangle counting algorithm in the [`TriangleCount` object][TriangleCount] that determines the number of triangles passing through each vertex, providing a measure of clustering. We compute the triangle count of the social network dataset from the [PageRank section](#pagerank). *Note that `TriangleCount` requires the edges to be in canonical orientation (`srcId < dstId`) and the graph to be partitioned using [`Graph.partitionBy`][Graph.partitionBy].* -[TriangleCount]: api/scala/index.html#org.apache.spark.graphx.lib.TriangleCount$ -[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph@partitionBy(PartitionStrategy):Graph[VD,ED] - {% highlight scala %} // Load the edges in canonical order and partition the graph for triangle count val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt", true).partitionBy(PartitionStrategy.RandomVertexCut) diff --git a/docs/img/data_parallel_vs_graph_parallel.png b/docs/img/data_parallel_vs_graph_parallel.png deleted file mode 100644 index d3918f01d8f3b..0000000000000 Binary files a/docs/img/data_parallel_vs_graph_parallel.png and /dev/null differ diff --git a/docs/img/graph_analytics_pipeline.png b/docs/img/graph_analytics_pipeline.png deleted file mode 100644 index 6d606e01894ae..0000000000000 Binary files a/docs/img/graph_analytics_pipeline.png and /dev/null differ diff --git a/docs/img/tables_and_graphs.png b/docs/img/tables_and_graphs.png deleted file mode 100644 index ec37bb45a62f0..0000000000000 Binary files a/docs/img/tables_and_graphs.png and /dev/null differ diff --git a/docs/monitoring.md b/docs/monitoring.md index e3f81a76acdbb..f32cdef240d31 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -79,7 +79,7 @@ follows: spark.history.fs.logDirectory - (none) + file:/tmp/spark-events Directory that contains application event logs to be loaded by the history server diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 9de2f914b8b4c..49f319ba775e5 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -117,6 +117,8 @@ The first thing a Spark program must do is to create a [SparkContext](api/scala/ how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) object that contains information about your application. +Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before creating a new one. + {% highlight scala %} val conf = new SparkConf().setAppName(appName).setMaster(master) new SparkContext(conf) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 48e8267ac072c..5500da83b2b66 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -14,7 +14,7 @@ title: Spark SQL Programming Guide Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using Spark. At the core of this component is a new type of RDD, [SchemaRDD](api/scala/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed of -[Row](api/scala/index.html#org.apache.spark.sql.catalyst.expressions.Row) objects, along with +[Row](api/scala/index.html#org.apache.spark.sql.package@Row:org.apache.spark.sql.catalyst.expressions.Row.type) objects, along with a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index a5396c2375915..a4ab844a56adf 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -138,7 +138,7 @@ def parse_args(): help="The SSH user you want to connect as (default: %default)") parser.add_option( "--delete-groups", action="store_true", default=False, - help="When destroying a cluster, delete the security groups that were created.") + help="When destroying a cluster, delete the security groups that were created") parser.add_option( "--use-existing-master", action="store_true", default=False, help="Launch fresh slaves, but use an existing stopped master if possible") @@ -152,9 +152,6 @@ def parse_args(): parser.add_option( "--user-data", type="string", default="", help="Path to a user-data file (most AMI's interpret this as an initialization script)") - parser.add_option( - "--security-group-prefix", type="string", default=None, - help="Use this prefix for the security group rather than the cluster name.") parser.add_option( "--authorized-address", type="string", default="0.0.0.0/0", help="Address to authorize on created security groups (default: %default)") @@ -305,12 +302,8 @@ def launch_cluster(conn, opts, cluster_name): user_data_content = user_data_file.read() print "Setting up security groups..." - if opts.security_group_prefix is None: - master_group = get_or_make_group(conn, cluster_name + "-master") - slave_group = get_or_make_group(conn, cluster_name + "-slaves") - else: - master_group = get_or_make_group(conn, opts.security_group_prefix + "-master") - slave_group = get_or_make_group(conn, opts.security_group_prefix + "-slaves") + master_group = get_or_make_group(conn, cluster_name + "-master") + slave_group = get_or_make_group(conn, cluster_name + "-slaves") authorized_address = opts.authorized_address if master_group.rules == []: # Group was just now created master_group.authorize(src_group=master_group) @@ -335,11 +328,12 @@ def launch_cluster(conn, opts, cluster_name): slave_group.authorize('tcp', 60060, 60060, authorized_address) slave_group.authorize('tcp', 60075, 60075, authorized_address) - # Check if instances are already running with the cluster name + # Check if instances are already running in our groups existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, die_on_error=False) if existing_slaves or (existing_masters and not opts.use_existing_master): - print >> stderr, ("ERROR: There are already instances for name: %s " % cluster_name) + print >> stderr, ("ERROR: There are already instances running in " + + "group %s or %s" % (master_group.name, slave_group.name)) sys.exit(1) # Figure out Spark AMI @@ -413,13 +407,9 @@ def launch_cluster(conn, opts, cluster_name): for r in reqs: id_to_req[r.id] = r active_instance_ids = [] - outstanding_request_ids = [] for i in my_req_ids: - if i in id_to_req: - if id_to_req[i].state == "active": - active_instance_ids.append(id_to_req[i].instance_id) - else: - outstanding_request_ids.append(i) + if i in id_to_req and id_to_req[i].state == "active": + active_instance_ids.append(id_to_req[i].instance_id) if len(active_instance_ids) == opts.slaves: print "All %d slaves granted" % opts.slaves reservations = conn.get_all_instances(active_instance_ids) @@ -428,8 +418,8 @@ def launch_cluster(conn, opts, cluster_name): slave_nodes += r.instances break else: - print "%d of %d slaves granted, waiting longer for request ids including %s" % ( - len(active_instance_ids), opts.slaves, outstanding_request_ids[0:10]) + print "%d of %d slaves granted, waiting longer" % ( + len(active_instance_ids), opts.slaves) except: print "Canceling spot instance requests" conn.cancel_spot_instance_requests(my_req_ids) @@ -488,59 +478,34 @@ def launch_cluster(conn, opts, cluster_name): # Give the instances descriptive names for master in master_nodes: - name = '{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id) - tag_instance(master, name) - + master.add_tag( + key='Name', + value='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) for slave in slave_nodes: - name = '{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id) - tag_instance(slave, name) + slave.add_tag( + key='Name', + value='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) # Return all the instances return (master_nodes, slave_nodes) -def tag_instance(instance, name): - for i in range(0, 5): - try: - instance.add_tag(key='Name', value=name) - break - except: - print "Failed attempt %i of 5 to tag %s" % ((i + 1), name) - if i == 5: - raise "Error - failed max attempts to add name tag" - time.sleep(5) - # Get the EC2 instances in an existing cluster if available. # Returns a tuple of lists of EC2 instance objects for the masters and slaves def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): print "Searching for existing cluster " + cluster_name + "..." - # Search all the spot instance requests, and copy any tags from the spot - # instance request to the cluster. - spot_instance_requests = conn.get_all_spot_instance_requests() - for req in spot_instance_requests: - if req.state != u'active': - continue - name = req.tags.get(u'Name', "") - if name.startswith(cluster_name): - reservations = conn.get_all_instances(instance_ids=[req.instance_id]) - for res in reservations: - active = [i for i in res.instances if is_active(i)] - for instance in active: - if instance.tags.get(u'Name') is None: - tag_instance(instance, name) - # Now proceed to detect master and slaves instances. reservations = conn.get_all_instances() master_nodes = [] slave_nodes = [] for res in reservations: active = [i for i in res.instances if is_active(i)] for inst in active: - name = inst.tags.get(u'Name', "") - if name.startswith(cluster_name + "-master"): + group_names = [g.name for g in inst.groups] + if group_names == [cluster_name + "-master"]: master_nodes.append(inst) - elif name.startswith(cluster_name + "-slave"): + elif group_names == [cluster_name + "-slaves"]: slave_nodes.append(inst) if any((master_nodes, slave_nodes)): print "Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes)) @@ -548,12 +513,12 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): return (master_nodes, slave_nodes) else: if master_nodes == [] and slave_nodes != []: - print >> sys.stderr, "ERROR: Could not find master in with name " + \ - cluster_name + "-master" + print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name + "-master" else: print >> sys.stderr, "ERROR: Could not find any existing cluster" sys.exit(1) + # Deploy configuration files and run setup scripts on a newly launched # or started EC2 cluster. @@ -984,11 +949,7 @@ def real_main(): # Delete security groups as well if opts.delete_groups: print "Deleting security groups (this will take some time)..." - if opts.security_group_prefix is None: - group_names = [cluster_name + "-master", cluster_name + "-slaves"] - else: - group_names = [opts.security_group_prefix + "-master", - opts.security_group_prefix + "-slaves"] + group_names = [cluster_name + "-master", cluster_name + "-slaves"] wait_for_cluster_state( cluster_instances=(master_nodes + slave_nodes), cluster_state='terminated', diff --git a/examples/pom.xml b/examples/pom.xml index 2752ce3ca9821..8713230e1e8ed 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml @@ -106,6 +106,11 @@ hbase-testing-util ${hbase.version} + + + org.apache.hbase + hbase-annotations + org.jruby jruby-complete @@ -121,12 +126,24 @@ org.apache.hbase hbase-common ${hbase.version} + + + + org.apache.hbase + hbase-annotations + + org.apache.hbase hbase-client ${hbase.version} + + + org.apache.hbase + hbase-annotations + io.netty netty @@ -158,6 +175,11 @@ org.apache.hadoop hadoop-auth + + + org.apache.hbase + hbase-annotations + org.apache.hadoop hadoop-annotations @@ -217,6 +239,11 @@ org.apache.commons commons-math3 + + com.twitter + algebird-core_${scala.binary.version} + 0.8.1 + org.scalatest scalatest_${scala.binary.version} @@ -389,8 +416,8 @@ - + scala-2.10 !scala-2.11 @@ -401,11 +428,6 @@ spark-streaming-kafka_${scala.binary.version} ${project.version} - - com.twitter - algebird-core_${scala.binary.version} - 0.1.11 - diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java similarity index 88% rename from examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java rename to examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java index 1af2067b2b929..4a5ac404ea5ea 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java @@ -27,18 +27,18 @@ import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.GradientBoosting; +import org.apache.spark.mllib.tree.GradientBoostedTrees; import org.apache.spark.mllib.tree.configuration.BoostingStrategy; -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel; +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; import org.apache.spark.mllib.util.MLUtils; /** * Classification and regression using gradient-boosted decision trees. */ -public final class JavaGradientBoostedTrees { +public final class JavaGradientBoostedTreesRunner { private static void usage() { - System.err.println("Usage: JavaGradientBoostedTrees " + + System.err.println("Usage: JavaGradientBoostedTreesRunner " + " "); System.exit(-1); } @@ -55,7 +55,7 @@ public static void main(String[] args) { if (args.length > 2) { usage(); } - SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees"); + SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner"); JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); @@ -64,7 +64,7 @@ public static void main(String[] args) { // Note: All features are treated as continuous. BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo); boostingStrategy.setNumIterations(10); - boostingStrategy.weakLearnerParams().setMaxDepth(5); + boostingStrategy.treeStrategy().setMaxDepth(5); if (algo.equals("Classification")) { // Compute the number of classes from the data. @@ -73,10 +73,10 @@ public static void main(String[] args) { return p.label(); } }).countByValue().size(); - boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression + boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses); // Train a GradientBoosting model for classification. - final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy); + final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); // Evaluate model on training instances and compute training error JavaPairRDD predictionAndLabel = @@ -95,7 +95,7 @@ public static void main(String[] args) { System.out.println("Learned classification tree model:\n" + model); } else if (algo.equals("Regression")) { // Train a GradientBoosting model for classification. - final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy); + final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); // Evaluate model on training instances and compute training error JavaPairRDD predictionAndLabel = diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 63f02cf7b98b9..98f9d1689c8e7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -22,11 +22,11 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity} +import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -349,24 +349,14 @@ object DecisionTreeRunner { sc.stop() } - /** - * Calculates the mean squared error for regression. - */ - private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { - data.map { y => - val err = tree.predict(y.features) - y.label - err * err - }.mean() - } - /** * Calculates the mean squared error for regression. */ private[mllib] def meanSquaredError( - tree: WeightedEnsembleModel, + model: { def predict(features: Vector): Double }, data: RDD[LabeledPoint]): Double = { data.map { y => - val err = tree.predict(y.features) - y.label + val err = model.predict(y.features) - y.label err * err }.mean() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala similarity index 91% rename from examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala rename to examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index 9b6db01448be0..1def8b45a230c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -21,21 +21,21 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.tree.GradientBoosting +import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} import org.apache.spark.util.Utils /** * An example runner for Gradient Boosting using decision trees as weak learners. Run with * {{{ - * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options] + * ./bin/run-example mllib.GradientBoostedTreesRunner [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. * * Note: This script treats all features as real-valued (not categorical). * To include categorical features, modify categoricalFeaturesInfo. */ -object GradientBoostedTrees { +object GradientBoostedTreesRunner { case class Params( input: String = null, @@ -93,24 +93,24 @@ object GradientBoostedTrees { def run(params: Params) { - val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params") + val conf = new SparkConf().setAppName(s"GradientBoostedTreesRunner with $params") val sc = new SparkContext(conf) - println(s"GradientBoostedTrees with parameters:\n$params") + println(s"GradientBoostedTreesRunner with parameters:\n$params") // Load training and test data and cache it. val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input, params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest) val boostingStrategy = BoostingStrategy.defaultParams(params.algo) - boostingStrategy.numClassesForClassification = numClasses + boostingStrategy.treeStrategy.numClassesForClassification = numClasses boostingStrategy.numIterations = params.numIterations - boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth + boostingStrategy.treeStrategy.maxDepth = params.maxDepth val randomSeed = Utils.random.nextInt() if (params.algo == "Classification") { val startTime = System.nanoTime() - val model = GradientBoosting.trainClassifier(training, boostingStrategy) + val model = GradientBoostedTrees.train(training, boostingStrategy) val elapsedTime = (System.nanoTime() - startTime) / 1e9 println(s"Training time: $elapsedTime seconds") if (model.totalNumNodes < 30) { @@ -127,7 +127,7 @@ object GradientBoostedTrees { println(s"Test accuracy = $testAccuracy") } else if (params.algo == "Regression") { val startTime = System.nanoTime() - val model = GradientBoosting.trainRegressor(training, boostingStrategy) + val model = GradientBoostedTrees.train(training, boostingStrategy) val elapsedTime = (System.nanoTime() - startTime) / 1e9 println(s"Training time: $elapsedTime seconds") if (model.totalNumNodes < 30) { diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 0c52ef8ed96ac..227acc117502d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -27,6 +27,7 @@ object HiveFromSpark { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("HiveFromSpark") val sc = new SparkContext(sparkConf) + val path = s"${System.getenv("SPARK_HOME")}/examples/src/main/resources/kv1.txt" // A local hive context creates an instance of the Hive Metastore in process, storing // the warehouse data in the current directory. This location can be overridden by @@ -35,7 +36,7 @@ object HiveFromSpark { import hiveContext._ sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - sql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src") + sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE src") // Queries are expressed in HiveQL println("Result of 'SELECT *': ") diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 514252b89e74e..ed186ea5650c4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -64,7 +64,7 @@ object StatefulNetworkWordCount { // Initial RDD input to updateStateByKey val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1))) - // Create a NetworkInputDStream on target ip:port and count the + // Create a ReceiverInputDStream on target ip:port and count the // words in input stream of \n delimited test (eg. generated by 'nc') val lines = ssc.socketTextStream(args(0), args(1).toInt) val words = lines.flatMap(_.split(" ")) diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala similarity index 100% rename from examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala rename to examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala similarity index 100% rename from examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala rename to examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index d9b886eff77cc..55226c0a6df60 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -50,7 +50,7 @@ object PageViewStream { val ssc = new StreamingContext("local[2]", "PageViewStream", Seconds(1), System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) - // Create a NetworkInputDStream on target host:port and convert each line to a PageView + // Create a ReceiverInputDStream on target host:port and convert each line to a PageView val pageViews = ssc.socketTextStream(host, port) .flatMap(_.split("\n")) .map(PageView.fromString(_)) diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index ac291bd4fde20..72618b6515f83 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 7d31e32283d88..a682f0e8471d8 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -39,19 +39,13 @@ org.apache.spark spark-streaming_${scala.binary.version} ${project.version} + provided org.apache.spark spark-streaming-flume-sink_${scala.binary.version} ${project.version} - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test - org.apache.flume flume-ng-sdk diff --git a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java new file mode 100644 index 0000000000000..6e1f01900071b --- /dev/null +++ b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.junit.After; +import org.junit.Before; + +public abstract class LocalJavaStreamingContext { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala new file mode 100644 index 0000000000000..1a900007b696b --- /dev/null +++ b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.{IOException, ObjectInputStream} + +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream.{DStream, ForEachDStream} +import org.apache.spark.util.Utils + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +/** + * This is a output stream just for the testsuites. All the output is collected into a + * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + * + * The buffer contains a sequence of RDD's, each containing a sequence of items + */ +class TestOutputStream[T: ClassTag](parent: DStream[T], + val output: ArrayBuffer[Seq[T]] = ArrayBuffer[Seq[T]]()) + extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { + val collected = rdd.collect() + output += collected + }) { + + // This is to clear the output buffer every it is read from a checkpoint + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { + ois.defaultReadObject() + output.clear() + } +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 475026e8eb140..b57a1c71e35b9 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -20,9 +20,6 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress import java.util.concurrent.{Callable, ExecutorCompletionService, Executors} -import java.util.Random - -import org.apache.spark.TestUtils import scala.collection.JavaConversions._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} @@ -32,20 +29,35 @@ import org.apache.flume.channel.MemoryChannel import org.apache.flume.conf.Configurables import org.apache.flume.event.EventBuilder +import org.scalatest.{BeforeAndAfter, FunSuite} + +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.util.ManualClock -import org.apache.spark.streaming.{TestSuiteBase, TestOutputStream, StreamingContext} +import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} import org.apache.spark.streaming.flume.sink._ import org.apache.spark.util.Utils -class FlumePollingStreamSuite extends TestSuiteBase { +class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging { val batchCount = 5 val eventsPerBatch = 100 val totalEventsPerChannel = batchCount * eventsPerBatch val channelCapacity = 5000 val maxAttempts = 5 + val batchDuration = Seconds(1) + + val conf = new SparkConf() + .setMaster("local[2]") + .setAppName(this.getClass.getSimpleName) + + def beforeFunction() { + logInfo("Using manual clock") + conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + } + + before(beforeFunction()) test("flume polling test") { testMultipleTimes(testFlumePolling) @@ -229,4 +241,5 @@ class FlumePollingStreamSuite extends TestSuiteBase { null } } + } diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 2067c473f0e3f..b3f44471cd326 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -39,13 +39,7 @@ org.apache.spark spark-streaming_${scala.binary.version} ${project.version} - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test + provided org.apache.kafka diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 362a76e515938..703806735b3ff 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -39,13 +39,7 @@ org.apache.spark spark-streaming_${scala.binary.version} ${project.version} - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test + provided org.eclipse.paho diff --git a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java new file mode 100644 index 0000000000000..6e1f01900071b --- /dev/null +++ b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.junit.After; +import org.junit.Before; + +public abstract class LocalJavaStreamingContext { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } +} diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index 467fd263e2d64..84595acf45ccb 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -17,11 +17,19 @@ package org.apache.spark.streaming.mqtt -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.scalatest.FunSuite + +import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -class MQTTStreamSuite extends TestSuiteBase { +class MQTTStreamSuite extends FunSuite { + + val batchDuration = Seconds(1) + + private val master: String = "local[2]" + + private val framework: String = this.getClass.getSimpleName test("mqtt input stream") { val ssc = new StreamingContext(master, framework, batchDuration) diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 1d7dd49d15c22..000ace1446e5e 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -39,13 +39,7 @@ org.apache.spark spark-streaming_${scala.binary.version} ${project.version} - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test + provided org.twitter4j diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java new file mode 100644 index 0000000000000..6e1f01900071b --- /dev/null +++ b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.junit.After; +import org.junit.Before; + +public abstract class LocalJavaStreamingContext { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } +} diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala index 93741e0375164..9ee57d7581d85 100644 --- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala +++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala @@ -17,13 +17,23 @@ package org.apache.spark.streaming.twitter -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} -import org.apache.spark.storage.StorageLevel + +import org.scalatest.{BeforeAndAfter, FunSuite} +import twitter4j.Status import twitter4j.auth.{NullAuthorization, Authorization} + +import org.apache.spark.Logging +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -import twitter4j.Status -class TwitterStreamSuite extends TestSuiteBase { +class TwitterStreamSuite extends FunSuite with BeforeAndAfter with Logging { + + val batchDuration = Seconds(1) + + private val master: String = "local[2]" + + private val framework: String = this.getClass.getSimpleName test("twitter input stream") { val ssc = new StreamingContext(master, framework, batchDuration) diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 7e48968feb3bc..29c452093502e 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -39,13 +39,7 @@ org.apache.spark spark-streaming_${scala.binary.version} ${project.version} - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test + provided ${akka.group} diff --git a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java new file mode 100644 index 0000000000000..6e1f01900071b --- /dev/null +++ b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.junit.After; +import org.junit.Before; + +public abstract class LocalJavaStreamingContext { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } +} diff --git a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala index cc10ff6ae03cd..a7566e733d891 100644 --- a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala +++ b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala @@ -20,12 +20,19 @@ package org.apache.spark.streaming.zeromq import akka.actor.SupervisorStrategy import akka.util.ByteString import akka.zeromq.Subscribe +import org.scalatest.FunSuite import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream -class ZeroMQStreamSuite extends TestSuiteBase { +class ZeroMQStreamSuite extends FunSuite { + + val batchDuration = Seconds(1) + + private val master: String = "local[2]" + + private val framework: String = this.getClass.getSimpleName test("zeromq input stream") { val ssc = new StreamingContext(master, framework, batchDuration) diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 7e478bed62da7..c8477a6566311 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 560244ad93369..c0d3a61119113 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 71a078d58a8d8..d1427f6a0c6e9 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 3f49b1d63b6e1..9982b36f9b62f 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index 869ef15893eb9..cc70b396a8dd4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -17,6 +17,7 @@ package org.apache.spark.graphx +import scala.language.existentials import scala.reflect.ClassTag import org.apache.spark.Dependency @@ -36,16 +37,16 @@ import org.apache.spark.graphx.impl.EdgeRDDImpl * edge to provide the triplet view. Shipping of the vertex attributes is managed by * `impl.ReplicatedVertexView`. */ -abstract class EdgeRDD[ED, VD]( +abstract class EdgeRDD[ED]( @transient sc: SparkContext, @transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { - private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] + private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD } override protected def getPartitions: Array[Partition] = partitionsRDD.partitions override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = { - val p = firstParent[(PartitionID, EdgePartition[ED, VD])].iterator(part, context) + val p = firstParent[(PartitionID, EdgePartition[ED, _])].iterator(part, context) if (p.hasNext) { p.next._2.iterator.map(_.copy()) } else { @@ -60,19 +61,14 @@ abstract class EdgeRDD[ED, VD]( * @param f the function from an edge to a new edge value * @return a new EdgeRDD containing the new edge values */ - def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2, VD] + def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2] /** * Reverse all the edges in this RDD. * * @return a new EdgeRDD containing all the edges reversed */ - def reverse: EdgeRDD[ED, VD] - - /** Removes all edges but those matching `epred` and where both vertices match `vpred`. */ - def filter( - epred: EdgeTriplet[VD, ED] => Boolean, - vpred: (VertexId, VD) => Boolean): EdgeRDD[ED, VD] + def reverse: EdgeRDD[ED] /** * Inner joins this EdgeRDD with another EdgeRDD, assuming both are partitioned using the same @@ -84,15 +80,8 @@ abstract class EdgeRDD[ED, VD]( * with values supplied by `f` */ def innerJoin[ED2: ClassTag, ED3: ClassTag] - (other: EdgeRDD[ED2, _]) - (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3, VD] - - private[graphx] def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag]( - f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDD[ED2, VD2] - - /** Replaces the edge partitions while preserving all other properties of the EdgeRDD. */ - private[graphx] def withPartitionsRDD[ED2: ClassTag, VD2: ClassTag]( - partitionsRDD: RDD[(PartitionID, EdgePartition[ED2, VD2])]): EdgeRDD[ED2, VD2] + (other: EdgeRDD[ED2]) + (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3] /** * Changes the target storage level while preserving all other properties of the @@ -101,7 +90,7 @@ abstract class EdgeRDD[ED, VD]( * This does not actually trigger a cache; to do this, call * [[org.apache.spark.graphx.EdgeRDD#cache]] on the returned EdgeRDD. */ - private[graphx] def withTargetStorageLevel(targetStorageLevel: StorageLevel): EdgeRDD[ED, VD] + private[graphx] def withTargetStorageLevel(targetStorageLevel: StorageLevel): EdgeRDD[ED] } object EdgeRDD { @@ -111,7 +100,7 @@ object EdgeRDD { * @tparam ED the edge attribute type * @tparam VD the type of the vertex attributes that may be joined with the returned EdgeRDD */ - def fromEdges[ED: ClassTag, VD: ClassTag](edges: RDD[Edge[ED]]): EdgeRDD[ED, VD] = { + def fromEdges[ED: ClassTag, VD: ClassTag](edges: RDD[Edge[ED]]): EdgeRDDImpl[ED, VD] = { val edgePartitions = edges.mapPartitionsWithIndex { (pid, iter) => val builder = new EdgePartitionBuilder[ED, VD] iter.foreach { e => @@ -128,8 +117,8 @@ object EdgeRDD { * @tparam ED the edge attribute type * @tparam VD the type of the vertex attributes that may be joined with the returned EdgeRDD */ - def fromEdgePartitions[ED: ClassTag, VD: ClassTag]( - edgePartitions: RDD[(Int, EdgePartition[ED, VD])]): EdgeRDD[ED, VD] = { + private[graphx] def fromEdgePartitions[ED: ClassTag, VD: ClassTag]( + edgePartitions: RDD[(Int, EdgePartition[ED, VD])]): EdgeRDDImpl[ED, VD] = { new EdgeRDDImpl(edgePartitions) } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 2c1b9518a3d16..637791543514c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -59,7 +59,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * along with their vertex data. * */ - @transient val edges: EdgeRDD[ED, VD] + @transient val edges: EdgeRDD[ED] /** * An RDD containing the edge triplets, which are edges along with the vertex data associated with diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index d5150382d599b..116d1ea700175 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -129,15 +129,15 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))) ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))) }, - (a, b) => a ++ b, TripletFields.SrcDstOnly) + (a, b) => a ++ b, TripletFields.All) case EdgeDirection.In => graph.aggregateMessages[Array[(VertexId,VD)]]( ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))), - (a, b) => a ++ b, TripletFields.SrcOnly) + (a, b) => a ++ b, TripletFields.Src) case EdgeDirection.Out => graph.aggregateMessages[Array[(VertexId,VD)]]( ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))), - (a, b) => a ++ b, TripletFields.DstOnly) + (a, b) => a ++ b, TripletFields.Dst) case EdgeDirection.Both => throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" + "EdgeDirection.Either instead.") diff --git a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java index 34df4b7ee7a06..7eb4ae0f44602 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java +++ b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java @@ -24,10 +24,17 @@ * system to populate only those fields for efficiency. */ public class TripletFields implements Serializable { + + /** Indicates whether the source vertex attribute is included. */ public final boolean useSrc; + + /** Indicates whether the destination vertex attribute is included. */ public final boolean useDst; + + /** Indicates whether the edge attribute is included. */ public final boolean useEdge; + /** Constructs a default TripletFields in which all fields are included. */ public TripletFields() { this(true, true, true); } @@ -38,14 +45,28 @@ public TripletFields(boolean useSrc, boolean useDst, boolean useEdge) { this.useEdge = useEdge; } + /** + * None of the triplet fields are exposed. + */ public static final TripletFields None = new TripletFields(false, false, false); + + /** + * Expose only the edge field and not the source or destination field. + */ public static final TripletFields EdgeOnly = new TripletFields(false, false, true); - public static final TripletFields SrcOnly = new TripletFields(true, false, false); - public static final TripletFields DstOnly = new TripletFields(false, true, false); - public static final TripletFields SrcDstOnly = new TripletFields(true, true, false); - public static final TripletFields SrcAndEdge = new TripletFields(true, false, true); - public static final TripletFields Src = SrcAndEdge; - public static final TripletFields DstAndEdge = new TripletFields(false, true, true); - public static final TripletFields Dst = DstAndEdge; + + /** + * Expose the source and edge fields but not the destination field. (Same as Src) + */ + public static final TripletFields Src = new TripletFields(true, false, true); + + /** + * Expose the destination and edge fields but not the source field. (Same as Dst) + */ + public static final TripletFields Dst = new TripletFields(false, true, true); + + /** + * Expose all the fields (source, edge, and destination). + */ public static final TripletFields All = new TripletFields(true, true, true); } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index f8be17669d892..1db3df03c8052 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -207,7 +207,7 @@ abstract class VertexRDD[VD]( def reverseRoutingTables(): VertexRDD[VD] /** Prepares this VertexRDD for efficient joins with the given EdgeRDD. */ - def withEdges(edges: EdgeRDD[_, _]): VertexRDD[VD] + def withEdges(edges: EdgeRDD[_]): VertexRDD[VD] /** Replaces the vertex partitions while preserving all other properties of the VertexRDD. */ private[graphx] def withPartitionsRDD[VD2: ClassTag]( @@ -269,7 +269,7 @@ object VertexRDD { * @param defaultVal the vertex attribute to use when creating missing vertices */ def apply[VD: ClassTag]( - vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_, _], defaultVal: VD): VertexRDD[VD] = { + vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_], defaultVal: VD): VertexRDD[VD] = { VertexRDD(vertices, edges, defaultVal, (a, b) => a) } @@ -286,7 +286,7 @@ object VertexRDD { * @param mergeFunc the commutative, associative duplicate vertex attribute merge function */ def apply[VD: ClassTag]( - vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_, _], defaultVal: VD, mergeFunc: (VD, VD) => VD + vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_], defaultVal: VD, mergeFunc: (VD, VD) => VD ): VertexRDD[VD] = { val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match { case Some(p) => vertices @@ -314,7 +314,7 @@ object VertexRDD { * @param defaultVal the vertex attribute to use when creating missing vertices */ def fromEdges[VD: ClassTag]( - edges: EdgeRDD[_, _], numPartitions: Int, defaultVal: VD): VertexRDD[VD] = { + edges: EdgeRDD[_], numPartitions: Int, defaultVal: VD): VertexRDD[VD] = { val routingTables = createRoutingTables(edges, new HashPartitioner(numPartitions)) val vertexPartitions = routingTables.mapPartitions({ routingTableIter => val routingTable = @@ -325,7 +325,7 @@ object VertexRDD { } private[graphx] def createRoutingTables( - edges: EdgeRDD[_, _], vertexPartitioner: Partitioner): RDD[RoutingTablePartition] = { + edges: EdgeRDD[_], vertexPartitioner: Partitioner): RDD[RoutingTablePartition] = { // Determine which vertices each edge partition needs by creating a mapping from vid to pid. val vid2pid = edges.partitionsRDD.mapPartitions(_.flatMap( Function.tupled(RoutingTablePartition.edgePartitionToMsgs))) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala index 4100a85d17ee3..a8169613b4fd2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -28,7 +28,7 @@ import org.apache.spark.graphx._ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( override val partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])], val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) - extends EdgeRDD[ED, VD](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { + extends EdgeRDD[ED](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { override def setName(_name: String): this.type = { if (partitionsRDD.name != null) { @@ -75,20 +75,20 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( partitionsRDD.map(_._2.size.toLong).reduce(_ + _) } - override def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2, VD] = + override def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDDImpl[ED2, VD] = mapEdgePartitions((pid, part) => part.map(f)) - override def reverse: EdgeRDD[ED, VD] = mapEdgePartitions((pid, part) => part.reverse) + override def reverse: EdgeRDDImpl[ED, VD] = mapEdgePartitions((pid, part) => part.reverse) - override def filter( + def filter( epred: EdgeTriplet[VD, ED] => Boolean, - vpred: (VertexId, VD) => Boolean): EdgeRDD[ED, VD] = { + vpred: (VertexId, VD) => Boolean): EdgeRDDImpl[ED, VD] = { mapEdgePartitions((pid, part) => part.filter(epred, vpred)) } override def innerJoin[ED2: ClassTag, ED3: ClassTag] - (other: EdgeRDD[ED2, _]) - (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3, VD] = { + (other: EdgeRDD[ED2]) + (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDDImpl[ED3, VD] = { val ed2Tag = classTag[ED2] val ed3Tag = classTag[ED3] this.withPartitionsRDD[ED3, VD](partitionsRDD.zipPartitions(other.partitionsRDD, true) { @@ -99,8 +99,8 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( }) } - override private[graphx] def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag]( - f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDD[ED2, VD2] = { + def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag]( + f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDDImpl[ED2, VD2] = { this.withPartitionsRDD[ED2, VD2](partitionsRDD.mapPartitions({ iter => if (iter.hasNext) { val (pid, ep) = iter.next() @@ -111,13 +111,13 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( }, preservesPartitioning = true)) } - override private[graphx] def withPartitionsRDD[ED2: ClassTag, VD2: ClassTag]( - partitionsRDD: RDD[(PartitionID, EdgePartition[ED2, VD2])]): EdgeRDD[ED2, VD2] = { + private[graphx] def withPartitionsRDD[ED2: ClassTag, VD2: ClassTag]( + partitionsRDD: RDD[(PartitionID, EdgePartition[ED2, VD2])]): EdgeRDDImpl[ED2, VD2] = { new EdgeRDDImpl(partitionsRDD, this.targetStorageLevel) } override private[graphx] def withTargetStorageLevel( - targetStorageLevel: StorageLevel): EdgeRDD[ED, VD] = { + targetStorageLevel: StorageLevel): EdgeRDDImpl[ED, VD] = { new EdgeRDDImpl(this.partitionsRDD, targetStorageLevel) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 2b4636a6c6ddf..0eae2a673874a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -43,7 +43,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( /** Default constructor is provided to support serialization */ protected def this() = this(null, null) - @transient override val edges: EdgeRDD[ED, VD] = replicatedVertexView.edges + @transient override val edges: EdgeRDDImpl[ED, VD] = replicatedVertexView.edges /** Return a RDD that brings edges together with their source and destination vertices. */ @transient override lazy val triplets: RDD[EdgeTriplet[VD, ED]] = { @@ -323,9 +323,10 @@ object GraphImpl { */ def apply[VD: ClassTag, ED: ClassTag]( vertices: VertexRDD[VD], - edges: EdgeRDD[ED, _]): GraphImpl[VD, ED] = { + edges: EdgeRDD[ED]): GraphImpl[VD, ED] = { // Convert the vertex partitions in edges to the correct type - val newEdges = edges.mapEdgePartitions((pid, part) => part.withoutVertexAttributes[VD]) + val newEdges = edges.asInstanceOf[EdgeRDDImpl[ED, _]] + .mapEdgePartitions((pid, part) => part.withoutVertexAttributes[VD]) GraphImpl.fromExistingRDDs(vertices, newEdges) } @@ -336,8 +337,8 @@ object GraphImpl { */ def fromExistingRDDs[VD: ClassTag, ED: ClassTag]( vertices: VertexRDD[VD], - edges: EdgeRDD[ED, VD]): GraphImpl[VD, ED] = { - new GraphImpl(vertices, new ReplicatedVertexView(edges)) + edges: EdgeRDD[ED]): GraphImpl[VD, ED] = { + new GraphImpl(vertices, new ReplicatedVertexView(edges.asInstanceOf[EdgeRDDImpl[ED, VD]])) } /** @@ -345,7 +346,7 @@ object GraphImpl { * `defaultVertexAttr`. The vertices will have the same number of partitions as the EdgeRDD. */ private def fromEdgeRDD[VD: ClassTag, ED: ClassTag]( - edges: EdgeRDD[ED, VD], + edges: EdgeRDDImpl[ED, VD], defaultVertexAttr: VD, edgeStorageLevel: StorageLevel, vertexStorageLevel: StorageLevel): GraphImpl[VD, ED] = { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index 86b366eb9202b..8ab255bd4038c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -33,7 +33,7 @@ import org.apache.spark.graphx._ */ private[impl] class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( - var edges: EdgeRDD[ED, VD], + var edges: EdgeRDDImpl[ED, VD], var hasSrcId: Boolean = false, var hasDstId: Boolean = false) { @@ -42,7 +42,7 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( * shipping level. */ def withEdges[VD2: ClassTag, ED2: ClassTag]( - edges_ : EdgeRDD[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = { + edges_ : EdgeRDDImpl[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = { new ReplicatedVertexView(edges_, hasSrcId, hasDstId) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index 08405629bc052..d92a55a189298 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -172,7 +172,7 @@ class VertexRDDImpl[VD] private[graphx] ( override def reverseRoutingTables(): VertexRDD[VD] = this.mapVertexPartitions(vPart => vPart.withRoutingTable(vPart.routingTable.reverse)) - override def withEdges(edges: EdgeRDD[_, _]): VertexRDD[VD] = { + override def withEdges(edges: EdgeRDD[_]): VertexRDD[VD] = { val routingTables = VertexRDD.createRoutingTables(edges, this.partitioner.get) val vertexPartitions = partitionsRDD.zipPartitions(routingTables, true) { (partIter, routingTableIter) => diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index e40ae0d615466..e139959c3f5c1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -85,7 +85,7 @@ object PageRank extends Logging { // Associate the degree with each vertex .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) } // Set the weight on the edges based on the degree - .mapTriplets( e => 1.0 / e.srcAttr, TripletFields.SrcOnly ) + .mapTriplets( e => 1.0 / e.srcAttr, TripletFields.Src ) // Set the vertex attributes to the initial pagerank values .mapVertices( (id, attr) => resetProb ) @@ -97,7 +97,7 @@ object PageRank extends Logging { // Compute the outgoing rank contributions of each vertex, perform local preaggregation, and // do the final aggregation at the receiving vertices. Requires a shuffle for aggregation. val rankUpdates = rankGraph.aggregateMessages[Double]( - ctx => ctx.sendToDst(ctx.srcAttr * ctx.attr), _ + _, TripletFields.SrcAndEdge) + ctx => ctx.sendToDst(ctx.srcAttr * ctx.attr), _ + _, TripletFields.Src) // Apply the final rank updates to get the new ranks, using join to preserve ranks of vertices // that didn't receive a message. Requires a shuffle for broadcasting updated ranks to the diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index df773db6e4326..a05d1ddb21295 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -328,7 +328,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { "expected ctx.dstAttr to be null due to TripletFields, but it was " + ctx.dstAttr) } ctx.sendToDst(ctx.srcAttr) - }, _ + _, TripletFields.SrcOnly) + }, _ + _, TripletFields.Src) assert(agg.collect().toSet === (1 to n).map(x => (x: VertexId, "v")).toSet) } } diff --git a/make-distribution.sh b/make-distribution.sh index 2267b1aa08a6c..7c0fb8992a155 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -119,7 +119,7 @@ VERSION=$(mvn help:evaluate -Dexpression=project.version 2>/dev/null | grep -v " SPARK_HADOOP_VERSION=$(mvn help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ | grep -v "INFO"\ | tail -n 1) -SPARK_HIVE=$(mvn help:evaluate -Dexpression=project.activeProfiles $@ 2>/dev/null\ +SPARK_HIVE=$(mvn help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ | grep -v "INFO"\ | fgrep --count "hive";\ # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\ diff --git a/mllib/pom.xml b/mllib/pom.xml index dd68b27a78bdc..0a6dda0ab8c80 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index c8476a5370b6c..9f20cd5d00dcd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.api.python import java.io.OutputStream +import java.nio.{ByteBuffer, ByteOrder} import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ @@ -40,10 +41,10 @@ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.test.ChiSqTestResult -import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.{RandomForest, DecisionTree} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.impurity._ -import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -74,10 +75,28 @@ class PythonMLLibAPI extends Serializable { learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], initialWeights: Vector): JList[Object] = { - // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. - learner.disableUncachedWarning() - val model = learner.run(data.rdd, initialWeights) - List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + try { + val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights) + List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + } finally { + data.rdd.unpersist(blocking = false) + } + } + + /** + * Return the Updater from string + */ + def getUpdaterFromString(regType: String): Updater = { + if (regType == "l2") { + new SquaredL2Updater + } else if (regType == "l1") { + new L1Updater + } else if (regType == null || regType == "none") { + new SimpleUpdater + } else { + throw new IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: ['l1', 'l2', None].") + } } /** @@ -99,16 +118,7 @@ class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) - if (regType == "l2") { - lrAlg.optimizer.setUpdater(new SquaredL2Updater) - } else if (regType == "l1") { - lrAlg.optimizer.setUpdater(new L1Updater) - } else if (regType == null) { - lrAlg.optimizer.setUpdater(new SimpleUpdater) - } else { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: ['l1', 'l2', None].") - } + lrAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( lrAlg, data, @@ -178,16 +188,7 @@ class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) - if (regType == "l2") { - SVMAlg.optimizer.setUpdater(new SquaredL2Updater) - } else if (regType == "l1") { - SVMAlg.optimizer.setUpdater(new L1Updater) - } else if (regType == null) { - SVMAlg.optimizer.setUpdater(new SimpleUpdater) - } else { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: ['l1', 'l2', None].") - } + SVMAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( SVMAlg, data, @@ -213,16 +214,33 @@ class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) - if (regType == "l2") { - LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) - } else if (regType == "l1") { - LogRegAlg.optimizer.setUpdater(new L1Updater) - } else if (regType == null) { - LogRegAlg.optimizer.setUpdater(new SimpleUpdater) - } else { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: ['l1', 'l2', None].") - } + LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType)) + trainRegressionModel( + LogRegAlg, + data, + initialWeights) + } + + /** + * Java stub for Python mllib LogisticRegressionWithLBFGS.train() + */ + def trainLogisticRegressionModelWithLBFGS( + data: JavaRDD[LabeledPoint], + numIterations: Int, + initialWeights: Vector, + regParam: Double, + regType: String, + intercept: Boolean, + corrections: Int, + tolerance: Double): JList[Object] = { + val LogRegAlg = new LogisticRegressionWithLBFGS() + LogRegAlg.setIntercept(intercept) + LogRegAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setNumCorrections(corrections) + .setConvergenceTol(tolerance) + LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( LogRegAlg, data, @@ -254,9 +272,11 @@ class PythonMLLibAPI extends Serializable { .setMaxIterations(maxIterations) .setRuns(runs) .setInitializationMode(initializationMode) - // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. - .disableUncachedWarning() - kMeansAlg.run(data.rdd) + try { + kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) + } finally { + data.rdd.unpersist(blocking = false) + } } /** @@ -390,16 +410,18 @@ class PythonMLLibAPI extends Serializable { numPartitions: Int, numIterations: Int, seed: Long): Word2VecModelWrapper = { - val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER) val word2vec = new Word2Vec() .setVectorSize(vectorSize) .setLearningRate(learningRate) .setNumPartitions(numPartitions) .setNumIterations(numIterations) .setSeed(seed) - val model = word2vec.fit(data) - data.unpersist() - new Word2VecModelWrapper(model) + try { + val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)) + new Word2VecModelWrapper(model) + } finally { + dataJRDD.rdd.unpersist(blocking = false) + } } private[python] class Word2VecModelWrapper(model: Word2VecModel) { @@ -460,8 +482,50 @@ class PythonMLLibAPI extends Serializable { categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap, minInstancesPerNode = minInstancesPerNode, minInfoGain = minInfoGain) + try { + DecisionTree.train(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), strategy) + } finally { + data.rdd.unpersist(blocking = false) + } + } - DecisionTree.train(data.rdd, strategy) + /** + * Java stub for Python mllib RandomForest.train(). + * This stub returns a handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on exit; + * see the Py4J documentation. + */ + def trainRandomForestModel( + data: JavaRDD[LabeledPoint], + algoStr: String, + numClasses: Int, + categoricalFeaturesInfo: JMap[Int, Int], + numTrees: Int, + featureSubsetStrategy: String, + impurityStr: String, + maxDepth: Int, + maxBins: Int, + seed: Int): RandomForestModel = { + + val algo = Algo.fromString(algoStr) + val impurity = Impurities.fromString(impurityStr) + val strategy = new Strategy( + algo = algo, + impurity = impurity, + maxDepth = maxDepth, + numClassesForClassification = numClasses, + maxBins = maxBins, + categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap) + val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK) + try { + if (algo == Algo.Classification) { + RandomForest.trainClassifier(cached, strategy, numTrees, featureSubsetStrategy, seed) + } else { + RandomForest.trainRegressor(cached, strategy, numTrees, featureSubsetStrategy, seed) + } + } finally { + cached.unpersist(blocking = false) + } } /** @@ -621,6 +685,7 @@ class PythonMLLibAPI extends Serializable { private[spark] object SerDe extends Serializable { val PYSPARK_PACKAGE = "pyspark.mllib" + val LATIN1 = "ISO-8859-1" /** * Base class used for pickle @@ -642,7 +707,7 @@ private[spark] object SerDe extends Serializable { def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { if (obj == this) { out.write(Opcodes.GLOBAL) - out.write((module + "\n" + name + "\n").getBytes()) + out.write((module + "\n" + name + "\n").getBytes) } else { pickler.save(this) // it will be memorized by Pickler saveState(obj, out, pickler) @@ -672,7 +737,16 @@ private[spark] object SerDe extends Serializable { def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { val vector: DenseVector = obj.asInstanceOf[DenseVector] - saveObjects(out, pickler, vector.toArray) + val bytes = new Array[Byte](8 * vector.size) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + val db = bb.asDoubleBuffer() + db.put(vector.values) + + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(bytes.length)) + out.write(bytes) + out.write(Opcodes.TUPLE1) } def construct(args: Array[Object]): Object = { @@ -680,7 +754,13 @@ private[spark] object SerDe extends Serializable { if (args.length != 1) { throw new PickleException("should be 1") } - new DenseVector(args(0).asInstanceOf[Array[Double]]) + val bytes = args(0).asInstanceOf[String].getBytes(LATIN1) + val bb = ByteBuffer.wrap(bytes, 0, bytes.length) + bb.order(ByteOrder.nativeOrder()) + val db = bb.asDoubleBuffer() + val ans = new Array[Double](bytes.length / 8) + db.get(ans) + Vectors.dense(ans) } } @@ -689,15 +769,30 @@ private[spark] object SerDe extends Serializable { def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { val m: DenseMatrix = obj.asInstanceOf[DenseMatrix] - saveObjects(out, pickler, m.numRows, m.numCols, m.values) + val bytes = new Array[Byte](8 * m.values.size) + val order = ByteOrder.nativeOrder() + ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values) + + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(m.numRows)) + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(m.numCols)) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(bytes.length)) + out.write(bytes) + out.write(Opcodes.TUPLE3) } def construct(args: Array[Object]): Object = { if (args.length != 3) { throw new PickleException("should be 3") } - new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], - args(2).asInstanceOf[Array[Double]]) + val bytes = args(2).asInstanceOf[String].getBytes(LATIN1) + val n = bytes.length / 8 + val values = new Array[Double](n) + val order = ByteOrder.nativeOrder() + ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values) + new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values) } } @@ -706,15 +801,40 @@ private[spark] object SerDe extends Serializable { def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { val v: SparseVector = obj.asInstanceOf[SparseVector] - saveObjects(out, pickler, v.size, v.indices, v.values) + val n = v.indices.size + val indiceBytes = new Array[Byte](4 * n) + val order = ByteOrder.nativeOrder() + ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().put(v.indices) + val valueBytes = new Array[Byte](8 * n) + ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().put(v.values) + + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(v.size)) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(indiceBytes.length)) + out.write(indiceBytes) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(valueBytes.length)) + out.write(valueBytes) + out.write(Opcodes.TUPLE3) } def construct(args: Array[Object]): Object = { if (args.length != 3) { throw new PickleException("should be 3") } - new SparseVector(args(0).asInstanceOf[Int], args(1).asInstanceOf[Array[Int]], - args(2).asInstanceOf[Array[Double]]) + val size = args(0).asInstanceOf[Int] + val indiceBytes = args(1).asInstanceOf[String].getBytes(LATIN1) + val valueBytes = args(2).asInstanceOf[String].getBytes(LATIN1) + val n = indiceBytes.length / 4 + val indices = new Array[Int](n) + val values = new Array[Double](n) + if (n > 0) { + val order = ByteOrder.nativeOrder() + ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().get(indices) + ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().get(values) + } + new SparseVector(size, indices, values) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 18b95f1edc0b0..94d757bc317ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -64,7 +64,7 @@ class LogisticRegressionModel ( val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept val score = 1.0 / (1.0 + math.exp(-margin)) threshold match { - case Some(t) => if (score < t) 0.0 else 1.0 + case Some(t) => if (score > t) 1.0 else 0.0 case None => score } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index ab9515b2a6db8..dd514ff8a37f2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -65,7 +65,7 @@ class SVMModel ( intercept: Double) = { val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept threshold match { - case Some(t) => if (margin < t) 0.0 else 1.0 + case Some(t) => if (margin > t) 1.0 else 0.0 case None => margin } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 7443f232ec3e7..34ea0de706f08 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -113,22 +113,13 @@ class KMeans private ( this } - /** Whether a warning should be logged if the input RDD is uncached. */ - private var warnOnUncachedInput = true - - /** Disable warnings about uncached input. */ - private[spark] def disableUncachedWarning(): this.type = { - warnOnUncachedInput = false - this - } - /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. */ def run(data: RDD[Vector]): KMeansModel = { - if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) { + if (data.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } @@ -143,7 +134,7 @@ class KMeans private ( norms.unpersist() // Warn at the end of the run as well, for increased visibility. - if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) { + if (data.getStorageLevel == StorageLevel.NONE) { logWarning("The input data was not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index dfad25d57c947..a9c2e23717896 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -17,10 +17,10 @@ package org.apache.spark.mllib.feature -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => brzNorm} +import breeze.linalg.{norm => brzNorm} import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} /** * :: Experimental :: @@ -47,22 +47,31 @@ class Normalizer(p: Double) extends VectorTransformer { * @return normalized vector. If the norm of the input is zero, it will return the input vector. */ override def transform(vector: Vector): Vector = { - var norm = brzNorm(vector.toBreeze, p) + val norm = brzNorm(vector.toBreeze, p) if (norm != 0.0) { // For dense vector, we've to allocate new memory for new output vector. // However, for sparse vector, the `index` array will not be changed, // so we can re-use it to save memory. - vector.toBreeze match { - case dv: BDV[Double] => Vectors.fromBreeze(dv :/ norm) - case sv: BSV[Double] => - val output = new BSV[Double](sv.index, sv.data.clone(), sv.length) + vector match { + case dv: DenseVector => + val values = dv.values.clone() + val size = values.size var i = 0 - while (i < output.data.length) { - output.data(i) /= norm + while (i < size) { + values(i) /= norm i += 1 } - Vectors.fromBreeze(output) + Vectors.dense(values) + case sv: SparseVector => + val values = sv.values.clone() + val nnz = values.size + var i = 0 + while (i < nnz) { + values(i) /= norm + i += 1 + } + Vectors.sparse(sv.size, sv.indices, values) case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 4dfd1f0ab8134..8c4c5db5258d5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -17,11 +17,9 @@ package org.apache.spark.mllib.feature -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} - import org.apache.spark.Logging import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD @@ -77,8 +75,8 @@ class StandardScalerModel private[mllib] ( require(mean.size == variance.size) - private lazy val factor: BDV[Double] = { - val f = BDV.zeros[Double](variance.size) + private lazy val factor: Array[Double] = { + val f = Array.ofDim[Double](variance.size) var i = 0 while (i < f.size) { f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0 @@ -87,6 +85,11 @@ class StandardScalerModel private[mllib] ( f } + // Since `shift` will be only used in `withMean` branch, we have it as + // `lazy val` so it will be evaluated in that branch. Note that we don't + // want to create this array multiple times in `transform` function. + private lazy val shift: Array[Double] = mean.toArray + /** * Applies standardization transformation on a vector. * @@ -97,30 +100,57 @@ class StandardScalerModel private[mllib] ( override def transform(vector: Vector): Vector = { require(mean.size == vector.size) if (withMean) { - vector.toBreeze match { - case dv: BDV[Double] => - val output = vector.toBreeze.copy - var i = 0 - while (i < output.length) { - output(i) = (output(i) - mean(i)) * (if (withStd) factor(i) else 1.0) - i += 1 + // By default, Scala generates Java methods for member variables. So every time when + // the member variables are accessed, `invokespecial` will be called which is expensive. + // This can be avoid by having a local reference of `shift`. + val localShift = shift + vector match { + case dv: DenseVector => + val values = dv.values.clone() + val size = values.size + if (withStd) { + // Having a local reference of `factor` to avoid overhead as the comment before. + val localFactor = factor + var i = 0 + while (i < size) { + values(i) = (values(i) - localShift(i)) * localFactor(i) + i += 1 + } + } else { + var i = 0 + while (i < size) { + values(i) -= localShift(i) + i += 1 + } } - Vectors.fromBreeze(output) + Vectors.dense(values) case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else if (withStd) { - vector.toBreeze match { - case dv: BDV[Double] => Vectors.fromBreeze(dv :* factor) - case sv: BSV[Double] => + // Having a local reference of `factor` to avoid overhead as the comment before. + val localFactor = factor + vector match { + case dv: DenseVector => + val values = dv.values.clone() + val size = values.size + var i = 0 + while(i < size) { + values(i) *= localFactor(i) + i += 1 + } + Vectors.dense(values) + case sv: SparseVector => // For sparse vector, the `index` array inside sparse vector object will not be changed, // so we can re-use it to save memory. - val output = new BSV[Double](sv.index, sv.data.clone(), sv.length) + val indices = sv.indices + val values = sv.values.clone() + val nnz = values.size var i = 0 - while (i < output.data.length) { - output.data(i) *= factor(output.index(i)) + while (i < nnz) { + values(i) *= localFactor(indices(i)) i += 1 } - Vectors.fromBreeze(output) + Vectors.sparse(sv.size, indices, values) case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index f5f7ad613d4c4..7960f3cab576f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -461,4 +461,11 @@ class Word2VecModel private[mllib] ( .tail .toArray } + + /** + * Returns a map of words to their vector representations. + */ + def getVectors: Map[String, Array[Float]] = { + model + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index a6f3343f56b86..0fa2f600721ed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -22,7 +22,7 @@ import breeze.linalg.{Matrix => BM, DenseMatrix => BDM, CSCMatrix => BSM} import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.util.Utils -import java.util.Arrays +import java.util.{Random, Arrays} import scala.collection.mutable.ArrayBuffer /** @@ -69,14 +69,14 @@ sealed trait Matrix extends Serializable { } /** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */ - def transposeMultiply(y: DenseMatrix): DenseMatrix = { + private[mllib] def transposeMultiply(y: DenseMatrix): DenseMatrix = { val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix] BLAS.gemm(true, false, 1.0, this, y, 0.0, C) C } /** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */ - def transposeMultiply(y: DenseVector): DenseVector = { + private[mllib] def transposeMultiply(y: DenseVector): DenseVector = { val output = new DenseVector(new Array[Double](numCols)) BLAS.gemv(true, 1.0, this, y, 0.0, output) output @@ -189,44 +189,22 @@ object DenseMatrix { * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix - * @param seed the seed seed for the random number generator + * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ - def rand(numRows: Int, numCols: Int, seed: Long): DenseMatrix = { - val rand = new XORShiftRandom(seed) - new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextDouble())) - } - - /** - * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. - * @param numRows number of rows of the matrix - * @param numCols number of columns of the matrix - * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) - */ - def rand(numRows: Int, numCols: Int): DenseMatrix = { - rand(numRows, numCols, Utils.random.nextLong()) + def rand(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble())) } /** * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix - * @param seed the seed seed for the random number generator + * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ - def randn(numRows: Int, numCols: Int, seed: Long): DenseMatrix = { - val rand = new XORShiftRandom(seed) - new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextGaussian())) - } - - /** - * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. - * @param numRows number of rows of the matrix - * @param numCols number of columns of the matrix - * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) - */ - def randn(numRows: Int, numCols: Int): DenseMatrix = { - randn(numRows, numCols, Utils.random.nextLong()) + def randn(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian())) } /** @@ -389,81 +367,49 @@ object SparseMatrix { * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param density the desired density for the matrix - * @param seed the seed for the random generator + * @param rng a random number generator * @return `SparseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ - def sprand( - numRows: Int, - numCols: Int, - density: Double, - seed: Long): SparseMatrix = { + def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = { require(density > 0.0 && density < 1.0, "density must be a double in the range " + s"0.0 < d < 1.0. Currently, density: $density") - val rand = new XORShiftRandom(seed) val length = numRows * numCols val rawA = new Array[Double](length) var nnz = 0 for (i <- 0 until length) { - val p = rand.nextDouble() + val p = rng.nextDouble() if (p <= density) { - rawA.update(i, rand.nextDouble()) + rawA.update(i, rng.nextDouble()) nnz += 1 } } genRand(numRows, numCols, rawA, nnz) } - /** - * Generate a `SparseMatrix` consisting of i.i.d. uniform random numbers. - * @param numRows number of rows of the matrix - * @param numCols number of columns of the matrix - * @param density the desired density for the matrix - * @return `SparseMatrix` with size `numRows` x `numCols` and values in U(0, 1) - */ - def sprand(numRows: Int, numCols: Int, density: Double): SparseMatrix = { - sprand(numRows, numCols, density, Utils.random.nextLong()) - } - /** * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param density the desired density for the matrix - * @param seed the seed for the random generator + * @param rng a random number generator * @return `SparseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ - def sprandn( - numRows: Int, - numCols: Int, - density: Double, - seed: Long): SparseMatrix = { + def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = { require(density > 0.0 && density < 1.0, "density must be a double in the range " + s"0.0 < d < 1.0. Currently, density: $density") - val rand = new XORShiftRandom(seed) val length = numRows * numCols val rawA = new Array[Double](length) var nnz = 0 for (i <- 0 until length) { - val p = rand.nextDouble() + val p = rng.nextDouble() if (p <= density) { - rawA.update(i, rand.nextGaussian()) + rawA.update(i, rng.nextGaussian()) nnz += 1 } } genRand(numRows, numCols, rawA, nnz) } - /** - * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. - * @param numRows number of rows of the matrix - * @param numCols number of columns of the matrix - * @param density the desired density for the matrix - * @return `SparseMatrix` with size `numRows` x `numCols` and values in N(0, 1) - */ - def sprandn(numRows: Int, numCols: Int, density: Double): SparseMatrix = { - sprandn(numRows, numCols, density, Utils.random.nextLong()) - } - /** * Generate a diagonal matrix in `DenseMatrix` format from the supplied values. * @param vector a `Vector` that will form the values on the diagonal of the matrix @@ -606,79 +552,43 @@ object Matrices { * Generate a dense `Matrix` consisting of i.i.d. uniform random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix - * @param seed the seed for the random generator + * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ - def rand(numRows: Int, numCols: Int, seed: Long): Matrix = - DenseMatrix.rand(numRows, numCols, seed) - - /** - * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. - * @param numRows number of rows of the matrix - * @param numCols number of columns of the matrix - * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) - */ - def rand(numRows: Int, numCols: Int): Matrix = DenseMatrix.rand(numRows, numCols) - - /** - * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. - * @param numRows number of rows of the matrix - * @param numCols number of columns of the matrix - * @param density the desired density for the matrix - * @param seed the seed for the random generator - * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1) - */ - def sprand(numRows: Int, numCols: Int, density: Double, seed: Long): Matrix = - SparseMatrix.sprand(numRows, numCols, density, seed) + def rand(numRows: Int, numCols: Int, rng: Random): Matrix = + DenseMatrix.rand(numRows, numCols, rng) /** * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param density the desired density for the matrix + * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1) */ - def sprand(numRows: Int, numCols: Int, density: Double): Matrix = - SparseMatrix.sprand(numRows, numCols, density) - - /** - * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. - * @param numRows number of rows of the matrix - * @param numCols number of columns of the matrix - * @param seed the seed for the random generator - * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) - */ - def randn(numRows: Int, numCols: Int, seed: Long): Matrix = - DenseMatrix.randn(numRows, numCols, seed) + def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix = + SparseMatrix.sprand(numRows, numCols, density, rng) /** * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix + * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ - def randn(numRows: Int, numCols: Int): Matrix = DenseMatrix.randn(numRows, numCols) - - /** - * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. - * @param numRows number of rows of the matrix - * @param numCols number of columns of the matrix - * @param density the desired density for the matrix - * @param seed the seed for the random generator - * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1) - */ - def sprandn(numRows: Int, numCols: Int, density: Double, seed: Long): Matrix = - SparseMatrix.sprandn(numRows, numCols, density, seed) + def randn(numRows: Int, numCols: Int, rng: Random): Matrix = + DenseMatrix.randn(numRows, numCols, rng) /** * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param density the desired density for the matrix + * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1) */ - def sprandn(numRows: Int, numCols: Int, density: Double): Matrix = - SparseMatrix.sprandn(numRows, numCols, density) + def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix = + SparseMatrix.sprandn(numRows, numCols, density, rng) /** * Generate a diagonal matrix in `DenseMatrix` format from the supplied values. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 9fccd6341ba7d..c6d5fe5bc678c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -76,6 +76,15 @@ sealed trait Vector extends Serializable { def copy: Vector = { throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") } + + /** + * Applies a function `f` to all the active elements of dense and sparse vector. + * + * @param f the function takes two parameters where the first parameter is the index of + * the vector with type `Int`, and the second parameter is the corresponding value + * with type `Double`. + */ + private[spark] def foreachActive(f: (Int, Double) => Unit) } /** @@ -237,7 +246,7 @@ object Vectors { private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = { breezeVector match { case v: BDV[Double] => - if (v.offset == 0 && v.stride == 1) { + if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) { new DenseVector(v.data) } else { new DenseVector(v.toArray) // Can't use underlying array directly, so make a new one @@ -273,6 +282,17 @@ class DenseVector(val values: Array[Double]) extends Vector { override def copy: DenseVector = { new DenseVector(values.clone()) } + + private[spark] override def foreachActive(f: (Int, Double) => Unit) = { + var i = 0 + val localValuesSize = values.size + val localValues = values + + while (i < localValuesSize) { + f(i, localValues(i)) + i += 1 + } + } } /** @@ -309,4 +329,16 @@ class SparseVector( } private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) + + private[spark] override def foreachActive(f: (Int, Double) => Unit) = { + var i = 0 + val localValuesSize = values.size + val localIndices = indices + val localValues = values + + while (i < localValuesSize) { + f(localIndices(i), localValues(i)) + i += 1 + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index a6912056395d7..0857877951c82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -160,14 +160,15 @@ object GradientDescent extends Logging { val stochasticLossHistory = new ArrayBuffer[Double](numIterations) val numExamples = data.count() - val miniBatchSize = numExamples * miniBatchFraction // if no data, return initial weights to avoid NaNs if (numExamples == 0) { - - logInfo("GradientDescent.runMiniBatchSGD returning initial weights, no data found") + logWarning("GradientDescent.runMiniBatchSGD returning initial weights, no data found") return (initialWeights, stochasticLossHistory.toArray) + } + if (numExamples * miniBatchFraction < 1) { + logWarning("The miniBatchFraction is too small") } // Initialize weights as a column vector @@ -185,25 +186,31 @@ object GradientDescent extends Logging { val bcWeights = data.context.broadcast(weights) // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) - val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i) - .treeAggregate((BDV.zeros[Double](n), 0.0))( - seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => - val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad)) - (grad, loss + l) + val (gradientSum, lossSum, miniBatchSize) = data.sample(false, miniBatchFraction, 42 + i) + .treeAggregate((BDV.zeros[Double](n), 0.0, 0L))( + seqOp = (c, v) => { + // c: (grad, loss, count), v: (label, features) + val l = gradient.compute(v._2, v._1, bcWeights.value, Vectors.fromBreeze(c._1)) + (c._1, c._2 + l, c._3 + 1) }, - combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => - (grad1 += grad2, loss1 + loss2) + combOp = (c1, c2) => { + // c: (grad, loss, count) + (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3) }) - /** - * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration - * and regVal is the regularization value computed in the previous iteration as well. - */ - stochasticLossHistory.append(lossSum / miniBatchSize + regVal) - val update = updater.compute( - weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam) - weights = update._1 - regVal = update._2 + if (miniBatchSize > 0) { + /** + * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration + * and regVal is the regularization value computed in the previous iteration as well. + */ + stochasticLossHistory.append(lossSum / miniBatchSize + regVal) + val update = updater.compute( + weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), stepSize, i, regParam) + weights = update._1 + regVal = update._2 + } else { + logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero") + } } logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 038edc3521f14..90ac252226006 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -746,7 +746,7 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into - * @param alpha confidence parameter (only applies when immplicitPrefs = true) + * @param alpha confidence parameter * @param seed random seed */ def trainImplicit( @@ -773,7 +773,7 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into - * @param alpha confidence parameter (only applies when immplicitPrefs = true) + * @param alpha confidence parameter */ def trainImplicit( ratings: RDD[Rating], @@ -797,6 +797,7 @@ object ALS { * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) + * @param alpha confidence parameter */ def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double) : MatrixFactorizationModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 969e23be21623..ed2f8b41bcae5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -21,23 +21,45 @@ import java.lang.{Integer => JavaInteger} import org.jblas.DoubleMatrix -import org.apache.spark.SparkContext._ +import org.apache.spark.Logging import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel /** * Model representing the result of matrix factorization. * + * Note: If you create the model directly using constructor, please be aware that fast prediction + * requires cached user/product features and their associated partitioners. + * * @param rank Rank for the features in this model. * @param userFeatures RDD of tuples where each tuple represents the userId and * the features computed for this user. * @param productFeatures RDD of tuples where each tuple represents the productId * and the features computed for this product. */ -class MatrixFactorizationModel private[mllib] ( +class MatrixFactorizationModel( val rank: Int, val userFeatures: RDD[(Int, Array[Double])], - val productFeatures: RDD[(Int, Array[Double])]) extends Serializable { + val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging { + + require(rank > 0) + validateFeatures("User", userFeatures) + validateFeatures("Product", productFeatures) + + /** Validates factors and warns users if there are performance concerns. */ + private def validateFeatures(name: String, features: RDD[(Int, Array[Double])]): Unit = { + require(features.first()._2.size == rank, + s"$name feature dimension does not match the rank $rank.") + if (features.partitioner.isEmpty) { + logWarning(s"$name factor does not have a partitioner. " + + "Prediction on individual records could be slow.") + } + if (features.getStorageLevel == StorageLevel.NONE) { + logWarning(s"$name factor is not cached. Prediction could be slow.") + } + } + /** Predict the rating of one user for one product. */ def predict(user: Int, product: Int): Double = { val userVector = new DoubleMatrix(userFeatures.lookup(user).head) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 00dfc86c9e0bd..0287f04e2c777 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -136,15 +136,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] this } - /** Whether a warning should be logged if the input RDD is uncached. */ - private var warnOnUncachedInput = true - - /** Disable warnings about uncached input. */ - private[spark] def disableUncachedWarning(): this.type = { - warnOnUncachedInput = false - this - } - /** * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. @@ -161,7 +152,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { - if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) { + if (input.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } @@ -241,7 +232,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] } // Warn at the end of the run as well, for increased visibility. - if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) { + if (input.getStorageLevel == StorageLevel.NONE) { logWarning("The input data was not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 654479ac2dd4f..fcc2a148791bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -17,10 +17,8 @@ package org.apache.spark.mllib.stat -import breeze.linalg.{DenseVector => BDV} - import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vectors, Vector} /** * :: DeveloperApi :: @@ -40,37 +38,14 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable { private var n = 0 - private var currMean: BDV[Double] = _ - private var currM2n: BDV[Double] = _ - private var currM2: BDV[Double] = _ - private var currL1: BDV[Double] = _ + private var currMean: Array[Double] = _ + private var currM2n: Array[Double] = _ + private var currM2: Array[Double] = _ + private var currL1: Array[Double] = _ private var totalCnt: Long = 0 - private var nnz: BDV[Double] = _ - private var currMax: BDV[Double] = _ - private var currMin: BDV[Double] = _ - - /** - * Adds input value to position i. - */ - private[this] def add(i: Int, value: Double) = { - if (value != 0.0) { - if (currMax(i) < value) { - currMax(i) = value - } - if (currMin(i) > value) { - currMin(i) = value - } - - val prevMean = currMean(i) - val diff = value - prevMean - currMean(i) = prevMean + diff / (nnz(i) + 1.0) - currM2n(i) += (value - currMean(i)) * diff - currM2(i) += value * value - currL1(i) += math.abs(value) - - nnz(i) += 1.0 - } - } + private var nnz: Array[Double] = _ + private var currMax: Array[Double] = _ + private var currMin: Array[Double] = _ /** * Add a new sample to this summarizer, and update the statistical summary. @@ -83,33 +58,36 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(sample.size > 0, s"Vector should have dimension larger than zero.") n = sample.size - currMean = BDV.zeros[Double](n) - currM2n = BDV.zeros[Double](n) - currM2 = BDV.zeros[Double](n) - currL1 = BDV.zeros[Double](n) - nnz = BDV.zeros[Double](n) - currMax = BDV.fill(n)(Double.MinValue) - currMin = BDV.fill(n)(Double.MaxValue) + currMean = Array.ofDim[Double](n) + currM2n = Array.ofDim[Double](n) + currM2 = Array.ofDim[Double](n) + currL1 = Array.ofDim[Double](n) + nnz = Array.ofDim[Double](n) + currMax = Array.fill[Double](n)(Double.MinValue) + currMin = Array.fill[Double](n)(Double.MaxValue) } require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") - sample match { - case dv: DenseVector => { - var j = 0 - while (j < dv.size) { - add(j, dv.values(j)) - j += 1 + sample.foreachActive { (index, value) => + if (value != 0.0) { + if (currMax(index) < value) { + currMax(index) = value } - } - case sv: SparseVector => - var j = 0 - while (j < sv.indices.size) { - add(sv.indices(j), sv.values(j)) - j += 1 + if (currMin(index) > value) { + currMin(index) = value } - case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + + val prevMean = currMean(index) + val diff = value - prevMean + currMean(index) = prevMean + diff / (nnz(index) + 1.0) + currM2n(index) += (value - currMean(index)) * diff + currM2(index) += value * value + currL1(index) += math.abs(value) + + nnz(index) += 1.0 + } } totalCnt += 1 @@ -152,14 +130,14 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } } else if (totalCnt == 0 && other.totalCnt != 0) { this.n = other.n - this.currMean = other.currMean.copy - this.currM2n = other.currM2n.copy - this.currM2 = other.currM2.copy - this.currL1 = other.currL1.copy + this.currMean = other.currMean.clone + this.currM2n = other.currM2n.clone + this.currM2 = other.currM2.clone + this.currL1 = other.currL1.clone this.totalCnt = other.totalCnt - this.nnz = other.nnz.copy - this.currMax = other.currMax.copy - this.currMin = other.currMin.copy + this.nnz = other.nnz.clone + this.currMax = other.currMax.clone + this.currMin = other.currMin.clone } this } @@ -167,19 +145,19 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S override def mean: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - val realMean = BDV.zeros[Double](n) + val realMean = Array.ofDim[Double](n) var i = 0 while (i < n) { realMean(i) = currMean(i) * (nnz(i) / totalCnt) i += 1 } - Vectors.fromBreeze(realMean) + Vectors.dense(realMean) } override def variance: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - val realVariance = BDV.zeros[Double](n) + val realVariance = Array.ofDim[Double](n) val denominator = totalCnt - 1.0 @@ -194,8 +172,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S i += 1 } } - - Vectors.fromBreeze(realVariance) + Vectors.dense(realVariance) } override def count: Long = totalCnt @@ -203,7 +180,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S override def numNonzeros: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - Vectors.fromBreeze(nnz) + Vectors.dense(nnz) } override def max: Vector = { @@ -214,7 +191,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 i += 1 } - Vectors.fromBreeze(currMax) + Vectors.dense(currMax) } override def min: Vector = { @@ -225,25 +202,25 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 i += 1 } - Vectors.fromBreeze(currMin) + Vectors.dense(currMin) } override def normL2: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - val realMagnitude = BDV.zeros[Double](n) + val realMagnitude = Array.ofDim[Double](n) var i = 0 while (i < currM2.size) { realMagnitude(i) = math.sqrt(currM2(i)) i += 1 } - - Vectors.fromBreeze(realMagnitude) + Vectors.dense(realMagnitude) } override def normL1: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - Vectors.fromBreeze(currL1) + + Vectors.dense(currL1) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 78acc17f901c1..3d91867c896d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -58,13 +58,19 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @return DecisionTreeModel that can be used for prediction */ - def train(input: RDD[LabeledPoint]): DecisionTreeModel = { + def run(input: RDD[LabeledPoint]): DecisionTreeModel = { // Note: random seed will not be used since numTrees = 1. val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) - val rfModel = rf.train(input) - rfModel.weakHypotheses(0) + val rfModel = rf.run(input) + rfModel.trees(0) } + /** + * Trains a decision tree model over an RDD. This is deprecated because it hides the static + * methods with the same name in Java. + */ + @deprecated("Please use DecisionTree.run instead.", "1.2.0") + def train(input: RDD[LabeledPoint]): DecisionTreeModel = run(input) } object DecisionTree extends Serializable with Logging { @@ -86,7 +92,7 @@ object DecisionTree extends Serializable with Logging { * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** @@ -112,7 +118,7 @@ object DecisionTree extends Serializable with Logging { impurity: Impurity, maxDepth: Int): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth) - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** @@ -140,7 +146,7 @@ object DecisionTree extends Serializable with Logging { maxDepth: Int, numClassesForClassification: Int): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification) - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** @@ -177,7 +183,7 @@ object DecisionTree extends Serializable with Logging { categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala similarity index 51% rename from mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala rename to mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index f729344a682e2..61f6b1313f82e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -21,170 +21,117 @@ import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.BoostingStrategy -import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum +import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.impl.TimeTracker -import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} +import org.apache.spark.mllib.tree.impurity.Variance +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel /** * :: Experimental :: - * A class that implements Stochastic Gradient Boosting - * for regression and binary classification problems. + * A class that implements + * [[http://en.wikipedia.org/wiki/Gradient_boosting Stochastic Gradient Boosting]] + * for regression and binary classification. * * The implementation is based upon: * J.H. Friedman. "Stochastic Gradient Boosting." 1999. * - * Notes: - * - This currently can be run with several loss functions. However, only SquaredError is - * fully supported. Specifically, the loss function should be used to compute the gradient - * (to re-label training instances on each iteration) and to weight weak hypotheses. - * Currently, gradients are computed correctly for the available loss functions, - * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError. - * Running with those losses will likely behave reasonably, but lacks the same guarantees. + * Notes on Gradient Boosting vs. TreeBoost: + * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. + * - Both algorithms learn tree ensembles by minimizing loss functions. + * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes + * based on the loss function, whereas the original gradient boosting method does not. + * - When the loss is SquaredError, these methods give the same result, but they could differ + * for other loss functions. * - * @param boostingStrategy Parameters for the gradient boosting algorithm + * @param boostingStrategy Parameters for the gradient boosting algorithm. */ @Experimental -class GradientBoosting ( - private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { - - boostingStrategy.weakLearnerParams.algo = Regression - boostingStrategy.weakLearnerParams.impurity = impurity.Variance - - // Ensure values for weak learner are the same as what is provided to the boosting algorithm. - boostingStrategy.weakLearnerParams.numClassesForClassification = - boostingStrategy.numClassesForClassification - - boostingStrategy.assertValid() +class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) + extends Serializable with Logging { /** * Method to train a gradient boosting model * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return WeightedEnsembleModel that can be used for prediction + * @return a gradient boosted trees model that can be used for prediction */ - def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = { - val algo = boostingStrategy.algo + def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { + val algo = boostingStrategy.treeStrategy.algo algo match { - case Regression => GradientBoosting.boost(input, boostingStrategy) + case Regression => GradientBoostedTrees.boost(input, boostingStrategy) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoosting.boost(remappedInput, boostingStrategy) + GradientBoostedTrees.boost(remappedInput, boostingStrategy) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } } + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]]. + */ + def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { + run(input.rdd) + } } -object GradientBoosting extends Logging { +object GradientBoostedTrees extends Logging { /** * Method to train a gradient boosting model. * - * Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] - * is recommended to clearly specify regression. - * Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] - * is recommended to clearly specify regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. * @param boostingStrategy Configuration options for the boosting algorithm. - * @return WeightedEnsembleModel that can be used for prediction + * @return a gradient boosted trees model that can be used for prediction */ def train( input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - new GradientBoosting(boostingStrategy).train(input) + boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { + new GradientBoostedTrees(boostingStrategy).run(input) } /** - * Method to train a gradient boosting classification model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param boostingStrategy Configuration options for the boosting algorithm. - * @return WeightedEnsembleModel that can be used for prediction - */ - def trainClassifier( - input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - val algo = boostingStrategy.algo - require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.") - new GradientBoosting(boostingStrategy).train(input) - } - - /** - * Method to train a gradient boosting regression model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param boostingStrategy Configuration options for the boosting algorithm. - * @return WeightedEnsembleModel that can be used for prediction - */ - def trainRegressor( - input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - val algo = boostingStrategy.algo - require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.") - new GradientBoosting(boostingStrategy).train(input) - } - - /** - * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]] + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]] */ def train( - input: JavaRDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - train(input.rdd, boostingStrategy) - } - - /** - * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] - */ - def trainClassifier( - input: JavaRDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - trainClassifier(input.rdd, boostingStrategy) - } - - /** - * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] - */ - def trainRegressor( input: JavaRDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - trainRegressor(input.rdd, boostingStrategy) + boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { + train(input.rdd, boostingStrategy) } /** * Internal method for performing regression using trees as base learners. * @param input training dataset * @param boostingStrategy boosting parameters - * @return + * @return a gradient boosted trees model that can be used for prediction */ private def boost( input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { val timer = new TimeTracker() timer.start("total") timer.start("init") + boostingStrategy.assertValid() + // Initialize gradient boosting parameters val numIterations = boostingStrategy.numIterations val baseLearners = new Array[DecisionTreeModel](numIterations) val baseLearnerWeights = new Array[Double](numIterations) val loss = boostingStrategy.loss val learningRate = boostingStrategy.learningRate - val strategy = boostingStrategy.weakLearnerParams + // Prepare strategy for individual trees, which use regression with variance impurity. + val treeStrategy = boostingStrategy.treeStrategy.copy + treeStrategy.algo = Regression + treeStrategy.impurity = Variance + treeStrategy.assertValid() // Cache input if (input.getStorageLevel == StorageLevel.NONE) { @@ -200,11 +147,10 @@ object GradientBoosting extends Logging { // Initialize tree timer.start("building tree 0") - val firstTreeModel = new DecisionTree(strategy).train(data) + val firstTreeModel = new DecisionTree(treeStrategy).run(data) baseLearners(0) = firstTreeModel baseLearnerWeights(0) = 1.0 - val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression, - Sum) + val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0)) logDebug("error of gbt = " + loss.computeError(startingModel, input)) // Note: A model of type regression is used since we require raw prediction timer.stop("building tree 0") @@ -219,7 +165,7 @@ object GradientBoosting extends Logging { logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") - val model = new DecisionTree(strategy).train(data) + val model = new DecisionTree(treeStrategy).run(data) timer.stop(s"building tree $m") // Create partial model baseLearners(m) = model @@ -228,8 +174,8 @@ object GradientBoosting extends Logging { // However, the behavior should be reasonable, though not optimal. baseLearnerWeights(m) = learningRate // Note: A model of type regression is used since we require raw prediction - val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1), - baseLearnerWeights.slice(0, m + 1), Regression, Sum) + val partialModel = new GradientBoostedTreesModel( + Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1)) logDebug("error of gbt = " + loss.computeError(partialModel, input)) // Update data with pseudo-residuals data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), @@ -242,8 +188,7 @@ object GradientBoosting extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") - new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum) - + new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 9683916d9b3f1..482d3395516e7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -17,18 +17,18 @@ package org.apache.spark.mllib.tree -import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.collection.JavaConverters._ import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average -import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache } +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, NodeIdCache, + TimeTracker, TreePoint} import org.apache.spark.mllib.tree.impurity.Impurities import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD @@ -37,7 +37,8 @@ import org.apache.spark.util.Utils /** * :: Experimental :: - * A class which implements a random forest learning algorithm for classification and regression. + * A class that implements a [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] + * learning algorithm for classification and regression. * It supports both continuous and categorical features. * * The settings for featureSubsetStrategy are based on the following references: @@ -70,6 +71,47 @@ private class RandomForest ( private val seed: Int) extends Serializable with Logging { + /* + ALGORITHM + This is a sketch of the algorithm to help new developers. + + The algorithm partitions data by instances (rows). + On each iteration, the algorithm splits a set of nodes. In order to choose the best split + for a given node, sufficient statistics are collected from the distributed data. + For each node, the statistics are collected to some worker node, and that worker selects + the best split. + + This setup requires discretization of continuous features. This binning is done in the + findSplitsBins() method during initialization, after which each continuous feature becomes + an ordered discretized feature with at most maxBins possible values. + + The main loop in the algorithm operates on a queue of nodes (nodeQueue). These nodes + lie at the periphery of the tree being trained. If multiple trees are being trained at once, + then this queue contains nodes from all of them. Each iteration works roughly as follows: + On the master node: + - Some number of nodes are pulled off of the queue (based on the amount of memory + required for their sufficient statistics). + - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate + features are chosen for each node. See method selectNodesToSplit(). + On worker nodes, via method findBestSplits(): + - The worker makes one pass over its subset of instances. + - For each (tree, node, feature, split) tuple, the worker collects statistics about + splitting. Note that the set of (tree, node) pairs is limited to the nodes selected + from the queue for this iteration. The set of features considered can also be limited + based on featureSubsetStrategy. + - For each node, the statistics for that node are aggregated to a particular worker + via reduceByKey(). The designated worker chooses the best (feature, split) pair, + or chooses to stop splitting if the stopping criteria are met. + On the master node: + - The master collects all decisions about splitting nodes and updates the model. + - The updated model is passed to the workers on the next iteration. + This process continues until the node queue is empty. + + Most of the methods in this implementation support the statistics aggregation, which is + the heaviest part of the computation. In general, this implementation is bound by either + the cost of statistics computation on workers or by communicating the sufficient statistics. + */ + strategy.assertValid() require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.") require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy), @@ -79,9 +121,9 @@ private class RandomForest ( /** * Method to train a decision tree model over an RDD * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ - def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = { + def run(input: RDD[LabeledPoint]): RandomForestModel = { val timer = new TimeTracker() @@ -212,8 +254,7 @@ private class RandomForest ( } val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) - val treeWeights = Array.fill[Double](numTrees)(1.0) - new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average) + new RandomForestModel(strategy.algo, trees) } } @@ -231,21 +272,20 @@ object RandomForest extends Serializable with Logging { * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; - * if numTrees > 1 (forest) set to "sqrt" for classification and - * to "onethird" for regression. + * if numTrees > 1 (forest) set to "sqrt". * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainClassifier( input: RDD[LabeledPoint], strategy: Strategy, numTrees: Int, featureSubsetStrategy: String, - seed: Int): WeightedEnsembleModel = { + seed: Int): RandomForestModel = { require(strategy.algo == Classification, s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}") val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) - rf.train(input) + rf.run(input) } /** @@ -262,8 +302,7 @@ object RandomForest extends Serializable with Logging { * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; - * if numTrees > 1 (forest) set to "sqrt" for classification and - * to "onethird" for regression. + * if numTrees > 1 (forest) set to "sqrt". * @param impurity Criterion used for information gain calculation. * Supported values: "gini" (recommended) or "entropy". * @param maxDepth Maximum depth of the tree. @@ -272,7 +311,7 @@ object RandomForest extends Serializable with Logging { * @param maxBins maximum number of bins used for splitting features * (suggested value: 100) * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainClassifier( input: RDD[LabeledPoint], @@ -283,7 +322,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = { + seed: Int = Utils.random.nextInt()): RandomForestModel = { val impurityType = Impurities.fromString(impurity) val strategy = new Strategy(Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo) @@ -302,7 +341,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int): WeightedEnsembleModel = { + seed: Int): RandomForestModel = { trainClassifier(input.rdd, numClassesForClassification, categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) @@ -319,21 +358,20 @@ object RandomForest extends Serializable with Logging { * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; - * if numTrees > 1 (forest) set to "sqrt" for classification and - * to "onethird" for regression. + * if numTrees > 1 (forest) set to "onethird". * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainRegressor( input: RDD[LabeledPoint], strategy: Strategy, numTrees: Int, featureSubsetStrategy: String, - seed: Int): WeightedEnsembleModel = { + seed: Int): RandomForestModel = { require(strategy.algo == Regression, s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}") val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) - rf.train(input) + rf.run(input) } /** @@ -349,8 +387,7 @@ object RandomForest extends Serializable with Logging { * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; - * if numTrees > 1 (forest) set to "sqrt" for classification and - * to "onethird" for regression. + * if numTrees > 1 (forest) set to "onethird". * @param impurity Criterion used for information gain calculation. * Supported values: "variance". * @param maxDepth Maximum depth of the tree. @@ -359,7 +396,7 @@ object RandomForest extends Serializable with Logging { * @param maxBins maximum number of bins used for splitting features * (suggested value: 100) * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainRegressor( input: RDD[LabeledPoint], @@ -369,7 +406,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = { + seed: Int = Utils.random.nextInt()): RandomForestModel = { val impurityType = Impurities.fromString(impurity) val strategy = new Strategy(Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo) @@ -387,7 +424,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int): WeightedEnsembleModel = { + seed: Int): RandomForestModel = { trainRegressor(input.rdd, categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) @@ -479,5 +516,4 @@ object RandomForest extends Serializable with Logging { 3 * totalBins } } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index abbda040bd528..e703adbdbfbb3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -25,57 +25,39 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} /** * :: Experimental :: - * Stores all the configuration options for the boosting algorithms - * @param algo Learning goal. Supported: - * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], - * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * Configuration options for [[org.apache.spark.mllib.tree.GradientBoostedTrees]]. + * + * @param treeStrategy Parameters for the tree algorithm. We support regression and binary + * classification for boosting. Impurity setting will be ignored. + * @param loss Loss function used for minimization during gradient boosting. * @param numIterations Number of iterations of boosting. In other words, the number of * weak hypotheses used in the final model. - * @param loss Loss function used for minimization during gradient boosting. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] - * @param numClassesForClassification Number of classes for classification. - * (Ignored for regression.) - * This setting overrides any setting in [[weakLearnerParams]]. - * Default value is 2 (binary classification). - * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are - * supported. */ @Experimental case class BoostingStrategy( // Required boosting parameters - @BeanProperty var algo: Algo, - @BeanProperty var numIterations: Int, + @BeanProperty var treeStrategy: Strategy, @BeanProperty var loss: Loss, // Optional boosting parameters - @BeanProperty var learningRate: Double = 0.1, - @BeanProperty var numClassesForClassification: Int = 2, - @BeanProperty var weakLearnerParams: Strategy) extends Serializable { - - // Ensure values for weak learner are the same as what is provided to the boosting algorithm. - weakLearnerParams.numClassesForClassification = numClassesForClassification - - /** - * Sets Algorithm using a String. - */ - def setAlgo(algo: String): Unit = algo match { - case "Classification" => setAlgo(Classification) - case "Regression" => setAlgo(Regression) - } + @BeanProperty var numIterations: Int = 100, + @BeanProperty var learningRate: Double = 0.1) extends Serializable { /** * Check validity of parameters. * Throws exception if invalid. */ private[tree] def assertValid(): Unit = { - algo match { + treeStrategy.algo match { case Classification => - require(numClassesForClassification == 2) + require(treeStrategy.numClassesForClassification == 2, + "Only binary classification is supported for boosting.") case Regression => // nothing case _ => throw new IllegalArgumentException( - s"BoostingStrategy given invalid algo parameter: $algo." + + s"BoostingStrategy given invalid algo parameter: ${treeStrategy.algo}." + s" Valid settings are: Classification, Regression.") } require(learningRate > 0 && learningRate <= 1, @@ -94,14 +76,14 @@ object BoostingStrategy { * @return Configuration for boosting algorithm */ def defaultParams(algo: String): BoostingStrategy = { - val treeStrategy = Strategy.defaultStrategy("Regression") + val treeStrategy = Strategy.defaultStrategy(algo) treeStrategy.maxDepth = 3 algo match { case "Classification" => - new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy) + treeStrategy.numClassesForClassification = 2 + new BoostingStrategy(treeStrategy, LogLoss) case "Regression" => - new BoostingStrategy(Algo.withName(algo), 100, SquaredError, - weakLearnerParams = treeStrategy) + new BoostingStrategy(treeStrategy, SquaredError) case _ => throw new IllegalArgumentException(s"$algo is not supported by the boosting.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala index 82889dc00cdad..b5bf732d1b33a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala @@ -17,14 +17,10 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.DeveloperApi - /** - * :: Experimental :: * Enum to select ensemble combining strategy for base learners */ -@DeveloperApi -object EnsembleCombiningStrategy extends Enumeration { +private[tree] object EnsembleCombiningStrategy extends Enumeration { type EnsembleCombiningStrategy = Value - val Sum, Average = Value + val Average, Sum, Vote = Value } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index b5b1f82177edc..d75f38433c081 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -157,6 +157,13 @@ class Strategy ( require(maxMemoryInMB <= 10240, s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") } + + /** Returns a shallow copy of this instance. */ + def copy: Strategy = { + new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, + quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, + maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval) + } } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index d111ffe30ed9e..d1bde15e6b150 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -17,19 +17,18 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.SparkContext._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: - * Class for least absolute error loss calculation. - * The features x and the corresponding label y is predicted using the function F. - * For each instance: - * Loss: |y - F| - * Negative gradient: sign(y - F) + * Class for absolute error loss calculation (for regression). + * + * The absolute (L1) error is defined as: + * |y - F(x)| + * where y is the label and F(x) is the model prediction for features x. */ @DeveloperApi object AbsoluteError extends Loss { @@ -37,30 +36,29 @@ object AbsoluteError extends Loss { /** * Method to calculate the gradients for the gradient boosting calculation for least * absolute error calculation. - * @param model Model of the weak learner + * The gradient with respect to F(x) is: sign(F(x) - y) + * @param model Ensemble model * @param point Instance of the training dataset * @return Loss gradient */ override def gradient( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, point: LabeledPoint): Double = { if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0 } /** - * Method to calculate error of the base learner for the gradient boosting calculation. + * Method to calculate loss of the base learner for the gradient boosting calculation. * Note: This method is not used by the gradient boosting algorithm but is useful for debugging * purposes. - * @param model Model of the weak learner. + * @param model Ensemble model * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return + * @return Mean absolute error of model on data */ - override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { - val sumOfAbsolutes = data.map { y => + override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { + data.map { y => val err = model.predict(y.features) - y.label math.abs(err) - }.sum() - sumOfAbsolutes / data.count() + }.mean() } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 6f3d4340f0d3b..7ce9fa6f86c42 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -19,17 +19,17 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: - * Class for least squares error loss calculation. + * Class for log loss calculation (for classification). + * This uses twice the binomial negative log likelihood, called "deviance" in Friedman (1999). * - * The features x and the corresponding label y is predicted using the function F. - * For each instance: - * Loss: log(1 + exp(-2yF)), y in {-1, 1} - * Negative gradient: 2y / ( 1 + exp(2yF)) + * The log loss is defined as: + * 2 log(1 + exp(-2 y F(x))) + * where y is a label in {-1, 1} and F(x) is the model prediction for features x. */ @DeveloperApi object LogLoss extends Loss { @@ -37,27 +37,37 @@ object LogLoss extends Loss { /** * Method to calculate the loss gradients for the gradient boosting calculation for binary * classification - * @param model Model of the weak learner + * The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x))) + * @param model Ensemble model * @param point Instance of the training dataset * @return Loss gradient */ override def gradient( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, point: LabeledPoint): Double = { val prediction = model.predict(point.features) - 1.0 / (1.0 + math.exp(-prediction)) - point.label + - 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction)) } /** - * Method to calculate error of the base learner for the gradient boosting calculation. + * Method to calculate loss of the base learner for the gradient boosting calculation. * Note: This method is not used by the gradient boosting algorithm but is useful for debugging * purposes. - * @param model Model of the weak learner. + * @param model Ensemble model * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return + * @return Mean log loss of model on data */ - override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { - val wrongPredictions = data.filter(lp => model.predict(lp.features) != lp.label).count() - wrongPredictions / data.count + override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { + data.map { case point => + val prediction = model.predict(point.features) + val margin = 2.0 * point.label * prediction + // The following are equivalent to 2.0 * log(1 + exp(-margin)) but are more numerically + // stable. + if (margin >= 0) { + 2.0 * math.log1p(math.exp(-margin)) + } else { + 2.0 * (-margin + math.log1p(math.exp(margin))) + } + }.mean() } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 5580866c879e2..4bca9039ebe1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD /** @@ -36,7 +36,7 @@ trait Loss extends Serializable { * @return Loss gradient. */ def gradient( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, point: LabeledPoint): Double /** @@ -47,6 +47,6 @@ trait Loss extends Serializable { * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return */ - def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double + def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index 4349fefef2c74..50ecaa2f86f35 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -17,20 +17,18 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.SparkContext._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: - * Class for least squares error loss calculation. + * Class for squared error loss calculation. * - * The features x and the corresponding label y is predicted using the function F. - * For each instance: - * Loss: (y - F)**2/2 - * Negative gradient: y - F + * The squared (L2) error is defined as: + * (y - F(x))**2 + * where y is the label and F(x) is the model prediction for features x. */ @DeveloperApi object SquaredError extends Loss { @@ -38,29 +36,29 @@ object SquaredError extends Loss { /** * Method to calculate the gradients for the gradient boosting calculation for least * squares error calculation. - * @param model Model of the weak learner + * The gradient with respect to F(x) is: - 2 (y - F(x)) + * @param model Ensemble model * @param point Instance of the training dataset * @return Loss gradient */ override def gradient( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, point: LabeledPoint): Double = { - model.predict(point.features) - point.label + 2.0 * (model.predict(point.features) - point.label) } /** - * Method to calculate error of the base learner for the gradient boosting calculation. + * Method to calculate loss of the base learner for the gradient boosting calculation. * Note: This method is not used by the gradient boosting algorithm but is useful for debugging * purposes. - * @param model Model of the weak learner. + * @param model Ensemble model * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return + * @return Mean squared error of model on data */ - override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { data.map { y => val err = model.predict(y.features) - y.label err * err }.mean() } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index ac4d02ee3928b..a5760963068c3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -17,11 +17,11 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.api.java.JavaRDD import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.Vector /** * :: Experimental :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala deleted file mode 100644 index 7b052d9163a13..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala +++ /dev/null @@ -1,158 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.model - -import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ -import org.apache.spark.rdd.RDD - -import scala.collection.mutable - -@Experimental -class WeightedEnsembleModel( - val weakHypotheses: Array[DecisionTreeModel], - val weakHypothesisWeights: Array[Double], - val algo: Algo, - val combiningStrategy: EnsembleCombiningStrategy) extends Serializable { - - require(numWeakHypotheses > 0, s"WeightedEnsembleModel cannot be created without weakHypotheses" + - s". Number of weakHypotheses = $weakHypotheses") - - /** - * Predict values for a single data point using the model trained. - * - * @param features array representing a single data point - * @return predicted category from the trained model - */ - private def predictRaw(features: Vector): Double = { - val treePredictions = weakHypotheses.map(learner => learner.predict(features)) - if (numWeakHypotheses == 1){ - treePredictions(0) - } else { - var prediction = treePredictions(0) - var index = 1 - while (index < numWeakHypotheses) { - prediction += weakHypothesisWeights(index) * treePredictions(index) - index += 1 - } - prediction - } - } - - /** - * Predict values for a single data point using the model trained. - * - * @param features array representing a single data point - * @return predicted category from the trained model - */ - private def predictBySumming(features: Vector): Double = { - algo match { - case Regression => predictRaw(features) - case Classification => { - // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info. - if (predictRaw(features) > 0 ) 1.0 else 0.0 - } - case _ => throw new IllegalArgumentException( - s"WeightedEnsembleModel given unknown algo parameter: $algo.") - } - } - - /** - * Predict values for a single data point. - * - * @param features array representing a single data point - * @return Double prediction from the trained model - */ - private def predictByAveraging(features: Vector): Double = { - algo match { - case Classification => - val predictionToCount = new mutable.HashMap[Int, Int]() - weakHypotheses.foreach { learner => - val prediction = learner.predict(features).toInt - predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1 - } - predictionToCount.maxBy(_._2)._1 - case Regression => - weakHypotheses.map(_.predict(features)).sum / weakHypotheses.size - } - } - - - /** - * Predict values for a single data point using the model trained. - * - * @param features array representing a single data point - * @return predicted category from the trained model - */ - def predict(features: Vector): Double = { - combiningStrategy match { - case Sum => predictBySumming(features) - case Average => predictByAveraging(features) - case _ => throw new IllegalArgumentException( - s"WeightedEnsembleModel given unknown combining parameter: $combiningStrategy.") - } - } - - /** - * Predict values for the given data set. - * - * @param features RDD representing data points to be predicted - * @return RDD[Double] where each entry contains the corresponding prediction - */ - def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x)) - - /** - * Print a summary of the model. - */ - override def toString: String = { - algo match { - case Classification => - s"WeightedEnsembleModel classifier with $numWeakHypotheses trees\n" - case Regression => - s"WeightedEnsembleModel regressor with $numWeakHypotheses trees\n" - case _ => throw new IllegalArgumentException( - s"WeightedEnsembleModel given unknown algo parameter: $algo.") - } - } - - /** - * Print the full model to a string. - */ - def toDebugString: String = { - val header = toString + "\n" - header + weakHypotheses.zipWithIndex.map { case (tree, treeIndex) => - s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) - }.fold("")(_ + _) - } - - /** - * Get number of trees in forest. - */ - def numWeakHypotheses: Int = weakHypotheses.size - - // TODO: Remove these helpers methods once class is generalized to support any base learning - // algorithms. - - /** - * Get total number of nodes, summed over all trees in the forest. - */ - def totalNumNodes: Int = weakHypotheses.map(tree => tree.numNodes).sum - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala new file mode 100644 index 0000000000000..22997110de8dd --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import scala.collection.mutable + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * Represents a random forest model. + * + * @param algo algorithm for the ensemble model, either Classification or Regression + * @param trees tree ensembles + */ +@Experimental +class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) + extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0), + combiningStrategy = if (algo == Classification) Vote else Average) { + + require(trees.forall(_.algo == algo)) +} + +/** + * :: Experimental :: + * Represents a gradient boosted trees model. + * + * @param algo algorithm for the ensemble model, either Classification or Regression + * @param trees tree ensembles + * @param treeWeights tree ensemble weights + */ +@Experimental +class GradientBoostedTreesModel( + override val algo: Algo, + override val trees: Array[DecisionTreeModel], + override val treeWeights: Array[Double]) + extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) { + + require(trees.size == treeWeights.size) +} + +/** + * Represents a tree ensemble model. + * + * @param algo algorithm for the ensemble model, either Classification or Regression + * @param trees tree ensembles + * @param treeWeights tree ensemble weights + * @param combiningStrategy strategy for combining the predictions, not used for regression. + */ +private[tree] sealed class TreeEnsembleModel( + protected val algo: Algo, + protected val trees: Array[DecisionTreeModel], + protected val treeWeights: Array[Double], + protected val combiningStrategy: EnsembleCombiningStrategy) extends Serializable { + + require(numTrees > 0, "TreeEnsembleModel cannot be created without trees.") + + private val sumWeights = math.max(treeWeights.sum, 1e-15) + + /** + * Predicts for a single data point using the weighted sum of ensemble predictions. + * + * @param features array representing a single data point + * @return predicted category from the trained model + */ + private def predictBySumming(features: Vector): Double = { + val treePredictions = trees.map(_.predict(features)) + blas.ddot(numTrees, treePredictions, 1, treeWeights, 1) + } + + /** + * Classifies a single data point based on (weighted) majority votes. + */ + private def predictByVoting(features: Vector): Double = { + val votes = mutable.Map.empty[Int, Double] + trees.view.zip(treeWeights).foreach { case (tree, weight) => + val prediction = tree.predict(features).toInt + votes(prediction) = votes.getOrElse(prediction, 0.0) + weight + } + votes.maxBy(_._2)._1 + } + + /** + * Predict values for a single data point using the model trained. + * + * @param features array representing a single data point + * @return predicted category from the trained model + */ + def predict(features: Vector): Double = { + (algo, combiningStrategy) match { + case (Regression, Sum) => + predictBySumming(features) + case (Regression, Average) => + predictBySumming(features) / sumWeights + case (Classification, Sum) => // binary classification + val prediction = predictBySumming(features) + // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info. + if (prediction > 0.0) 1.0 else 0.0 + case (Classification, Vote) => + predictByVoting(features) + case _ => + throw new IllegalArgumentException( + "TreeEnsembleModel given unsupported (algo, combiningStrategy) combination: " + + s"($algo, $combiningStrategy).") + } + } + + /** + * Predict values for the given data set. + * + * @param features RDD representing data points to be predicted + * @return RDD[Double] where each entry contains the corresponding prediction + */ + def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x)) + + /** + * Java-friendly version of [[org.apache.spark.mllib.tree.model.TreeEnsembleModel#predict]]. + */ + def predict(features: JavaRDD[Vector]): JavaRDD[java.lang.Double] = { + predict(features.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] + } + + /** + * Print a summary of the model. + */ + override def toString: String = { + algo match { + case Classification => + s"TreeEnsembleModel classifier with $numTrees trees\n" + case Regression => + s"TreeEnsembleModel regressor with $numTrees trees\n" + case _ => throw new IllegalArgumentException( + s"TreeEnsembleModel given unknown algo parameter: $algo.") + } + } + + /** + * Print the full model to a string. + */ + def toDebugString: String = { + val header = toString + "\n" + header + trees.zipWithIndex.map { case (tree, treeIndex) => + s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) + }.fold("")(_ + _) + } + + /** + * Get number of trees in forest. + */ + def numTrees: Int = trees.size + + /** + * Get total number of nodes, summed over all trees in the forest. + */ + def totalNumNodes: Int = trees.map(_.numNodes).sum +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java index 2c281a1ee7157..9925aae441af9 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -74,7 +74,7 @@ public void runDTUsingConstructor() { maxBins, categoricalFeaturesInfo); DecisionTree learner = new DecisionTree(strategy); - DecisionTreeModel model = learner.train(rdd.rdd()); + DecisionTreeModel model = learner.run(rdd.rdd()); int numCorrect = validatePrediction(arr, model); Assert.assertTrue(numCorrect == rdd.count()); diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index ef6c4a3974bf3..6526bd614fe40 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -17,7 +17,11 @@ package org.apache.spark.mllib.linalg +import java.util.Random + +import org.mockito.Mockito.when import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar._ class MatricesSuite extends FunSuite { test("dense matrix construction") { @@ -221,4 +225,50 @@ class MatricesSuite extends FunSuite { Matrices.vertcat(Array(deMat1, spMat2)) } } + + test("zeros") { + val mat = Matrices.zeros(2, 3).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 3) + assert(mat.values.forall(_ == 0.0)) + } + + test("ones") { + val mat = Matrices.ones(2, 3).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 3) + assert(mat.values.forall(_ == 1.0)) + } + + test("eye") { + val mat = Matrices.eye(2).asInstanceOf[DenseMatrix] + assert(mat.numCols === 2) + assert(mat.numCols === 2) + assert(mat.values.toSeq === Seq(1.0, 0.0, 0.0, 1.0)) + } + + test("rand") { + val rng = mock[Random] + when(rng.nextDouble()).thenReturn(1.0, 2.0, 3.0, 4.0) + val mat = Matrices.rand(2, 2, rng).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 2) + assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0)) + } + + test("randn") { + val rng = mock[Random] + when(rng.nextGaussian()).thenReturn(1.0, 2.0, 3.0, 4.0) + val mat = Matrices.randn(2, 2, rng).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 2) + assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0)) + } + + test("diag") { + val mat = Matrices.diag(Vectors.dense(1.0, 2.0)).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 2) + assert(mat.values.toSeq === Seq(1.0, 0.0, 0.0, 2.0)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 93a84fe07b32a..9492f604af4d5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.linalg +import breeze.linalg.{DenseMatrix => BDM} import org.scalatest.FunSuite import org.apache.spark.SparkException @@ -166,4 +167,34 @@ class VectorsSuite extends FunSuite { assert(v === udt.deserialize(udt.serialize(v))) } } + + test("fromBreeze") { + val x = BDM.zeros[Double](10, 10) + val v = Vectors.fromBreeze(x(::, 0)) + assert(v.size === x.rows) + } + + test("foreachActive") { + val dv = Vectors.dense(0.0, 1.2, 3.1, 0.0) + val sv = Vectors.sparse(4, Seq((1, 1.2), (2, 3.1), (3, 0.0))) + + val dvMap = scala.collection.mutable.Map[Int, Double]() + dv.foreachActive { (index, value) => + dvMap.put(index, value) + } + assert(dvMap.size === 4) + assert(dvMap.get(0) === Some(0.0)) + assert(dvMap.get(1) === Some(1.2)) + assert(dvMap.get(2) === Some(3.1)) + assert(dvMap.get(3) === Some(0.0)) + + val svMap = scala.collection.mutable.Map[Int, Double]() + sv.foreachActive { (index, value) => + svMap.put(index, value) + } + assert(svMap.size === 3) + assert(svMap.get(1) === Some(1.2)) + assert(svMap.get(2) === Some(3.1)) + assert(svMap.get(3) === Some(0.0)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala new file mode 100644 index 0000000000000..b9caecc904a23 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.recommendation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD + +class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext { + + val rank = 2 + var userFeatures: RDD[(Int, Array[Double])] = _ + var prodFeatures: RDD[(Int, Array[Double])] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + userFeatures = sc.parallelize(Seq((0, Array(1.0, 2.0)), (1, Array(3.0, 4.0)))) + prodFeatures = sc.parallelize(Seq((2, Array(5.0, 6.0)))) + } + + test("constructor") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + assert(model.predict(0, 2) ~== 17.0 relTol 1e-14) + + intercept[IllegalArgumentException] { + new MatrixFactorizationModel(1, userFeatures, prodFeatures) + } + + val userFeatures1 = sc.parallelize(Seq((0, Array(1.0)), (1, Array(3.0)))) + intercept[IllegalArgumentException] { + new MatrixFactorizationModel(rank, userFeatures1, prodFeatures) + } + + val prodFeatures1 = sc.parallelize(Seq((2, Array(5.0)))) + intercept[IllegalArgumentException] { + new MatrixFactorizationModel(rank, userFeatures, prodFeatures1) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala index effb7b8259ffb..8972c229b7ecb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.util.StatCounter import scala.collection.mutable @@ -48,7 +48,7 @@ object EnsembleTestHelper { } def validateClassifier( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, input: Seq[LabeledPoint], requiredAccuracy: Double) { val predictions = input.map(x => model.predict(x.features)) @@ -60,17 +60,27 @@ object EnsembleTestHelper { s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") } + /** + * Validates a tree ensemble model for regression. + */ def validateRegressor( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, input: Seq[LabeledPoint], - requiredMSE: Double) { + required: Double, + metricName: String = "mse") { val predictions = input.map(x => model.predict(x.features)) - val squaredError = predictions.zip(input).map { case (prediction, expected) => - val err = prediction - expected.label - err * err - }.sum - val mse = squaredError / input.length - assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") + val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) => + prediction - label + } + val metric = metricName match { + case "mse" => + errors.map(err => err * err).sum / errors.size + case "mae" => + errors.map(math.abs).sum / errors.size + } + + assert(metric <= required, + s"validateRegressor calculated $metricName $metric but required $required.") } def generateOrderedLabeledPoints(numFeatures: Int, numInstances: Int): Array[LabeledPoint] = { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala new file mode 100644 index 0000000000000..d4d54cf4c9e2a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} +import org.apache.spark.mllib.tree.impurity.Variance +import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss} + +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suite for [[GradientBoostedTrees]]. + */ +class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { + + test("Regression with continuous features: SquaredError") { + GradientBoostedTreesSuite.testCombinations.foreach { + case (numIterations, learningRate, subsamplingRate) => + GradientBoostedTreesSuite.randomSeeds.foreach { randomSeed => + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) + val boostingStrategy = + new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + assert(gbt.trees.size === numIterations) + try { + EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06) + } catch { + case e: java.lang.AssertionError => + println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + s" subsamplingRate=$subsamplingRate") + throw e + } + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val dt = DecisionTree.train(remappedInput, treeStrategy) + + // Make sure trees are the same. + assert(gbt.trees.head.toString == dt.toString) + } + } + } + + test("Regression with continuous features: Absolute Error") { + GradientBoostedTreesSuite.testCombinations.foreach { + case (numIterations, learningRate, subsamplingRate) => + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) + val boostingStrategy = + new BoostingStrategy(treeStrategy, AbsoluteError, numIterations, learningRate) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + assert(gbt.trees.size === numIterations) + try { + EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.85, "mae") + } catch { + case e: java.lang.AssertionError => + println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + s" subsamplingRate=$subsamplingRate") + throw e + } + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val dt = DecisionTree.train(remappedInput, treeStrategy) + + // Make sure trees are the same. + assert(gbt.trees.head.toString == dt.toString) + } + } + + test("Binary classification with continuous features: Log Loss") { + GradientBoostedTreesSuite.testCombinations.foreach { + case (numIterations, learningRate, subsamplingRate) => + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = Map.empty, + subsamplingRate = subsamplingRate) + val boostingStrategy = + new BoostingStrategy(treeStrategy, LogLoss, numIterations, learningRate) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + assert(gbt.trees.size === numIterations) + try { + EnsembleTestHelper.validateClassifier(gbt, GradientBoostedTreesSuite.data, 0.9) + } catch { + case e: java.lang.AssertionError => + println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + s" subsamplingRate=$subsamplingRate") + throw e + } + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val ensembleStrategy = treeStrategy.copy + ensembleStrategy.algo = Regression + ensembleStrategy.impurity = Variance + val dt = DecisionTree.train(remappedInput, ensembleStrategy) + + // Make sure trees are the same. + assert(gbt.trees.head.toString == dt.toString) + } + } + +} + +object GradientBoostedTreesSuite { + + // Combinations for estimators, learning rates and subsamplingRate + val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) + + val randomSeeds = Array(681283, 4398) + + val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala deleted file mode 100644 index 84de40103d8aa..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree - -import org.scalatest.FunSuite - -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} -import org.apache.spark.mllib.tree.impurity.Variance -import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss} - -import org.apache.spark.mllib.util.MLlibTestSparkContext - -/** - * Test suite for [[GradientBoosting]]. - */ -class GradientBoostingSuite extends FunSuite with MLlibTestSparkContext { - - test("Regression with continuous features: SquaredError") { - GradientBoostingSuite.testCombinations.foreach { - case (numIterations, learningRate, subsamplingRate) => - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) - val rdd = sc.parallelize(arr) - val categoricalFeaturesInfo = Map.empty[Int, Int] - - val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, - subsamplingRate = subsamplingRate) - - val dt = DecisionTree.train(remappedInput, treeStrategy) - - val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError, - learningRate, 1, treeStrategy) - - val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numIterations) - val gbtTree = gbt.weakHypotheses(0) - - EnsembleTestHelper.validateRegressor(gbt, arr, 0.03) - - // Make sure trees are the same. - assert(gbtTree.toString == dt.toString) - } - } - - test("Regression with continuous features: Absolute Error") { - GradientBoostingSuite.testCombinations.foreach { - case (numIterations, learningRate, subsamplingRate) => - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) - val rdd = sc.parallelize(arr) - val categoricalFeaturesInfo = Map.empty[Int, Int] - - val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, - subsamplingRate = subsamplingRate) - - val dt = DecisionTree.train(remappedInput, treeStrategy) - - val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError, - learningRate, numClassesForClassification = 2, treeStrategy) - - val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numIterations) - val gbtTree = gbt.weakHypotheses(0) - - EnsembleTestHelper.validateRegressor(gbt, arr, 0.03) - - // Make sure trees are the same. - assert(gbtTree.toString == dt.toString) - } - } - - test("Binary classification with continuous features: Log Loss") { - GradientBoostingSuite.testCombinations.foreach { - case (numIterations, learningRate, subsamplingRate) => - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) - val rdd = sc.parallelize(arr) - val categoricalFeaturesInfo = Map.empty[Int, Int] - - val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, - subsamplingRate = subsamplingRate) - - val dt = DecisionTree.train(remappedInput, treeStrategy) - - val boostingStrategy = new BoostingStrategy(Classification, numIterations, LogLoss, - learningRate, numClassesForClassification = 2, treeStrategy) - - val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numIterations) - val gbtTree = gbt.weakHypotheses(0) - - EnsembleTestHelper.validateClassifier(gbt, arr, 0.9) - - // Make sure trees are the same. - assert(gbtTree.toString == dt.toString) - } - } - -} - -object GradientBoostingSuite { - - // Combinations for estimators, learning rates and subsamplingRate - val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75)) - -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 2734e089d62e6..90a8c2dfdab80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -41,8 +41,8 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) - assert(rf.weakHypotheses.size === 1) - val rfTree = rf.weakHypotheses(0) + assert(rf.trees.size === 1) + val rfTree = rf.trees(0) val dt = DecisionTree.train(rdd, strategy) @@ -65,7 +65,8 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { " comparing DecisionTree vs. RandomForest(numTrees = 1)") { val categoricalFeaturesInfo = Map.empty[Int, Int] val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + useNodeIdCache = true) binaryClassificationTestWithContinuousFeatures(strategy) } @@ -76,8 +77,8 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) - assert(rf.weakHypotheses.size === 1) - val rfTree = rf.weakHypotheses(0) + assert(rf.trees.size === 1) + val rfTree = rf.trees(0) val dt = DecisionTree.train(rdd, strategy) @@ -175,7 +176,8 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { test("Binary classification with continuous features and node Id cache: subsampling features") { val categoricalFeaturesInfo = Map.empty[Int, Int] val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + useNodeIdCache = true) binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) } diff --git a/network/common/pom.xml b/network/common/pom.xml index 2bd0a7d2945dd..baca859fa5011 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 76bce8592816a..9afd5decd5e6b 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -19,7 +19,6 @@ import java.io.Closeable; import java.io.IOException; -import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.List; @@ -37,7 +36,6 @@ import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; -import io.netty.util.internal.PlatformDependent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -67,6 +65,7 @@ public class TransportClientFactory implements Closeable { private final Class socketChannelClass; private EventLoopGroup workerGroup; + private PooledByteBufAllocator pooledAllocator; public TransportClientFactory( TransportContext context, @@ -80,6 +79,8 @@ public TransportClientFactory( this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); // TODO: Make thread pool name configurable. this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client"); + this.pooledAllocator = NettyUtils.createPooledByteBufAllocator( + conf.preferDirectBufs(), false /* allowCache */, conf.clientThreads()); } /** @@ -115,11 +116,8 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO // Disable Nagle's Algorithm since we don't want packets to wait .option(ChannelOption.TCP_NODELAY, true) .option(ChannelOption.SO_KEEPALIVE, true) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()); - - // Use pooled buffers to reduce temporary buffer allocation - bootstrap.option(ChannelOption.ALLOCATOR, NettyUtils.createPooledByteBufAllocator( - conf.preferDirectBufs(), false /* allowCache */, conf.clientThreads())); + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()) + .option(ChannelOption.ALLOCATOR, pooledAllocator); final AtomicReference clientRef = new AtomicReference(); diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 5c654a6fd6ebe..b3991a6577cfe 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -109,9 +109,9 @@ public static String getRemoteAddress(Channel channel) { /** * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches - * are disabled because the ByteBufs are allocated by the event loop thread, but released by the - * executor thread rather than the event loop thread. Those thread-local caches actually delay - * the recycling of buffers, leading to larger memory usage. + * are disabled for TransportClients because the ByteBufs are allocated by the event loop thread, + * but released by the executor thread rather than the event loop thread. Those thread-local + * caches actually delay the recycling of buffers, leading to larger memory usage. */ public static PooledByteBufAllocator createPooledByteBufAllocator( boolean allowDirectBufs, diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 12ff034cfe588..12468567c3aed 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index 7845011ec3200..acec8f18f2b5c 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml diff --git a/pom.xml b/pom.xml index 639ea22a1fda3..6c1c1214a7d3e 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -254,6 +254,18 @@ false + + + spark-staging-1038 + Spark 1.2.0 Staging (1038) + https://repository.apache.org/content/repositories/orgapachespark-1038/ + + true + + + false + + @@ -269,7 +281,7 @@ - @@ -278,7 +290,7 @@ unused 1.0.0 - | + * ignore threshold --->| |<--- current batch time + * |____.____.____.____.____.____| + * | | | | | | | + * ---------------------|----|----|----|----|----|----|-----------------------> Time + * |____|____|____|____|____|____| + * remembered batches + * + * The trailing end of the window is the "ignore threshold" and all files whose mod times + * are less than this threshold are assumed to have already been selected and are therefore + * ignored. Files whose mod times are within the "remember window" are checked against files + * that have already been selected. At a high level, this is how new files are identified in + * each batch - files whose mod times are greater than the ignore threshold and + * have not been considered within the remember window. See the documentation on the method + * `isNewFile` for more details. + * + * This makes some assumptions from the underlying file system that the system is monitoring. + * - The clock of the file system is assumed to synchronized with the clock of the machine running + * the streaming app. + * - If a file is to be visible in the directory listings, it must be visible within a certain + * duration of the mod time of the file. This duration is the "remember window", which is set to + * 1 minute (see `FileInputDStream.MIN_REMEMBER_DURATION`). Otherwise, the file will never be + * selected as the mod time will be less than the ignore threshold when it becomes visible. + * - Once a file is visible, the mod time cannot change. If it does due to appends, then the + * processing semantics are undefined. + */ private[streaming] class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : ClassTag]( @transient ssc_ : StreamingContext, @@ -37,22 +74,37 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas newFilesOnly: Boolean = true) extends InputDStream[(K, V)](ssc_) { + // Data to be saved as part of the streaming checkpoints protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData - // files found in the last interval - private val lastFoundFiles = new HashSet[String] + // Initial ignore threshold based on which old, existing files in the directory (at the time of + // starting the streaming application) will be ignored or considered + private val initialModTimeIgnoreThreshold = if (newFilesOnly) System.currentTimeMillis() else 0L + + /* + * Make sure that the information of files selected in the last few batches are remembered. + * This would allow us to filter away not-too-old files which have already been recently + * selected and processed. + */ + private val numBatchesToRemember = FileInputDStream.calculateNumBatchesToRemember(slideDuration) + private val durationToRemember = slideDuration * numBatchesToRemember + remember(durationToRemember) - // Files with mod time earlier than this is ignored. This is updated every interval - // such that in the current interval, files older than any file found in the - // previous interval will be ignored. Obviously this time keeps moving forward. - private var ignoreTime = if (newFilesOnly) System.currentTimeMillis() else 0L + // Map of batch-time to selected file info for the remembered batches + @transient private[streaming] var batchTimeToSelectedFiles = + new mutable.HashMap[Time, Array[String]] + + // Set of files that were selected in the remembered batches + @transient private var recentlySelectedFiles = new mutable.HashSet[String]() + + // Read-through cache of file mod times, used to speed up mod time lookups + @transient private var fileToModTime = new TimeStampedHashMap[String, Long](true) + + // Timestamp of the last round of finding files + @transient private var lastNewFileFindingTime = 0L - // Latest file mod time seen till any point of time @transient private var path_ : Path = null @transient private var fs_ : FileSystem = null - @transient private[streaming] var files = new HashMap[Time, Array[String]] - @transient private var fileModTimes = new TimeStampedHashMap[String, Long](true) - @transient private var lastNewFileFindingTime = 0L override def start() { } @@ -68,54 +120,113 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas * the previous call. */ override def compute(validTime: Time): Option[RDD[(K, V)]] = { - assert(validTime.milliseconds >= ignoreTime, - "Trying to get new files for a really old time [" + validTime + " < " + ignoreTime + "]") - // Find new files - val (newFiles, minNewFileModTime) = findNewFiles(validTime.milliseconds) + val newFiles = findNewFiles(validTime.milliseconds) logInfo("New files at time " + validTime + ":\n" + newFiles.mkString("\n")) - if (!newFiles.isEmpty) { - lastFoundFiles.clear() - lastFoundFiles ++= newFiles - ignoreTime = minNewFileModTime - } - files += ((validTime, newFiles.toArray)) + batchTimeToSelectedFiles += ((validTime, newFiles)) + recentlySelectedFiles ++= newFiles Some(filesToRDD(newFiles)) } /** Clear the old time-to-files mappings along with old RDDs */ protected[streaming] override def clearMetadata(time: Time) { super.clearMetadata(time) - val oldFiles = files.filter(_._1 < (time - rememberDuration)) - files --= oldFiles.keys + val oldFiles = batchTimeToSelectedFiles.filter(_._1 < (time - rememberDuration)) + batchTimeToSelectedFiles --= oldFiles.keys + recentlySelectedFiles --= oldFiles.values.flatten logInfo("Cleared " + oldFiles.size + " old files that were older than " + (time - rememberDuration) + ": " + oldFiles.keys.mkString(", ")) logDebug("Cleared files are:\n" + oldFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n")) // Delete file mod times that weren't accessed in the last round of getting new files - fileModTimes.clearOldValues(lastNewFileFindingTime - 1) + fileToModTime.clearOldValues(lastNewFileFindingTime - 1) } /** - * Find files which have modification timestamp <= current time and return a 3-tuple of - * (new files found, latest modification time among them, files with latest modification time) + * Find new files for the batch of `currentTime`. This is done by first calculating the + * ignore threshold for file mod times, and then getting a list of files filtered based on + * the current batch time and the ignore threshold. The ignore threshold is the max of + * initial ignore threshold and the trailing end of the remember window (that is, which ever + * is later in time). */ - private def findNewFiles(currentTime: Long): (Seq[String], Long) = { - logDebug("Trying to get new files for time " + currentTime) - lastNewFileFindingTime = System.currentTimeMillis - val filter = new CustomPathFilter(currentTime) - val newFiles = fs.listStatus(directoryPath, filter).map(_.getPath.toString) - val timeTaken = System.currentTimeMillis - lastNewFileFindingTime - logInfo("Finding new files took " + timeTaken + " ms") - logDebug("# cached file times = " + fileModTimes.size) - if (timeTaken > slideDuration.milliseconds) { - logWarning( - "Time taken to find new files exceeds the batch size. " + - "Consider increasing the batch size or reduceing the number of " + - "files in the monitored directory." + private def findNewFiles(currentTime: Long): Array[String] = { + try { + lastNewFileFindingTime = System.currentTimeMillis + + // Calculate ignore threshold + val modTimeIgnoreThreshold = math.max( + initialModTimeIgnoreThreshold, // initial threshold based on newFilesOnly setting + currentTime - durationToRemember.milliseconds // trailing end of the remember window ) + logDebug(s"Getting new files for time $currentTime, " + + s"ignoring files older than $modTimeIgnoreThreshold") + val filter = new PathFilter { + def accept(path: Path): Boolean = isNewFile(path, currentTime, modTimeIgnoreThreshold) + } + val newFiles = fs.listStatus(directoryPath, filter).map(_.getPath.toString) + val timeTaken = System.currentTimeMillis - lastNewFileFindingTime + logInfo("Finding new files took " + timeTaken + " ms") + logDebug("# cached file times = " + fileToModTime.size) + if (timeTaken > slideDuration.milliseconds) { + logWarning( + "Time taken to find new files exceeds the batch size. " + + "Consider increasing the batch size or reducing the number of " + + "files in the monitored directory." + ) + } + newFiles + } catch { + case e: Exception => + logWarning("Error finding new files", e) + reset() + Array.empty + } + } + + /** + * Identify whether the given `path` is a new file for the batch of `currentTime`. For it to be + * accepted, it has to pass the following criteria. + * - It must pass the user-provided file filter. + * - It must be newer than the ignore threshold. It is assumed that files older than the ignore + * threshold have already been considered or are existing files before start + * (when newFileOnly = true). + * - It must not be present in the recently selected files that this class remembers. + * - It must not be newer than the time of the batch (i.e. `currentTime` for which this + * file is being tested. This can occur if the driver was recovered, and the missing batches + * (during downtime) are being generated. In that case, a batch of time T may be generated + * at time T+x. Say x = 5. If that batch T contains file of mod time T+5, then bad things can + * happen. Let's say the selected files are remembered for 60 seconds. At time t+61, + * the batch of time t is forgotten, and the ignore threshold is still T+1. + * The files with mod time T+5 are not remembered and cannot be ignored (since, t+5 > t+1). + * Hence they can get selected as new files again. To prevent this, files whose mod time is more + * than current batch time are not considered. + */ + private def isNewFile(path: Path, currentTime: Long, modTimeIgnoreThreshold: Long): Boolean = { + val pathStr = path.toString + // Reject file if it does not satisfy filter + if (!filter(path)) { + logDebug(s"$pathStr rejected by filter") + return false + } + // Reject file if it was created before the ignore time + val modTime = getFileModTime(path) + if (modTime <= modTimeIgnoreThreshold) { + // Use <= instead of < to avoid SPARK-4518 + logDebug(s"$pathStr ignored as mod time $modTime <= ignore time $modTimeIgnoreThreshold") + return false } - (newFiles, filter.minNewFileModTime) + // Reject file if mod time > current batch time + if (modTime > currentTime) { + logDebug(s"$pathStr not selected as mod time $modTime > current time $currentTime") + return false + } + // Reject file if it was considered earlier + if (recentlySelectedFiles.contains(pathStr)) { + logDebug(s"$pathStr already considered") + return false + } + logDebug(s"$pathStr accepted with mod time $modTime") + return true } /** Generate one RDD from an array of files */ @@ -132,21 +243,21 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas new UnionRDD(context.sparkContext, fileRDDs) } + /** Get file mod time from cache or fetch it from the file system */ + private def getFileModTime(path: Path) = { + fileToModTime.getOrElseUpdate(path.toString, fs.getFileStatus(path).getModificationTime()) + } + private def directoryPath: Path = { if (path_ == null) path_ = new Path(directory) path_ } private def fs: FileSystem = { - if (fs_ == null) fs_ = directoryPath.getFileSystem(new Configuration()) + if (fs_ == null) fs_ = directoryPath.getFileSystem(ssc.sparkContext.hadoopConfiguration) fs_ } - private def getFileModTime(path: Path) = { - // Get file mod time from cache or fetch it from the file system - fileModTimes.getOrElseUpdate(path.toString, fs.getFileStatus(path).getModificationTime()) - } - private def reset() { fs_ = null } @@ -155,9 +266,10 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() - generatedRDDs = new HashMap[Time, RDD[(K,V)]] () - files = new HashMap[Time, Array[String]] - fileModTimes = new TimeStampedHashMap[String, Long](true) + generatedRDDs = new mutable.HashMap[Time, RDD[(K,V)]] () + batchTimeToSelectedFiles = new mutable.HashMap[Time, Array[String]]() + recentlySelectedFiles = new mutable.HashSet[String]() + fileToModTime = new TimeStampedHashMap[String, Long](true) } /** @@ -167,11 +279,11 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas private[streaming] class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) { - def hadoopFiles = data.asInstanceOf[HashMap[Time, Array[String]]] + def hadoopFiles = data.asInstanceOf[mutable.HashMap[Time, Array[String]]] override def update(time: Time) { hadoopFiles.clear() - hadoopFiles ++= files + hadoopFiles ++= batchTimeToSelectedFiles } override def cleanup(time: Time) { } @@ -182,7 +294,8 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas // Restore the metadata in both files and generatedRDDs logInfo("Restoring files for time " + t + " - " + f.mkString("[", ", ", "]") ) - files += ((t, f)) + batchTimeToSelectedFiles += ((t, f)) + recentlySelectedFiles ++= f generatedRDDs += ((t, filesToRDD(f))) } } @@ -193,57 +306,25 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas hadoopFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n") + "\n]" } } +} + +private[streaming] +object FileInputDStream { /** - * Custom PathFilter class to find new files that - * ... have modification time more than ignore time - * ... have not been seen in the last interval - * ... have modification time less than maxModTime + * Minimum duration of remembering the information of selected files. Files with mod times + * older than this "window" of remembering will be ignored. So if new files are visible + * within this window, then the file will get selected in the next batch. */ - private[streaming] - class CustomPathFilter(maxModTime: Long) extends PathFilter { + private val MIN_REMEMBER_DURATION = Minutes(1) - // Minimum of the mod times of new files found in the current interval - var minNewFileModTime = -1L + def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".") - def accept(path: Path): Boolean = { - try { - if (!filter(path)) { // Reject file if it does not satisfy filter - logDebug("Rejected by filter " + path) - return false - } - // Reject file if it was found in the last interval - if (lastFoundFiles.contains(path.toString)) { - logDebug("Mod time equal to last mod time, but file considered already") - return false - } - val modTime = getFileModTime(path) - logDebug("Mod time for " + path + " is " + modTime) - if (modTime < ignoreTime) { - // Reject file if it was created before the ignore time (or, before last interval) - logDebug("Mod time " + modTime + " less than ignore time " + ignoreTime) - return false - } else if (modTime > maxModTime) { - // Reject file if it is too new that considering it may give errors - logDebug("Mod time more than ") - return false - } - if (minNewFileModTime < 0 || modTime < minNewFileModTime) { - minNewFileModTime = modTime - } - logDebug("Accepted " + path) - } catch { - case fnfe: java.io.FileNotFoundException => - logWarning("Error finding new files", fnfe) - reset() - return false - } - true - } + /** + * Calculate the number of last batches to remember, such that all the files selected in + * at least last MIN_REMEMBER_DURATION duration can be remembered. + */ + def calculateNumBatchesToRemember(batchDuration: Duration): Int = { + math.ceil(MIN_REMEMBER_DURATION.milliseconds.toDouble / batchDuration.milliseconds).toInt } } - -private[streaming] -object FileInputDStream { - def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".") -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala index 905bc723f69a9..1361c30395b57 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -38,6 +38,7 @@ class ForEachDStream[T: ClassTag] ( parent.getOrCompute(time) match { case Some(rdd) => val jobFunc = () => { + ssc.sparkContext.setCallSite(creationSite) foreachFunc(rdd, time) } Some(new Job(time, jobFunc)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index b39f47f04a38b..98539e06b4e29 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -17,20 +17,17 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.StreamingContext._ - -import org.apache.spark.{Partitioner, HashPartitioner} -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD - import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.hadoop.mapred.OutputFormat import org.apache.hadoop.conf.Configuration -import org.apache.spark.streaming.{Time, Duration} +import org.apache.hadoop.mapred.{JobConf, OutputFormat} +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} + +import org.apache.spark.{HashPartitioner, Partitioner, SerializableWritable} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} +import org.apache.spark.streaming.StreamingContext._ /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. @@ -398,10 +395,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. * org.apache.spark.Partitioner is used to control the partitioning of each RDD. - * @param updateFunc State update function. If `this` function returns None, then - * corresponding state key-value pair will be eliminated. Note, that - * this function may generate a different a tuple with a different key - * than the input key. It is up to the developer to decide whether to + * @param updateFunc State update function. Note, that this function may generate a different + * tuple with a different key than the input key. Therefore keys may be removed + * or added in this way. It is up to the developer to decide whether to * remember the partitioner despite the key being changed. * @param partitioner Partitioner for controlling the partitioning of each RDD in the new * DStream @@ -442,11 +438,10 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. * org.apache.spark.Partitioner is used to control the partitioning of each RDD. - * @param updateFunc State update function. If `this` function returns None, then - * corresponding state key-value pair will be eliminated. Note, that - * this function may generate a different a tuple with a different key - * than the input key. It is up to the developer to decide whether to - * remember the partitioner despite the key being changed. + * @param updateFunc State update function. Note, that this function may generate a different + * tuple with a different key than the input key. Therefore keys may be removed + * or added in this way. It is up to the developer to decide whether to + * remember the partitioner despite the key being changed. * @param partitioner Partitioner for controlling the partitioning of each RDD in the new * DStream * @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs. @@ -673,11 +668,13 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], - conf: JobConf = new JobConf + conf: JobConf = new JobConf(ssc.sparkContext.hadoopConfiguration) ) { + // Wrap conf in SerializableWritable so that ForeachDStream can be serialized for checkpoints + val serializableConf = new SerializableWritable(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) - rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) + rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, serializableConf.value) } self.foreachRDD(saveFunc) } @@ -704,11 +701,14 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration = new Configuration + conf: Configuration = ssc.sparkContext.hadoopConfiguration ) { + // Wrap conf in SerializableWritable so that ForeachDStream can be serialized for checkpoints + val serializableConf = new SerializableWritable(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) - rdd.saveAsNewAPIHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) + rdd.saveAsNewAPIHadoopFile( + file, keyClass, valueClass, outputFormatClass, serializableConf.value) } self.foreachRDD(saveFunc) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 3e67161363e50..c834744631e02 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -29,7 +29,7 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] * that has to start a receiver on worker nodes to receive external data. - * Specific implementations of NetworkInputDStream must + * Specific implementations of ReceiverInputDStream must * define `the getReceiver()` function that gets the receiver object of type * [[org.apache.spark.streaming.receiver.Receiver]] that will be sent * to the workers to receive data. @@ -39,17 +39,17 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { - /** This is an unique identifier for the network input stream. */ + /** This is an unique identifier for the receiver input stream. */ val id = ssc.getNewReceiverStreamId() /** * Gets the receiver object that will be sent to the worker nodes * to receive data. This method needs to defined by any specific implementation - * of a NetworkInputDStream. + * of a ReceiverInputDStream. */ def getReceiver(): Receiver[T] - // Nothing to start or stop as both taken care of by the ReceiverInputTracker. + // Nothing to start or stop as both taken care of by the ReceiverTracker. def start() {} def stop() {} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala index 57429a15329a1..abbc40befa95b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala @@ -28,17 +28,10 @@ private[streaming] class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) extends DStream[T](parents.head.ssc) { - if (parents.length == 0) { - throw new IllegalArgumentException("Empty array of parents") - } - - if (parents.map(_.ssc).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different StreamingContexts") - } - - if (parents.map(_.slideDuration).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different slide times") - } + require(parents.length > 0, "List of DStreams to union is empty") + require(parents.map(_.ssc).distinct.size == 1, "Some of the DStreams have different contexts") + require(parents.map(_.slideDuration).distinct.size == 1, + "Some of the DStreams have different slide durations") override def dependencies = parents.toList diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 5f5e1909908d5..02758e0bca6c5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -70,18 +70,7 @@ private[streaming] class ReceivedBlockTracker( private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, ReceivedBlockQueue] private val timeToAllocatedBlocks = new mutable.HashMap[Time, AllocatedBlocks] - - private val logManagerRollingIntervalSecs = conf.getInt( - "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", 60) - private val logManagerOption = checkpointDirOption.map { checkpointDir => - new WriteAheadLogManager( - ReceivedBlockTracker.checkpointDirToLogDir(checkpointDir), - hadoopConf, - rollingIntervalSecs = logManagerRollingIntervalSecs, - callerName = "ReceivedBlockHandlerMaster", - clock = clock - ) - } + private val logManagerOption = createLogManager() private var lastAllocatedBatchTime: Time = null @@ -221,6 +210,30 @@ private[streaming] class ReceivedBlockTracker( private def getReceivedBlockQueue(streamId: Int): ReceivedBlockQueue = { streamIdToUnallocatedBlockQueues.getOrElseUpdate(streamId, new ReceivedBlockQueue) } + + /** Optionally create the write ahead log manager only if the feature is enabled */ + private def createLogManager(): Option[WriteAheadLogManager] = { + if (conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + if (checkpointDirOption.isEmpty) { + throw new SparkException( + "Cannot enable receiver write-ahead log without checkpoint directory set. " + + "Please use streamingContext.checkpoint() to set the checkpoint directory. " + + "See documentation for more details.") + } + val logDir = ReceivedBlockTracker.checkpointDirToLogDir(checkpointDirOption.get) + val rollingIntervalSecs = conf.getInt( + "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", 60) + val logManager = new WriteAheadLogManager(logDir, hadoopConf, + rollingIntervalSecs = rollingIntervalSecs, clock = clock, + callerName = "ReceivedBlockHandlerMaster") + Some(logManager) + } else { + None + } + } + + /** Check if the log manager is enabled. This is only used for testing purposes. */ + private[streaming] def isLogManagerEnabled: Boolean = logManagerOption.nonEmpty } private[streaming] object ReceivedBlockTracker { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 1c3984d968d20..32e481dabc8ca 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -46,7 +46,7 @@ private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, err extends ReceiverTrackerMessage /** - * This class manages the execution of the receivers of NetworkInputDStreams. Instance of + * This class manages the execution of the receivers of ReceiverInputDStreams. Instance of * this class must be created after all input streams have been added and StreamingContext.start() * has been called because it needs the final set of input streams at the time of instantiation. * diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 30a359677cc74..86b96785d7b87 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -470,32 +470,31 @@ class BasicOperationsSuite extends TestSuiteBase { } test("slice") { - val ssc = new StreamingContext(conf, Seconds(1)) - val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) - val stream = new TestInputStream[Int](ssc, input, 2) - stream.foreachRDD(_ => {}) // Dummy output stream - ssc.start() - Thread.sleep(2000) - def getInputFromSlice(fromMillis: Long, toMillis: Long) = { - stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet - } + withStreamingContext(new StreamingContext(conf, Seconds(1))) { ssc => + val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) + val stream = new TestInputStream[Int](ssc, input, 2) + stream.foreachRDD(_ => {}) // Dummy output stream + ssc.start() + Thread.sleep(2000) + def getInputFromSlice(fromMillis: Long, toMillis: Long) = { + stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet + } - assert(getInputFromSlice(0, 1000) == Set(1)) - assert(getInputFromSlice(0, 2000) == Set(1, 2)) - assert(getInputFromSlice(1000, 2000) == Set(1, 2)) - assert(getInputFromSlice(2000, 4000) == Set(2, 3, 4)) - ssc.stop() - Thread.sleep(1000) + assert(getInputFromSlice(0, 1000) == Set(1)) + assert(getInputFromSlice(0, 2000) == Set(1, 2)) + assert(getInputFromSlice(1000, 2000) == Set(1, 2)) + assert(getInputFromSlice(2000, 4000) == Set(2, 3, 4)) + } } - test("slice - has not been initialized") { - val ssc = new StreamingContext(conf, Seconds(1)) - val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) - val stream = new TestInputStream[Int](ssc, input, 2) - val thrown = intercept[SparkException] { - stream.slice(new Time(0), new Time(1000)) + withStreamingContext(new StreamingContext(conf, Seconds(1))) { ssc => + val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) + val stream = new TestInputStream[Int](ssc, input, 2) + val thrown = intercept[SparkException] { + stream.slice(new Time(0), new Time(1000)) + } + assert(thrown.getMessage.contains("has not been initialized")) } - assert(thrown.getMessage.contains("has not been initialized")) } val cleanupTestInput = (0 until 10).map(x => Seq(x, x + 1)).toSeq @@ -555,73 +554,72 @@ class BasicOperationsSuite extends TestSuiteBase { test("rdd cleanup - input blocks and persisted RDDs") { // Actually receive data over through receiver to create BlockRDDs - // Start the server - val testServer = new TestServer() - testServer.start() - - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) - val mappedStream = networkStream.map(_ + ".").persist() - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(mappedStream, outputBuffer) - - outputStream.register() - ssc.start() - - // Feed data to the server to send to the network receiver - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = Seq(1, 2, 3, 4, 5, 6) + withTestServer(new TestServer()) { testServer => + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + testServer.start() + // Set up the streaming context and input streams + val networkStream = + ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) + val mappedStream = networkStream.map(_ + ".").persist() + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(mappedStream, outputBuffer) + + outputStream.register() + ssc.start() + + // Feed data to the server to send to the network receiver + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val input = Seq(1, 2, 3, 4, 5, 6) + + val blockRdds = new mutable.HashMap[Time, BlockRDD[_]] + val persistentRddIds = new mutable.HashMap[Time, Int] + + def collectRddInfo() { // get all RDD info required for verification + networkStream.generatedRDDs.foreach { case (time, rdd) => + blockRdds(time) = rdd.asInstanceOf[BlockRDD[_]] + } + mappedStream.generatedRDDs.foreach { case (time, rdd) => + persistentRddIds(time) = rdd.id + } + } - val blockRdds = new mutable.HashMap[Time, BlockRDD[_]] - val persistentRddIds = new mutable.HashMap[Time, Int] + Thread.sleep(200) + for (i <- 0 until input.size) { + testServer.send(input(i).toString + "\n") + Thread.sleep(200) + clock.addToTime(batchDuration.milliseconds) + collectRddInfo() + } - def collectRddInfo() { // get all RDD info required for verification - networkStream.generatedRDDs.foreach { case (time, rdd) => - blockRdds(time) = rdd.asInstanceOf[BlockRDD[_]] - } - mappedStream.generatedRDDs.foreach { case (time, rdd) => - persistentRddIds(time) = rdd.id + Thread.sleep(200) + collectRddInfo() + logInfo("Stopping server") + testServer.stop() + + // verify data has been received + assert(outputBuffer.size > 0) + assert(blockRdds.size > 0) + assert(persistentRddIds.size > 0) + + import Time._ + + val latestPersistedRddId = persistentRddIds(persistentRddIds.keySet.max) + val earliestPersistedRddId = persistentRddIds(persistentRddIds.keySet.min) + val latestBlockRdd = blockRdds(blockRdds.keySet.max) + val earliestBlockRdd = blockRdds(blockRdds.keySet.min) + // verify that the latest mapped RDD is persisted but the earliest one has been unpersisted + assert(ssc.sparkContext.persistentRdds.contains(latestPersistedRddId)) + assert(!ssc.sparkContext.persistentRdds.contains(earliestPersistedRddId)) + + // verify that the latest input blocks are present but the earliest blocks have been removed + assert(latestBlockRdd.isValid) + assert(latestBlockRdd.collect != null) + assert(!earliestBlockRdd.isValid) + earliestBlockRdd.blockIds.foreach { blockId => + assert(!ssc.sparkContext.env.blockManager.master.contains(blockId)) + } } } - - Thread.sleep(200) - for (i <- 0 until input.size) { - testServer.send(input(i).toString + "\n") - Thread.sleep(200) - clock.addToTime(batchDuration.milliseconds) - collectRddInfo() - } - - Thread.sleep(200) - collectRddInfo() - logInfo("Stopping server") - testServer.stop() - logInfo("Stopping context") - - // verify data has been received - assert(outputBuffer.size > 0) - assert(blockRdds.size > 0) - assert(persistentRddIds.size > 0) - - import Time._ - - val latestPersistedRddId = persistentRddIds(persistentRddIds.keySet.max) - val earliestPersistedRddId = persistentRddIds(persistentRddIds.keySet.min) - val latestBlockRdd = blockRdds(blockRdds.keySet.max) - val earliestBlockRdd = blockRdds(blockRdds.keySet.min) - // verify that the latest mapped RDD is persisted but the earliest one has been unpersisted - assert(ssc.sparkContext.persistentRdds.contains(latestPersistedRddId)) - assert(!ssc.sparkContext.persistentRdds.contains(earliestPersistedRddId)) - - // verify that the latest input blocks are present but the earliest blocks have been removed - assert(latestBlockRdd.isValid) - assert(latestBlockRdd.collect != null) - assert(!earliestBlockRdd.isValid) - earliestBlockRdd.blockIds.foreach { blockId => - assert(!ssc.sparkContext.env.blockManager.master.contains(blockId)) - } - ssc.stop() } /** Test cleanup of RDDs in DStream metadata */ @@ -635,13 +633,15 @@ class BasicOperationsSuite extends TestSuiteBase { // Setup the stream computation assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second, check cleanup tests") - val ssc = setupStreams(cleanupTestInput, operation) - val operatedStream = ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]] - if (rememberDuration != null) ssc.remember(rememberDuration) - val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput) - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - assert(clock.time === Seconds(10).milliseconds) - assert(output.size === numExpectedOutput) - operatedStream + withStreamingContext(setupStreams(cleanupTestInput, operation)) { ssc => + val operatedStream = + ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]] + if (rememberDuration != null) ssc.remember(rememberDuration) + val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput) + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + assert(clock.time === Seconds(10).milliseconds) + assert(output.size === numExpectedOutput) + operatedStream + } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index e5592e52b0d2d..c97998add8ffa 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -22,9 +22,14 @@ import java.nio.charset.Charset import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag + import com.google.common.io.Files -import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.io.{IntWritable, Text} +import org.apache.hadoop.mapred.TextOutputFormat +import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} + import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} import org.apache.spark.streaming.util.ManualClock @@ -205,6 +210,51 @@ class CheckpointSuite extends TestSuiteBase { testCheckpointedOperation(input, operation, output, 7) } + test("recovery with saveAsHadoopFiles operation") { + val tempDir = Files.createTempDir() + try { + testCheckpointedOperation( + Seq(Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq()), + (s: DStream[String]) => { + val output = s.map(x => (x, 1)).reduceByKey(_ + _) + output.saveAsHadoopFiles( + tempDir.toURI.toString, + "result", + classOf[Text], + classOf[IntWritable], + classOf[TextOutputFormat[Text, IntWritable]]) + output + }, + Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + 3 + ) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("recovery with saveAsNewAPIHadoopFiles operation") { + val tempDir = Files.createTempDir() + try { + testCheckpointedOperation( + Seq(Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq()), + (s: DStream[String]) => { + val output = s.map(x => (x, 1)).reduceByKey(_ + _) + output.saveAsNewAPIHadoopFiles( + tempDir.toURI.toString, + "result", + classOf[Text], + classOf[IntWritable], + classOf[NewTextOutputFormat[Text, IntWritable]]) + output + }, + Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + 3 + ) + } finally { + Utils.deleteRecursively(tempDir) + } + } // This tests whether the StateDStream's RDD checkpoints works correctly such // that the system can recover from a master failure. This assumes as reliable, @@ -265,7 +315,7 @@ class CheckpointSuite extends TestSuiteBase { // Verify whether files created have been recorded correctly or not var fileInputDStream = ssc.graph.getInputStreams().head.asInstanceOf[FileInputDStream[_, _, _]] - def recordedFiles = fileInputDStream.files.values.flatMap(x => x) + def recordedFiles = fileInputDStream.batchTimeToSelectedFiles.values.flatten assert(!recordedFiles.filter(_.endsWith("1")).isEmpty) assert(!recordedFiles.filter(_.endsWith("2")).isEmpty) assert(!recordedFiles.filter(_.endsWith("3")).isEmpty) @@ -391,7 +441,9 @@ class CheckpointSuite extends TestSuiteBase { logInfo("Manual clock after advancing = " + clock.time) Thread.sleep(batchDuration.milliseconds) - val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]] + val outputStream = ssc.graph.getOutputStreams.filter { dstream => + dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] + }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] outputStream.output.map(_.flatten) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index fa04fa326e370..307052a4a9cbb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -28,9 +28,12 @@ import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer, SynchronizedQueue} +import scala.concurrent.duration._ +import scala.language.postfixOps import com.google.common.io.Files import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually._ import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel @@ -38,6 +41,9 @@ import org.apache.spark.streaming.util.ManualClock import org.apache.spark.util.Utils import org.apache.spark.streaming.receiver.{ActorHelper, Receiver} import org.apache.spark.rdd.RDD +import org.apache.hadoop.io.{Text, LongWritable} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.apache.hadoop.fs.Path class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { @@ -91,54 +97,12 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } - test("file input stream") { - // Disable manual clock as FileInputDStream does not work with manual clock - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") - - // Set up the streaming context and input streams - val testDir = Utils.createTempDir() - val ssc = new StreamingContext(conf, batchDuration) - val fileStream = ssc.textFileStream(testDir.toString) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - def output = outputBuffer.flatMap(x => x) - val outputStream = new TestOutputStream(fileStream, outputBuffer) - outputStream.register() - ssc.start() - - // Create files in the temporary directory so that Spark Streaming can read data from it - val input = Seq(1, 2, 3, 4, 5) - val expectedOutput = input.map(_.toString) - Thread.sleep(1000) - for (i <- 0 until input.size) { - val file = new File(testDir, i.toString) - Files.write(input(i) + "\n", file, Charset.forName("UTF-8")) - logInfo("Created file " + file) - Thread.sleep(batchDuration.milliseconds) - Thread.sleep(1000) - } - val startTime = System.currentTimeMillis() - Thread.sleep(1000) - val timeTaken = System.currentTimeMillis() - startTime - assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") - logInfo("Stopping context") - ssc.stop() - - // Verify whether data received by Spark Streaming was as expected - logInfo("--------------------------------") - logInfo("output, size = " + outputBuffer.size) - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output, size = " + expectedOutput.size) - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") - - // Verify whether all the elements received are as expected - // (whether the elements were received one in each interval is not verified) - assert(output.toList === expectedOutput.toList) - - Utils.deleteRecursively(testDir) + test("file input stream - newFilesOnly = true") { + testFileStream(newFilesOnly = true) + } - // Enable manual clock back again for other tests - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + test("file input stream - newFilesOnly = false") { + testFileStream(newFilesOnly = false) } test("multi-thread receiver") { @@ -180,7 +144,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { assert(output.sum === numTotalRecords) } - test("queue input stream - oneAtATime=true") { + test("queue input stream - oneAtATime = true") { // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) val queue = new SynchronizedQueue[RDD[String]]() @@ -223,7 +187,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } } - test("queue input stream - oneAtATime=false") { + test("queue input stream - oneAtATime = false") { // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) val queue = new SynchronizedQueue[RDD[String]]() @@ -268,6 +232,50 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { assert(output(i) === expectedOutput(i)) } } + + def testFileStream(newFilesOnly: Boolean) { + var ssc: StreamingContext = null + val testDir: File = null + try { + val testDir = Utils.createTempDir() + val existingFile = new File(testDir, "0") + Files.write("0\n", existingFile, Charset.forName("UTF-8")) + + Thread.sleep(1000) + // Set up the streaming context and input streams + val newConf = conf.clone.set( + "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") + ssc = new StreamingContext(newConf, batchDuration) + val fileStream = ssc.fileStream[LongWritable, Text, TextInputFormat]( + testDir.toString, (x: Path) => true, newFilesOnly = newFilesOnly).map(_._2.toString) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(fileStream, outputBuffer) + outputStream.register() + ssc.start() + + // Create files in the directory + val input = Seq(1, 2, 3, 4, 5) + input.foreach { i => + Thread.sleep(batchDuration.milliseconds) + val file = new File(testDir, i.toString) + Files.write(i + "\n", file, Charset.forName("UTF-8")) + logInfo("Created file " + file) + } + + // Verify that all the files have been read + val expectedOutput = if (newFilesOnly) { + input.map(_.toString).toSet + } else { + (Seq(0) ++ input).map(_.toString).toSet + } + eventually(timeout(maxWaitTimeMillis milliseconds), interval(100 milliseconds)) { + assert(outputBuffer.flatten.toSet === expectedOutput) + } + } finally { + if (ssc != null) ssc.stop() + if (testDir != null) Utils.deleteRecursively(testDir) + } + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index fd9c97f551c62..01a09b67b99dc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -41,17 +41,16 @@ import org.apache.spark.util.Utils class ReceivedBlockTrackerSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { - val conf = new SparkConf().setMaster("local[2]").setAppName("ReceivedBlockTrackerSuite") - conf.set("spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", "1") - val hadoopConf = new Configuration() val akkaTimeout = 10 seconds val streamId = 1 var allReceivedBlockTrackers = new ArrayBuffer[ReceivedBlockTracker]() var checkpointDirectory: File = null + var conf: SparkConf = null before { + conf = new SparkConf().setMaster("local[2]").setAppName("ReceivedBlockTrackerSuite") checkpointDirectory = Files.createTempDir() } @@ -64,7 +63,8 @@ class ReceivedBlockTrackerSuite } test("block addition, and block to batch allocation") { - val receivedBlockTracker = createTracker(enableCheckpoint = false) + val receivedBlockTracker = createTracker(setCheckpointDir = false) + receivedBlockTracker.isLogManagerEnabled should be (false) // should be disable by default receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual Seq.empty val blockInfos = generateBlockInfos() @@ -95,13 +95,11 @@ class ReceivedBlockTrackerSuite test("block addition, block to batch allocation and cleanup with write ahead log") { val manualClock = new ManualClock - conf.getInt( - "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", -1) should be (1) - // Set the time increment level to twice the rotation interval so that every increment creates // a new log file - val timeIncrementMillis = 2000L + def incrementTime() { + val timeIncrementMillis = 2000L manualClock.addToTime(timeIncrementMillis) } @@ -121,7 +119,11 @@ class ReceivedBlockTrackerSuite } // Start tracker and add blocks - val tracker1 = createTracker(enableCheckpoint = true, clock = manualClock) + conf.set("spark.streaming.receiver.writeAheadLog.enable", "true") + conf.set("spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", "1") + val tracker1 = createTracker(clock = manualClock) + tracker1.isLogManagerEnabled should be (true) + val blockInfos1 = addBlockInfos(tracker1) tracker1.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 @@ -132,7 +134,7 @@ class ReceivedBlockTrackerSuite // Restart tracker and verify recovered list of unallocated blocks incrementTime() - val tracker2 = createTracker(enableCheckpoint = true, clock = manualClock) + val tracker2 = createTracker(clock = manualClock) tracker2.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 // Allocate blocks to batch and verify whether the unallocated blocks got allocated @@ -156,7 +158,7 @@ class ReceivedBlockTrackerSuite // Restart tracker and verify recovered state incrementTime() - val tracker3 = createTracker(enableCheckpoint = true, clock = manualClock) + val tracker3 = createTracker(clock = manualClock) tracker3.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1 tracker3.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 tracker3.getUnallocatedBlocks(streamId) shouldBe empty @@ -179,18 +181,38 @@ class ReceivedBlockTrackerSuite // Restart tracker and verify recovered state, specifically whether info about the first // batch has been removed, but not the second batch incrementTime() - val tracker4 = createTracker(enableCheckpoint = true, clock = manualClock) + val tracker4 = createTracker(clock = manualClock) tracker4.getUnallocatedBlocks(streamId) shouldBe empty tracker4.getBlocksOfBatchAndStream(batchTime1, streamId) shouldBe empty // should be cleaned tracker4.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 } + + test("enabling write ahead log but not setting checkpoint dir") { + conf.set("spark.streaming.receiver.writeAheadLog.enable", "true") + intercept[SparkException] { + createTracker(setCheckpointDir = false) + } + } + + test("setting checkpoint dir but not enabling write ahead log") { + // When WAL config is not set, log manager should not be enabled + val tracker1 = createTracker(setCheckpointDir = true) + tracker1.isLogManagerEnabled should be (false) + + // When WAL is explicitly disabled, log manager should not be enabled + conf.set("spark.streaming.receiver.writeAheadLog.enable", "false") + val tracker2 = createTracker(setCheckpointDir = true) + tracker2.isLogManagerEnabled should be(false) + } /** * Create tracker object with the optional provided clock. Use fake clock if you * want to control time by manually incrementing it to test log cleanup. */ - def createTracker(enableCheckpoint: Boolean, clock: Clock = new SystemClock): ReceivedBlockTracker = { - val cpDirOption = if (enableCheckpoint) Some(checkpointDirectory.toString) else None + def createTracker( + setCheckpointDir: Boolean = true, + clock: Clock = new SystemClock): ReceivedBlockTracker = { + val cpDirOption = if (setCheckpointDir) Some(checkpointDirectory.toString) else None val tracker = new ReceivedBlockTracker(conf, hadoopConf, Seq(streamId), clock, cpDirOption) allReceivedBlockTrackers += tracker tracker diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 4b49c4d251645..9f352bdcb0893 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -336,16 +336,20 @@ package object testPackage extends Assertions { // Verify creation site of generated RDDs var rddGenerated = false - var rddCreationSiteCorrect = true + var rddCreationSiteCorrect = false + var foreachCallSiteCorrect = false inputStream.foreachRDD { rdd => rddCreationSiteCorrect = rdd.creationSite == creationSite + foreachCallSiteCorrect = + rdd.sparkContext.getCallSite().shortForm.contains("StreamingContextSuite") rddGenerated = true } ssc.start() eventually(timeout(10000 millis), interval(10 millis)) { assert(rddGenerated && rddCreationSiteCorrect, "RDD creation site was not correct") + assert(rddGenerated && foreachCallSiteCorrect, "Call site in foreachRDD was not correct") } } finally { ssc.stop() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 2154c24abda3a..52972f63c6c5c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -163,6 +163,40 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { before(beforeFunction) after(afterFunction) + /** + * Run a block of code with the given StreamingContext and automatically + * stop the context when the block completes or when an exception is thrown. + */ + def withStreamingContext[R](ssc: StreamingContext)(block: StreamingContext => R): R = { + try { + block(ssc) + } finally { + try { + ssc.stop(stopSparkContext = true) + } catch { + case e: Exception => + logError("Error stopping StreamingContext", e) + } + } + } + + /** + * Run a block of code with the given TestServer and automatically + * stop the server when the block completes or when an exception is thrown. + */ + def withTestServer[R](testServer: TestServer)(block: TestServer => R): R = { + try { + block(testServer) + } finally { + try { + testServer.stop() + } catch { + case e: Exception => + logError("Error stopping TestServer", e) + } + } + } + /** * Set up required DStreams to test the DStream operation using the two sequences * of input collections. @@ -282,10 +316,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") Thread.sleep(100) // Give some time for the forgetting old RDDs to complete - } catch { - case e: Exception => {e.printStackTrace(); throw e} } finally { - ssc.stop() + ssc.stop(stopSparkContext = true) } output } @@ -351,9 +383,10 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { useSet: Boolean ) { val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size - val ssc = setupStreams[U, V](input, operation) - val output = runStreams[V](ssc, numBatches_, expectedOutput.size) - verifyOutput[V](output, expectedOutput, useSet) + withStreamingContext(setupStreams[U, V](input, operation)) { ssc => + val output = runStreams[V](ssc, numBatches_, expectedOutput.size) + verifyOutput[V](output, expectedOutput, useSet) + } } /** @@ -389,8 +422,9 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { useSet: Boolean ) { val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size - val ssc = setupStreams[U, V, W](input1, input2, operation) - val output = runStreams[W](ssc, numBatches_, expectedOutput.size) - verifyOutput[W](output, expectedOutput, useSet) + withStreamingContext(setupStreams[U, V, W](input1, input2, operation)) { ssc => + val output = runStreams[W](ssc, numBatches_, expectedOutput.size) + verifyOutput[W](output, expectedOutput, useSet) + } } } diff --git a/tools/pom.xml b/tools/pom.xml index b90eb0ca250c5..c0bc6e2a2af9d 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml index 7dadbba58fd82..40e9e99c6f855 100644 --- a/yarn/alpha/pom.xml +++ b/yarn/alpha/pom.xml @@ -20,7 +20,7 @@ org.apache.spark yarn-parent_2.10 - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml diff --git a/yarn/pom.xml b/yarn/pom.xml index 2885e6607ec24..bba73648c7abe 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index fe55d70ccc370..8b6521ad7f859 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -20,7 +20,7 @@ org.apache.spark yarn-parent_2.10 - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml