From ca7910d6dd7693be2a675a0d6a6fcc9eb0aaeb5d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 21 Jan 2015 21:20:31 -0800 Subject: [PATCH 01/33] [SPARK-3424][MLLIB] cache point distances during k-means|| init This PR ports the following feature implemented in #2634 by derrickburns: * During k-means|| initialization, we should cache costs (squared distances) previously computed. It also contains the following optimization: * aggregate sumCosts directly * ran multiple (#runs) k-means++ in parallel I compared the performance locally on mnist-digit. Before this patch: ![before](https://cloud.githubusercontent.com/assets/829644/5845647/93080862-a172-11e4-9a35-044ec711afc4.png) with this patch: ![after](https://cloud.githubusercontent.com/assets/829644/5845653/a47c29e8-a172-11e4-8e9f-08db57fe3502.png) It is clear that each k-means|| iteration takes about the same amount of time with this patch. Authors: Derrick Burns Xiangrui Meng Closes #4144 from mengxr/SPARK-3424-kmeans-parallel and squashes the following commits: 0a875ec [Xiangrui Meng] address comments 4341bb8 [Xiangrui Meng] do not re-compute point distances during k-means|| --- .../spark/mllib/clustering/KMeans.scala | 65 ++++++++++++++----- 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 6b5c934f015ba..fc46da3a93425 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -279,45 +279,80 @@ class KMeans private ( */ private def initKMeansParallel(data: RDD[VectorWithNorm]) : Array[Array[VectorWithNorm]] = { - // Initialize each run's center to a random point + // Initialize empty centers and point costs. + val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm]) + var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache() + + // Initialize each run's first center to a random point. val seed = new XORShiftRandom(this.seed).nextInt() val sample = data.takeSample(true, runs, seed).toSeq - val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) + val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) + + /** Merges new centers to centers. */ + def mergeNewCenters(): Unit = { + var r = 0 + while (r < runs) { + centers(r) ++= newCenters(r) + newCenters(r).clear() + r += 1 + } + } // On each step, sample 2 * k points on average for each run with probability proportional - // to their squared distance from that run's current centers + // to their squared distance from that run's centers. Note that only distances between points + // and new centers are computed in each iteration. var step = 0 while (step < initializationSteps) { - val bcCenters = data.context.broadcast(centers) - val sumCosts = data.flatMap { point => - (0 until runs).map { r => - (r, KMeans.pointCost(bcCenters.value(r), point)) - } - }.reduceByKey(_ + _).collectAsMap() - val chosen = data.mapPartitionsWithIndex { (index, points) => + val bcNewCenters = data.context.broadcast(newCenters) + val preCosts = costs + costs = data.zip(preCosts).map { case (point, cost) => + Vectors.dense( + Array.tabulate(runs) { r => + math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r)) + }) + }.cache() + val sumCosts = costs + .aggregate(Vectors.zeros(runs))( + seqOp = (s, v) => { + // s += v + axpy(1.0, v, s) + s + }, + combOp = (s0, s1) => { + // s0 += s1 + axpy(1.0, s1, s0) + s0 + } + ) + preCosts.unpersist(blocking = false) + val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) => val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) - points.flatMap { p => + pointsWithCosts.flatMap { case (p, c) => (0 until runs).filter { r => - rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r) + rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r) }.map((_, p)) } }.collect() + mergeNewCenters() chosen.foreach { case (r, p) => - centers(r) += p.toDense + newCenters(r) += p.toDense } step += 1 } + mergeNewCenters() + costs.unpersist(blocking = false) + // Finally, we might have a set of more than k candidate centers for each run; weigh each // candidate by the number of points in the dataset mapping to it and run a local k-means++ // on the weighted centers to pick just k of them val bcCenters = data.context.broadcast(centers) val weightMap = data.flatMap { p => - (0 until runs).map { r => + Iterator.tabulate(runs) { r => ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0) } }.reduceByKey(_ + _).collectAsMap() - val finalCenters = (0 until runs).map { r => + val finalCenters = (0 until runs).par.map { r => val myCenters = centers(r).toArray val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30) From fcb3e1862ffe784f39bde467e8d24c1b7ed3afbb Mon Sep 17 00:00:00 2001 From: Basin Date: Wed, 21 Jan 2015 23:06:34 -0800 Subject: [PATCH 02/33] [SPARK-5317]Set BoostingStrategy.defaultParams With Enumeration Algo.Classification or Algo.Regression JIRA Issue: https://issues.apache.org/jira/browse/SPARK-5317 When setting the BoostingStrategy.defaultParams("Classification"), It's more straightforward to set it with the Enumeration Algo.Classification, just like BoostingStragety.defaultParams(Algo.Classification). I overload the method BoostingStragety.defaultParams(). Author: Basin Closes #4103 from Peishen-Jia/stragetyAlgo and squashes the following commits: 87bab1c [Basin] Docs and Code documentations updated. 3b72875 [Basin] defaultParams(algoStr: String) call defaultParams(algo: Algo). 7c1e6ee [Basin] Doc of Java updated. algo -> algoStr instead. d5c8a2e [Basin] Merge branch 'stragetyAlgo' of github.com:Peishen-Jia/spark into stragetyAlgo 65f96ce [Basin] mllib-ensembles doc modified. e04a5aa [Basin] boostingstrategy.defaultParam string algo to enumeration. 68cf544 [Basin] mllib-ensembles doc modified. a4aea51 [Basin] boostingstrategy.defaultParam string algo to enumeration. --- .../tree/configuration/BoostingStrategy.scala | 25 +++++++++++++------ .../mllib/tree/configuration/Strategy.scala | 14 ++++++++--- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index cf51d041c65a9..ed8e6a796f8c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -68,6 +68,15 @@ case class BoostingStrategy( @Experimental object BoostingStrategy { + /** + * Returns default configuration for the boosting algorithm + * @param algo Learning goal. Supported: "Classification" or "Regression" + * @return Configuration for boosting algorithm + */ + def defaultParams(algo: String): BoostingStrategy = { + defaultParams(Algo.fromString(algo)) + } + /** * Returns default configuration for the boosting algorithm * @param algo Learning goal. Supported: @@ -75,15 +84,15 @@ object BoostingStrategy { * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ - def defaultParams(algo: String): BoostingStrategy = { - val treeStrategy = Strategy.defaultStrategy(algo) - treeStrategy.maxDepth = 3 + def defaultParams(algo: Algo): BoostingStrategy = { + val treeStragtegy = Strategy.defaultStategy(algo) + treeStragtegy.maxDepth = 3 algo match { - case "Classification" => - treeStrategy.numClasses = 2 - new BoostingStrategy(treeStrategy, LogLoss) - case "Regression" => - new BoostingStrategy(treeStrategy, SquaredError) + case Algo.Classification => + treeStragtegy.numClasses = 2 + new BoostingStrategy(treeStragtegy, LogLoss) + case Algo.Regression => + new BoostingStrategy(treeStragtegy, SquaredError) case _ => throw new IllegalArgumentException(s"$algo is not supported by boosting.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index d5cd89ab94e81..972959885f396 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -173,11 +173,19 @@ object Strategy { * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] * @param algo "Classification" or "Regression" */ - def defaultStrategy(algo: String): Strategy = algo match { - case "Classification" => + def defaultStrategy(algo: String): Strategy = { + defaultStategy(Algo.fromString(algo)) + } + + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo Algo.Classification or Algo.Regression + */ + def defaultStategy(algo: Algo): Strategy = algo match { + case Algo.Classification => new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, numClasses = 2) - case "Regression" => + case Algo.Regression => new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, numClasses = 0) } From 3027f06b4127ab23a43c5ce8cebf721e3b6766e5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 21 Jan 2015 23:41:44 -0800 Subject: [PATCH 03/33] [SPARK-5147][Streaming] Delete the received data WAL log periodically This is a refactored fix based on jerryshao 's PR #4037 This enabled deletion of old WAL files containing the received block data. Improvements over #4037 - Respecting the rememberDuration of all receiver streams. In #4037, if there were two receiver streams with multiple remember durations, the deletion would have delete based on the shortest remember duration, thus deleting data prematurely for the receiver stream with longer remember duration. - Added unit test to test creation of receiver WAL, automatic deletion, and respecting of remember duration. jerryshao I am going to merge this ASAP to make it 1.2.1 Thanks for the initial draft of this PR. Made my job much easier. Author: Tathagata Das Author: jerryshao Closes #4149 from tdas/SPARK-5147 and squashes the following commits: 730798b [Tathagata Das] Added comments. c4cf067 [Tathagata Das] Minor fixes 2579b27 [Tathagata Das] Refactored the fix to make sure that the cleanup respects the remember duration of all the receiver streams 2736fd1 [jerryshao] Delete the old WAL log periodically --- .../apache/spark/streaming/DStreamGraph.scala | 8 + .../dstream/ReceiverInputDStream.scala | 11 -- .../streaming/receiver/ReceiverMessage.scala | 5 +- .../receiver/ReceiverSupervisorImpl.scala | 9 + .../streaming/scheduler/JobGenerator.scala | 11 +- .../scheduler/ReceivedBlockTracker.scala | 1 - .../streaming/scheduler/ReceiverTracker.scala | 18 +- .../spark/streaming/util/HdfsUtils.scala | 2 +- .../spark/streaming/ReceiverSuite.scala | 157 ++++++++++++++---- 9 files changed, 172 insertions(+), 50 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index e59c24adb84af..0e285d6088ec1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -160,6 +160,14 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { } } + /** + * Get the maximum remember duration across all the input streams. This is a conservative but + * safe remember duration which can be used to perform cleanup operations. + */ + def getMaxInputStreamRememberDuration(): Duration = { + inputStreams.map { _.rememberDuration }.maxBy { _.milliseconds } + } + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { logDebug("DStreamGraph.writeObject used") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index afd3c4bc4c4fe..8be04314c4285 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -94,15 +94,4 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont } Some(blockRDD) } - - /** - * Clear metadata that are older than `rememberDuration` of this DStream. - * This is an internal method that should not be called directly. This - * implementation overrides the default implementation to clear received - * block information. - */ - private[streaming] override def clearMetadata(time: Time) { - super.clearMetadata(time) - ssc.scheduler.receiverTracker.cleanupOldMetadata(time - rememberDuration) - } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala index ab9fa192191aa..7bf3c33319491 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala @@ -17,7 +17,10 @@ package org.apache.spark.streaming.receiver -/** Messages sent to the NetworkReceiver. */ +import org.apache.spark.streaming.Time + +/** Messages sent to the Receiver. */ private[streaming] sealed trait ReceiverMessage extends Serializable private[streaming] object StopReceiver extends ReceiverMessage +private[streaming] case class CleanupOldBlocks(threshTime: Time) extends ReceiverMessage diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index d7229c2b96d0b..716cf2c7f32fc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -82,6 +83,9 @@ private[streaming] class ReceiverSupervisorImpl( case StopReceiver => logInfo("Received stop signal") stop("Stopped by driver", None) + case CleanupOldBlocks(threshTime) => + logDebug("Received delete old batch signal") + cleanupOldBlocks(threshTime) } def ref = self @@ -193,4 +197,9 @@ private[streaming] class ReceiverSupervisorImpl( /** Generate new block ID */ private def nextBlockId = StreamBlockId(streamId, newBlockId.getAndIncrement) + + private def cleanupOldBlocks(cleanupThreshTime: Time): Unit = { + logDebug(s"Cleaning up blocks older then $cleanupThreshTime") + receivedBlockHandler.cleanupOldBlocks(cleanupThreshTime.milliseconds) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 39b66e1130768..d86f852aba97e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -238,13 +238,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream metadata for the given `time`. */ private def clearMetadata(time: Time) { ssc.graph.clearMetadata(time) - jobScheduler.receiverTracker.cleanupOldMetadata(time - graph.batchDuration) // If checkpointing is enabled, then checkpoint, // else mark batch to be fully processed if (shouldCheckpoint) { eventActor ! DoCheckpoint(time) } else { + // If checkpointing is not enabled, then delete metadata information about + // received blocks (block data not saved in any case). Otherwise, wait for + // checkpointing of this batch to complete. + val maxRememberDuration = graph.getMaxInputStreamRememberDuration() + jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) markBatchFullyProcessed(time) } } @@ -252,6 +256,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream checkpoint data for the given `time`. */ private def clearCheckpointData(time: Time) { ssc.graph.clearCheckpointData(time) + + // All the checkpoint information about which batches have been processed, etc have + // been saved to checkpoints, so its safe to delete block metadata and data WAL files + val maxRememberDuration = graph.getMaxInputStreamRememberDuration() + jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) markBatchFullyProcessed(time) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index c3d9d7b6813d3..ef23b5c79f2e1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -150,7 +150,6 @@ private[streaming] class ReceivedBlockTracker( writeToLog(BatchCleanupEvent(timesToCleanup)) timeToAllocatedBlocks --= timesToCleanup logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds, waitForCompletion)) - log } /** Stop the block tracker. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 8dbb42a86e3bd..4f998869731ed 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -24,9 +24,8 @@ import scala.language.existentials import akka.actor._ import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} -import org.apache.spark.SparkContext._ import org.apache.spark.streaming.{StreamingContext, Time} -import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver} +import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver} /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -119,9 +118,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } } - /** Clean up metadata older than the given threshold time */ - def cleanupOldMetadata(cleanupThreshTime: Time) { + /** + * Clean up the data and metadata of blocks and batches that are strictly + * older than the threshold time. Note that this does not + */ + def cleanupOldBlocksAndBatches(cleanupThreshTime: Time) { + // Clean up old block and batch metadata receivedBlockTracker.cleanupOldBatches(cleanupThreshTime, waitForCompletion = false) + + // Signal the receivers to delete old block data + if (ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + logInfo(s"Cleanup old received batch data: $cleanupThreshTime") + receiverInfo.values.flatMap { info => Option(info.actor) } + .foreach { _ ! CleanupOldBlocks(cleanupThreshTime) } + } } /** Register a receiver */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala index 27a28bab83ed5..858ba3c9eb4e5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala @@ -63,7 +63,7 @@ private[streaming] object HdfsUtils { } def getFileSystemForPath(path: Path, conf: Configuration): FileSystem = { - // For local file systems, return the raw loca file system, such calls to flush() + // For local file systems, return the raw local file system, such calls to flush() // actually flushes the stream. val fs = path.getFileSystem(conf) fs match { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index e26c0c6859e57..e8c34a9ee40b9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -17,21 +17,26 @@ package org.apache.spark.streaming +import java.io.File import java.nio.ByteBuffer import java.util.concurrent.Semaphore +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.SparkConf -import org.apache.spark.storage.{StorageLevel, StreamBlockId} -import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver, ReceiverSupervisor} -import org.scalatest.FunSuite +import com.google.common.io.Files import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.receiver._ +import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ + /** Testsuite for testing the network receiver behavior */ -class ReceiverSuite extends FunSuite with Timeouts { +class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { test("receiver life cycle") { @@ -192,7 +197,6 @@ class ReceiverSuite extends FunSuite with Timeouts { val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 3 val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 1 val receivedBlockSizes = recordedBlocks.map { _.size }.mkString(",") - println(minExpectedMessagesPerBlock, maxExpectedMessagesPerBlock, ":", receivedBlockSizes) assert( // the first and last block may be incomplete, so we slice them out recordedBlocks.drop(1).dropRight(1).forall { block => @@ -203,39 +207,91 @@ class ReceiverSuite extends FunSuite with Timeouts { ) } - /** - * An implementation of NetworkReceiver that is used for testing a receiver's life cycle. + * Test whether write ahead logs are generated by received, + * and automatically cleaned up. The clean up must be aware of the + * remember duration of the input streams. E.g., input streams on which window() + * has been applied must remember the data for longer, and hence corresponding + * WALs should be cleaned later. */ - class FakeReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) { - @volatile var otherThread: Thread = null - @volatile var receiving = false - @volatile var onStartCalled = false - @volatile var onStopCalled = false - - def onStart() { - otherThread = new Thread() { - override def run() { - receiving = true - while(!isStopped()) { - Thread.sleep(10) - } + test("write ahead log - generating and cleaning") { + val sparkConf = new SparkConf() + .setMaster("local[4]") // must be at least 3 as we are going to start 2 receivers + .setAppName(framework) + .set("spark.ui.enabled", "true") + .set("spark.streaming.receiver.writeAheadLog.enable", "true") + .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + val batchDuration = Milliseconds(500) + val tempDirectory = Files.createTempDir() + val logDirectory1 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 0)) + val logDirectory2 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 1)) + val allLogFiles1 = new mutable.HashSet[String]() + val allLogFiles2 = new mutable.HashSet[String]() + logInfo("Temp checkpoint directory = " + tempDirectory) + + def getBothCurrentLogFiles(): (Seq[String], Seq[String]) = { + (getCurrentLogFiles(logDirectory1), getCurrentLogFiles(logDirectory2)) + } + + def getCurrentLogFiles(logDirectory: File): Seq[String] = { + try { + if (logDirectory.exists()) { + logDirectory1.listFiles().filter { _.getName.startsWith("log") }.map { _.toString } + } else { + Seq.empty } + } catch { + case e: Exception => + Seq.empty } - onStartCalled = true - otherThread.start() - } - def onStop() { - onStopCalled = true - otherThread.join() + def printLogFiles(message: String, files: Seq[String]) { + logInfo(s"$message (${files.size} files):\n" + files.mkString("\n")) } - def reset() { - receiving = false - onStartCalled = false - onStopCalled = false + withStreamingContext(new StreamingContext(sparkConf, batchDuration)) { ssc => + tempDirectory.deleteOnExit() + val receiver1 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) + val receiver2 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) + val receiverStream1 = ssc.receiverStream(receiver1) + val receiverStream2 = ssc.receiverStream(receiver2) + receiverStream1.register() + receiverStream2.window(batchDuration * 6).register() // 3 second window + ssc.checkpoint(tempDirectory.getAbsolutePath()) + ssc.start() + + // Run until sufficient WAL files have been generated and + // the first WAL files has been deleted + eventually(timeout(20 seconds), interval(batchDuration.milliseconds millis)) { + val (logFiles1, logFiles2) = getBothCurrentLogFiles() + allLogFiles1 ++= logFiles1 + allLogFiles2 ++= logFiles2 + if (allLogFiles1.size > 0) { + assert(!logFiles1.contains(allLogFiles1.toSeq.sorted.head)) + } + if (allLogFiles2.size > 0) { + assert(!logFiles2.contains(allLogFiles2.toSeq.sorted.head)) + } + assert(allLogFiles1.size >= 7) + assert(allLogFiles2.size >= 7) + } + ssc.stop(stopSparkContext = true, stopGracefully = true) + + val sortedAllLogFiles1 = allLogFiles1.toSeq.sorted + val sortedAllLogFiles2 = allLogFiles2.toSeq.sorted + val (leftLogFiles1, leftLogFiles2) = getBothCurrentLogFiles() + + printLogFiles("Receiver 0: all", sortedAllLogFiles1) + printLogFiles("Receiver 0: left", leftLogFiles1) + printLogFiles("Receiver 1: all", sortedAllLogFiles2) + printLogFiles("Receiver 1: left", leftLogFiles2) + + // Verify that necessary latest log files are not deleted + // receiverStream1 needs to retain just the last batch = 1 log file + // receiverStream2 needs to retain 3 seconds (3-seconds window) = 3 log files + assert(sortedAllLogFiles1.takeRight(1).forall(leftLogFiles1.contains)) + assert(sortedAllLogFiles2.takeRight(3).forall(leftLogFiles2.contains)) } } @@ -315,3 +371,42 @@ class ReceiverSuite extends FunSuite with Timeouts { } } +/** + * An implementation of Receiver that is used for testing a receiver's life cycle. + */ +class FakeReceiver(sendData: Boolean = false) extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + @volatile var otherThread: Thread = null + @volatile var receiving = false + @volatile var onStartCalled = false + @volatile var onStopCalled = false + + def onStart() { + otherThread = new Thread() { + override def run() { + receiving = true + var count = 0 + while(!isStopped()) { + if (sendData) { + store(count) + count += 1 + } + Thread.sleep(10) + } + } + } + onStartCalled = true + otherThread.start() + } + + def onStop() { + onStopCalled = true + otherThread.join() + } + + def reset() { + receiving = false + onStartCalled = false + onStopCalled = false + } +} + From 246111d179a2f3f6b97a5c2b121d8ddbfd1c9aad Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Jan 2015 08:16:35 -0800 Subject: [PATCH 04/33] [SPARK-5365][MLlib] Refactor KMeans to reduce redundant data If a point is selected as new centers for many runs, it would collect many redundant data. This pr refactors it. Author: Liang-Chi Hsieh Closes #4159 from viirya/small_refactor_kmeans and squashes the following commits: 25487e6 [Liang-Chi Hsieh] Refactor codes to reduce redundant data. --- .../scala/org/apache/spark/mllib/clustering/KMeans.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index fc46da3a93425..11633e8242313 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -328,14 +328,15 @@ class KMeans private ( val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) => val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) pointsWithCosts.flatMap { case (p, c) => - (0 until runs).filter { r => + val rs = (0 until runs).filter { r => rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r) - }.map((_, p)) + } + if (rs.length > 0) Some(p, rs) else None } }.collect() mergeNewCenters() - chosen.foreach { case (r, p) => - newCenters(r) += p.toDense + chosen.foreach { case (p, rs) => + rs.foreach(newCenters(_) += p.toDense) } step += 1 } From 820ce03597350257abe0c5c96435c555038e3e6c Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Thu, 22 Jan 2015 13:49:35 -0600 Subject: [PATCH 05/33] SPARK-5370. [YARN] Remove some unnecessary synchronization in YarnAlloca... ...tor Author: Sandy Ryza Closes #4164 from sryza/sandy-spark-5370 and squashes the following commits: 0c8d736 [Sandy Ryza] SPARK-5370. [YARN] Remove some unnecessary synchronization in YarnAllocator --- .../spark/deploy/yarn/YarnAllocator.scala | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 4c35b60c57df3..d00f29665a58f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -60,7 +60,6 @@ private[yarn] class YarnAllocator( import YarnAllocator._ - // These two complementary data structures are locked on allocatedHostToContainersMap. // Visible for testing. val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]] @@ -355,20 +354,18 @@ private[yarn] class YarnAllocator( } } - allocatedHostToContainersMap.synchronized { - if (allocatedContainerToHostMap.containsKey(containerId)) { - val host = allocatedContainerToHostMap.get(containerId).get - val containerSet = allocatedHostToContainersMap.get(host).get + if (allocatedContainerToHostMap.containsKey(containerId)) { + val host = allocatedContainerToHostMap.get(containerId).get + val containerSet = allocatedHostToContainersMap.get(host).get - containerSet.remove(containerId) - if (containerSet.isEmpty) { - allocatedHostToContainersMap.remove(host) - } else { - allocatedHostToContainersMap.update(host, containerSet) - } - - allocatedContainerToHostMap.remove(containerId) + containerSet.remove(containerId) + if (containerSet.isEmpty) { + allocatedHostToContainersMap.remove(host) + } else { + allocatedHostToContainersMap.update(host, containerSet) } + + allocatedContainerToHostMap.remove(containerId) } } } From 3c3fa632e6ba45ce536065aa1145698385301fb2 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 22 Jan 2015 21:58:53 -0800 Subject: [PATCH 06/33] [SPARK-5233][Streaming] Fix error replaying of WAL introduced bug Because of lacking of `BlockAllocationEvent` in WAL recovery, the dangled event will mix into the new batch, which will lead to the wrong result. Details can be seen in [SPARK-5233](https://issues.apache.org/jira/browse/SPARK-5233). Author: jerryshao Closes #4032 from jerryshao/SPARK-5233 and squashes the following commits: f0b0c0b [jerryshao] Further address the comments a237c75 [jerryshao] Address the comments e356258 [jerryshao] Fix bug in unit test 558bdc3 [jerryshao] Correctly replay the WAL log when recovering from failure --- .../examples/streaming/KafkaWordCount.scala | 2 +- .../streaming/scheduler/JobGenerator.scala | 18 +++++++++++------ .../scheduler/ReceivedBlockTracker.scala | 12 ++++++++--- .../streaming/ReceivedBlockTrackerSuite.scala | 20 +++++++++---------- 4 files changed, 32 insertions(+), 20 deletions(-) diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index 2adc63f7ff30e..387c0e421334b 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -76,7 +76,7 @@ object KafkaWordCountProducer { val Array(brokers, topic, messagesPerSec, wordsPerMessage) = args - // Zookeper connection properties + // Zookeeper connection properties val props = new Properties() props.put("metadata.broker.list", brokers) props.put("serializer.class", "kafka.serializer.StringEncoder") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index d86f852aba97e..8632c94349bf9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -17,12 +17,14 @@ package org.apache.spark.streaming.scheduler -import akka.actor.{ActorRef, ActorSystem, Props, Actor} -import org.apache.spark.{SparkException, SparkEnv, Logging} -import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter} -import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock} import scala.util.{Failure, Success, Try} +import akka.actor.{ActorRef, Props, Actor} + +import org.apache.spark.{SparkEnv, Logging} +import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} +import org.apache.spark.streaming.util.{Clock, ManualClock, RecurringTimer} + /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent private[scheduler] case class GenerateJobs(time: Time) extends JobGeneratorEvent @@ -206,9 +208,13 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", ")) - timesToReschedule.foreach(time => + timesToReschedule.foreach { time => + // Allocate the related blocks when recovering from failure, because some blocks that were + // added but not allocated, are dangling in the queue after recovering, we have to allocate + // those blocks to the next batch, which is the batch they were supposed to go. + jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch jobScheduler.submitJobSet(JobSet(time, graph.generateJobs(time))) - ) + } // Restart the timer timer.start(restartTime.milliseconds) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index ef23b5c79f2e1..e19ac939f9ac5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -67,7 +67,7 @@ private[streaming] class ReceivedBlockTracker( extends Logging { private type ReceivedBlockQueue = mutable.Queue[ReceivedBlockInfo] - + private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, ReceivedBlockQueue] private val timeToAllocatedBlocks = new mutable.HashMap[Time, AllocatedBlocks] private val logManagerOption = createLogManager() @@ -107,8 +107,14 @@ private[streaming] class ReceivedBlockTracker( lastAllocatedBatchTime = batchTime allocatedBlocks } else { - throw new SparkException(s"Unexpected allocation of blocks, " + - s"last batch = $lastAllocatedBatchTime, batch time to allocate = $batchTime ") + // This situation occurs when: + // 1. WAL is ended with BatchAllocationEvent, but without BatchCleanupEvent, + // possibly processed batch job or half-processed batch job need to be processed again, + // so the batchTime will be equal to lastAllocatedBatchTime. + // 2. Slow checkpointing makes recovered batch time older than WAL recovered + // lastAllocatedBatchTime. + // This situation will only occurs in recovery time. + logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery") } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index de7e9d624bf6b..fbb7b0bfebafc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -82,15 +82,15 @@ class ReceivedBlockTrackerSuite receivedBlockTracker.allocateBlocksToBatch(2) receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty - // Verify that batch 2 cannot be allocated again - intercept[SparkException] { - receivedBlockTracker.allocateBlocksToBatch(2) - } + // Verify that older batches have no operation on batch allocation, + // will return the same blocks as previously allocated. + receivedBlockTracker.allocateBlocksToBatch(1) + receivedBlockTracker.getBlocksOfBatchAndStream(1, streamId) shouldEqual blockInfos - // Verify that older batches cannot be allocated again - intercept[SparkException] { - receivedBlockTracker.allocateBlocksToBatch(1) - } + blockInfos.map(receivedBlockTracker.addBlock) + receivedBlockTracker.allocateBlocksToBatch(2) + receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos } test("block addition, block to batch allocation and cleanup with write ahead log") { @@ -186,14 +186,14 @@ class ReceivedBlockTrackerSuite tracker4.getBlocksOfBatchAndStream(batchTime1, streamId) shouldBe empty // should be cleaned tracker4.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 } - + test("enabling write ahead log but not setting checkpoint dir") { conf.set("spark.streaming.receiver.writeAheadLog.enable", "true") intercept[SparkException] { createTracker(setCheckpointDir = false) } } - + test("setting checkpoint dir but not enabling write ahead log") { // When WAL config is not set, log manager should not be enabled val tracker1 = createTracker(setCheckpointDir = true) From e0f7fb7f9f497b34d42f9ba147197cf9ffc51607 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 22 Jan 2015 22:04:21 -0800 Subject: [PATCH 07/33] [SPARK-5315][Streaming] Fix reduceByWindow Java API not work bug `reduceByWindow` for Java API is actually not Java compatible, change to make it Java compatible. Current solution is to deprecate the old one and add a new API, but since old API actually is not correct, so is keeping the old one meaningful? just to keep the binary compatible? Also even adding new API still need to add to Mima exclusion, I'm not sure to change the API, or deprecate the old API and add a new one, which is the best solution? Author: jerryshao Closes #4104 from jerryshao/SPARK-5315 and squashes the following commits: 5bc8987 [jerryshao] Address the comment c7aa1b4 [jerryshao] Deprecate the old one to keep binary compatible 8e9dc67 [jerryshao] Fix JavaDStream reduceByWindow signature error --- project/MimaExcludes.scala | 4 ++++ .../streaming/api/java/JavaDStreamLike.scala | 20 +++++++++++++++++++ .../apache/spark/streaming/JavaAPISuite.java | 20 +++++++++++++++++-- 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 127973b658190..bc5d81f12d746 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -90,6 +90,10 @@ object MimaExcludes { // SPARK-5297 Java FileStream do not work with custom key/values ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream") + ) ++ Seq( + // SPARK-5315 Spark Streaming Java API returns Scala DStream + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow") ) case v if v.startsWith("1.2") => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index e0542eda1383f..c382a12f4d099 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -211,7 +211,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval + * @deprecated As this API is not Java compatible. */ + @deprecated("Use Java-compatible version of reduceByWindow", "1.3.0") def reduceByWindow( reduceFunc: (T, T) => T, windowDuration: Duration, @@ -220,6 +222,24 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration) } + /** + * Return a new DStream in which each RDD has a single element generated by reducing all + * elements in a sliding window over this DStream. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def reduceByWindow( + reduceFunc: JFunction2[T, T, T], + windowDuration: Duration, + slideDuration: Duration + ): JavaDStream[T] = { + dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration) + } + /** * Return a new DStream in which each RDD has a single element generated by reducing all * elements in a sliding window over this DStream. However, the reduction is done incrementally diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index d92e7fe899a09..d4c40745658c2 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -306,7 +306,17 @@ public void testReduce() { @SuppressWarnings("unchecked") @Test - public void testReduceByWindow() { + public void testReduceByWindowWithInverse() { + testReduceByWindow(true); + } + + @SuppressWarnings("unchecked") + @Test + public void testReduceByWindowWithoutInverse() { + testReduceByWindow(false); + } + + private void testReduceByWindow(boolean withInverse) { List> inputData = Arrays.asList( Arrays.asList(1,2,3), Arrays.asList(4,5,6), @@ -319,8 +329,14 @@ public void testReduceByWindow() { Arrays.asList(24)); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reducedWindowed = stream.reduceByWindow(new IntegerSum(), + JavaDStream reducedWindowed = null; + if (withInverse) { + reducedWindowed = stream.reduceByWindow(new IntegerSum(), new IntegerDifference(), new Duration(2000), new Duration(1000)); + } else { + reducedWindowed = stream.reduceByWindow(new IntegerSum(), + new Duration(2000), new Duration(1000)); + } JavaTestUtils.attachTestOutputStream(reducedWindowed); List> result = JavaTestUtils.runStreams(ssc, 4, 4); From ea74365b7c5a3ac29cae9ba66f140f1fa5e8d312 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 22 Jan 2015 22:09:13 -0800 Subject: [PATCH 08/33] [SPARK-3541][MLLIB] New ALS implementation with improved storage This PR adds a new ALS implementation to `spark.ml` using the pipeline API, which should be able to scale to billions of ratings. Compared with the ALS under `spark.mllib`, the new implementation 1. uses the same algorithm, 2. uses float type for ratings, 3. uses primitive arrays to avoid GC, 4. sorts and compresses ratings on each block so that we can solve least squares subproblems one by one using only one normal equation instance. The following figure shows performance comparison on copies of the Amazon Reviews dataset using a 16-node (m3.2xlarge) EC2 cluster (the same setup as in http://databricks.com/blog/2014/07/23/scalable-collaborative-filtering-with-spark-mllib.html): ![als-wip](https://cloud.githubusercontent.com/assets/829644/5659447/4c4ff8e0-96c7-11e4-87a9-73c1c63d07f3.png) I keep the `spark.mllib`'s ALS untouched for easy comparison. If the new implementation works well, I'm going to match the features of the ALS under `spark.mllib` and then make it a wrapper of the new implementation, in a separate PR. TODO: - [X] Add unit tests for implicit preferences. Author: Xiangrui Meng Closes #3720 from mengxr/SPARK-3541 and squashes the following commits: 1b9e852 [Xiangrui Meng] fix compile 5129be9 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-3541 dd0d0e8 [Xiangrui Meng] simplify test code c627de3 [Xiangrui Meng] add tests for implicit feedback b84f41c [Xiangrui Meng] address comments a76da7b [Xiangrui Meng] update ALS tests 2a8deb3 [Xiangrui Meng] add some ALS tests 857e876 [Xiangrui Meng] add tests for rating block and encoded block d3c1ac4 [Xiangrui Meng] rename some classes for better code readability add more doc and comments 213d163 [Xiangrui Meng] org imports 771baf3 [Xiangrui Meng] chol doc update ca9ad9d [Xiangrui Meng] add unit tests for chol b4fd17c [Xiangrui Meng] add unit tests for NormalEquation d0f99d3 [Xiangrui Meng] add tests for LocalIndexEncoder 80b8e61 [Xiangrui Meng] fix imports 4937fd4 [Xiangrui Meng] update ALS example 56c253c [Xiangrui Meng] rename product to item bce8692 [Xiangrui Meng] doc for parameters and project the output columns 3f2d81a [Xiangrui Meng] add doc 1efaecf [Xiangrui Meng] add example code 8ae86b5 [Xiangrui Meng] add a working copy of the new ALS implementation --- .../spark/examples/ml/MovieLensALS.scala | 174 ++++ .../apache/spark/ml/recommendation/ALS.scala | 973 ++++++++++++++++++ .../spark/mllib/recommendation/ALS.scala | 2 +- .../spark/ml/recommendation/ALSSuite.scala | 435 ++++++++ .../spark/mllib/recommendation/ALSSuite.scala | 2 +- 5 files changed, 1584 insertions(+), 2 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala new file mode 100644 index 0000000000000..cf62772b92651 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.recommendation.ALS +import org.apache.spark.sql.{Row, SQLContext} + +/** + * An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/). + * Run with + * {{{ + * bin/run-example ml.MovieLensALS + * }}} + */ +object MovieLensALS { + + case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long) + + object Rating { + def parseRating(str: String): Rating = { + val fields = str.split("::") + assert(fields.size == 4) + Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong) + } + } + + case class Movie(movieId: Int, title: String, genres: Seq[String]) + + object Movie { + def parseMovie(str: String): Movie = { + val fields = str.split("::") + assert(fields.size == 3) + Movie(fields(0).toInt, fields(1), fields(2).split("|")) + } + } + + case class Params( + ratings: String = null, + movies: String = null, + maxIter: Int = 10, + regParam: Double = 0.1, + rank: Int = 10, + numBlocks: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("MovieLensALS") { + head("MovieLensALS: an example app for ALS on MovieLens data.") + opt[String]("ratings") + .required() + .text("path to a MovieLens dataset of ratings") + .action((x, c) => c.copy(ratings = x)) + opt[String]("movies") + .required() + .text("path to a MovieLens dataset of movies") + .action((x, c) => c.copy(movies = x)) + opt[Int]("rank") + .text(s"rank, default: ${defaultParams.rank}}") + .action((x, c) => c.copy(rank = x)) + opt[Int]("maxIter") + .text(s"max number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Int]("numBlocks") + .text(s"number of blocks, default: ${defaultParams.numBlocks}") + .action((x, c) => c.copy(numBlocks = x)) + note( + """ + |Example command line to run this app: + | + | bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \ + | examples/target/scala-*/spark-examples-*.jar \ + | --rank 10 --maxIter 15 --regParam 0.1 \ + | --movies path/to/movielens/movies.dat \ + | --ratings path/to/movielens/ratings.dat + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + System.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"MovieLensALS with $params") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ + + val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache() + + val numRatings = ratings.count() + val numUsers = ratings.map(_.userId).distinct().count() + val numMovies = ratings.map(_.movieId).distinct().count() + + println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + + val splits = ratings.randomSplit(Array(0.8, 0.2), 0L) + val training = splits(0).cache() + val test = splits(1).cache() + + val numTraining = training.count() + val numTest = test.count() + println(s"Training: $numTraining, test: $numTest.") + + ratings.unpersist(blocking = false) + + val als = new ALS() + .setUserCol("userId") + .setItemCol("movieId") + .setRank(params.rank) + .setMaxIter(params.maxIter) + .setRegParam(params.regParam) + .setNumBlocks(params.numBlocks) + + val model = als.fit(training) + + val predictions = model.transform(test).cache() + + // Evaluate the model. + // TODO: Create an evaluator to compute RMSE. + val mse = predictions.select('rating, 'prediction) + .flatMap { case Row(rating: Float, prediction: Float) => + val err = rating.toDouble - prediction + val err2 = err * err + if (err2.isNaN) { + None + } else { + Some(err2) + } + }.mean() + val rmse = math.sqrt(mse) + println(s"Test RMSE = $rmse.") + + // Inspect false positives. + predictions.registerTempTable("prediction") + sc.textFile(params.movies).map(Movie.parseMovie).registerTempTable("movie") + sqlContext.sql( + """ + |SELECT userId, prediction.movieId, title, rating, prediction + | FROM prediction JOIN movie ON prediction.movieId = movie.movieId + | WHERE rating <= 1 AND prediction >= 4 + | LIMIT 100 + """.stripMargin) + .collect() + .foreach(println) + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala new file mode 100644 index 0000000000000..2d89e76a4c8b2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -0,0 +1,973 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import java.{util => ju} + +import scala.collection.mutable + +import com.github.fommil.netlib.BLAS.{getInstance => blas} +import com.github.fommil.netlib.LAPACK.{getInstance => lapack} +import org.netlib.util.intW + +import org.apache.spark.{HashPartitioner, Logging, Partitioner} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} +import org.apache.spark.util.Utils +import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} +import org.apache.spark.util.random.XORShiftRandom + +/** + * Common params for ALS. + */ +private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam + with HasPredictionCol { + + /** Param for rank of the matrix factorization. */ + val rank = new IntParam(this, "rank", "rank of the factorization", Some(10)) + def getRank: Int = get(rank) + + /** Param for number of user blocks. */ + val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10)) + def getNumUserBlocks: Int = get(numUserBlocks) + + /** Param for number of item blocks. */ + val numItemBlocks = + new IntParam(this, "numItemBlocks", "number of item blocks", Some(10)) + def getNumItemBlocks: Int = get(numItemBlocks) + + /** Param to decide whether to use implicit preference. */ + val implicitPrefs = + new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false)) + def getImplicitPrefs: Boolean = get(implicitPrefs) + + /** Param for the alpha parameter in the implicit preference formulation. */ + val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0)) + def getAlpha: Double = get(alpha) + + /** Param for the column name for user ids. */ + val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user")) + def getUserCol: String = get(userCol) + + /** Param for the column name for item ids. */ + val itemCol = + new Param[String](this, "itemCol", "column name for item ids", Some("item")) + def getItemCol: String = get(itemCol) + + /** Param for the column name for ratings. */ + val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating")) + def getRatingCol: String = get(ratingCol) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @param paramMap extra params + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + assert(schema(map(userCol)).dataType == IntegerType) + assert(schema(map(itemCol)).dataType== IntegerType) + val ratingType = schema(map(ratingCol)).dataType + assert(ratingType == FloatType || ratingType == DoubleType) + val predictionColName = map(predictionCol) + assert(!schema.fieldNames.contains(predictionColName), + s"Prediction column $predictionColName already exists.") + val newFields = schema.fields :+ StructField(map(predictionCol), FloatType, nullable = false) + StructType(newFields) + } +} + +/** + * Model fitted by ALS. + */ +class ALSModel private[ml] ( + override val parent: ALS, + override val fittingParamMap: ParamMap, + k: Int, + userFactors: RDD[(Int, Array[Float])], + itemFactors: RDD[(Int, Array[Float])]) + extends Model[ALSModel] with ALSParams { + + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + import dataset.sqlContext._ + import org.apache.spark.ml.recommendation.ALSModel.Factor + val map = this.paramMap ++ paramMap + // TODO: Add DSL to simplify the code here. + val instanceTable = s"instance_$uid" + val userTable = s"user_$uid" + val itemTable = s"item_$uid" + val instances = dataset.as(Symbol(instanceTable)) + val users = userFactors.map { case (id, features) => + Factor(id, features) + }.as(Symbol(userTable)) + val items = itemFactors.map { case (id, features) => + Factor(id, features) + }.as(Symbol(itemTable)) + val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => { + if (userFeatures != null && itemFeatures != null) { + blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) + } else { + Float.NaN + } + } + val inputColumns = dataset.schema.fieldNames + val prediction = + predict.call(s"$userTable.features".attr, s"$itemTable.features".attr) as map(predictionCol) + val outputColumns = inputColumns.map(f => s"$instanceTable.$f".attr as f) :+ prediction + instances + .join(users, LeftOuter, Some(map(userCol).attr === s"$userTable.id".attr)) + .join(items, LeftOuter, Some(map(itemCol).attr === s"$itemTable.id".attr)) + .select(outputColumns: _*) + } + + override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +private object ALSModel { + /** Case class to convert factors to SchemaRDDs */ + private case class Factor(id: Int, features: Seq[Float]) +} + +/** + * Alternating Least Squares (ALS) matrix factorization. + * + * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices, + * `X` and `Y`, i.e. `X * Yt = R`. Typically these approximations are called 'factor' matrices. + * The general approach is iterative. During each iteration, one of the factor matrices is held + * constant, while the other is solved for using least squares. The newly-solved factor matrix is + * then held constant while solving for the other factor matrix. + * + * This is a blocked implementation of the ALS factorization algorithm that groups the two sets + * of factors (referred to as "users" and "products") into blocks and reduces communication by only + * sending one copy of each user vector to each product block on each iteration, and only for the + * product blocks that need that user's feature vector. This is achieved by pre-computing some + * information about the ratings matrix to determine the "out-links" of each user (which blocks of + * products it will contribute to) and "in-link" information for each product (which of the feature + * vectors it receives from each user block it will depend on). This allows us to send only an + * array of feature vectors between each user block and product block, and have the product block + * find the users' ratings and update the products based on these messages. + * + * For implicit preference data, the algorithm used is based on + * "Collaborative Filtering for Implicit Feedback Datasets", available at + * [[http://dx.doi.org/10.1109/ICDM.2008.22]], adapted for the blocked approach used here. + * + * Essentially instead of finding the low-rank approximations to the rating matrix `R`, + * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if + * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of + * indicated user + * preferences rather than explicit ratings given to items. + */ +class ALS extends Estimator[ALSModel] with ALSParams { + + import org.apache.spark.ml.recommendation.ALS.Rating + + def setRank(value: Int): this.type = set(rank, value) + def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value) + def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value) + def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value) + def setAlpha(value: Double): this.type = set(alpha, value) + def setUserCol(value: String): this.type = set(userCol, value) + def setItemCol(value: String): this.type = set(itemCol, value) + def setRatingCol(value: String): this.type = set(ratingCol, value) + def setPredictionCol(value: String): this.type = set(predictionCol, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) + def setRegParam(value: Double): this.type = set(regParam, value) + + /** Sets both numUserBlocks and numItemBlocks to the specific value. */ + def setNumBlocks(value: Int): this.type = { + setNumUserBlocks(value) + setNumItemBlocks(value) + this + } + + setMaxIter(20) + setRegParam(1.0) + + override def fit(dataset: SchemaRDD, paramMap: ParamMap): ALSModel = { + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val ratings = + dataset.select(map(userCol).attr, map(itemCol).attr, Cast(map(ratingCol).attr, FloatType)) + .map { row => + new Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) + } + val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank), + numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks), + maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs), + alpha = map(alpha)) + val model = new ALSModel(this, map, map(rank), userFactors, itemFactors) + Params.inheritValues(map, this, model) + model + } + + override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +private[recommendation] object ALS extends Logging { + + /** Rating class for better code readability. */ + private[recommendation] case class Rating(user: Int, item: Int, rating: Float) + + /** Cholesky solver for least square problems. */ + private[recommendation] class CholeskySolver { + + private val upper = "U" + private val info = new intW(0) + + /** + * Solves a least squares problem with L2 regularization: + * + * min norm(A x - b)^2^ + lambda * n * norm(x)^2^ + * + * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances) + * @param lambda regularization constant, which will be scaled by n + * @return the solution x + */ + def solve(ne: NormalEquation, lambda: Double): Array[Float] = { + val k = ne.k + // Add scaled lambda to the diagonals of AtA. + val scaledlambda = lambda * ne.n + var i = 0 + var j = 2 + while (i < ne.triK) { + ne.ata(i) += scaledlambda + i += j + j += 1 + } + lapack.dppsv(upper, k, 1, ne.ata, ne.atb, k, info) + val code = info.`val` + assert(code == 0, s"lapack.dppsv returned $code.") + val x = new Array[Float](k) + i = 0 + while (i < k) { + x(i) = ne.atb(i).toFloat + i += 1 + } + ne.reset() + x + } + } + + /** Representing a normal equation (ALS' subproblem). */ + private[recommendation] class NormalEquation(val k: Int) extends Serializable { + + /** Number of entries in the upper triangular part of a k-by-k matrix. */ + val triK = k * (k + 1) / 2 + /** A^T^ * A */ + val ata = new Array[Double](triK) + /** A^T^ * b */ + val atb = new Array[Double](k) + /** Number of observations. */ + var n = 0 + + private val da = new Array[Double](k) + private val upper = "U" + + private def copyToDouble(a: Array[Float]): Unit = { + var i = 0 + while (i < k) { + da(i) = a(i) + i += 1 + } + } + + /** Adds an observation. */ + def add(a: Array[Float], b: Float): this.type = { + require(a.size == k) + copyToDouble(a) + blas.dspr(upper, k, 1.0, da, 1, ata) + blas.daxpy(k, b.toDouble, da, 1, atb, 1) + n += 1 + this + } + + /** + * Adds an observation with implicit feedback. Note that this does not increment the counter. + */ + def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = { + require(a.size == k) + // Extension to the original paper to handle b < 0. confidence is a function of |b| instead + // so that it is never negative. + val confidence = 1.0 + alpha * math.abs(b) + copyToDouble(a) + blas.dspr(upper, k, confidence - 1.0, da, 1, ata) + // For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0. + if (b > 0) { + blas.daxpy(k, confidence, da, 1, atb, 1) + } + this + } + + /** Merges another normal equation object. */ + def merge(other: NormalEquation): this.type = { + require(other.k == k) + blas.daxpy(ata.size, 1.0, other.ata, 1, ata, 1) + blas.daxpy(atb.size, 1.0, other.atb, 1, atb, 1) + n += other.n + this + } + + /** Resets everything to zero, which should be called after each solve. */ + def reset(): Unit = { + ju.Arrays.fill(ata, 0.0) + ju.Arrays.fill(atb, 0.0) + n = 0 + } + } + + /** + * Implementation of the ALS algorithm. + */ + private def train( + ratings: RDD[Rating], + rank: Int = 10, + numUserBlocks: Int = 10, + numItemBlocks: Int = 10, + maxIter: Int = 10, + regParam: Double = 1.0, + implicitPrefs: Boolean = false, + alpha: Double = 1.0): (RDD[(Int, Array[Float])], RDD[(Int, Array[Float])]) = { + val userPart = new HashPartitioner(numUserBlocks) + val itemPart = new HashPartitioner(numItemBlocks) + val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions) + val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions) + val blockRatings = partitionRatings(ratings, userPart, itemPart).cache() + val (userInBlocks, userOutBlocks) = makeBlocks("user", blockRatings, userPart, itemPart) + // materialize blockRatings and user blocks + userOutBlocks.count() + val swappedBlockRatings = blockRatings.map { + case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) => + ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings)) + } + val (itemInBlocks, itemOutBlocks) = makeBlocks("item", swappedBlockRatings, itemPart, userPart) + // materialize item blocks + itemOutBlocks.count() + var userFactors = initialize(userInBlocks, rank) + var itemFactors = initialize(itemInBlocks, rank) + if (implicitPrefs) { + for (iter <- 1 to maxIter) { + userFactors.setName(s"userFactors-$iter").persist() + val previousItemFactors = itemFactors + itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, + userLocalIndexEncoder, implicitPrefs, alpha) + previousItemFactors.unpersist() + itemFactors.setName(s"itemFactors-$iter").persist() + val previousUserFactors = userFactors + userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, + itemLocalIndexEncoder, implicitPrefs, alpha) + previousUserFactors.unpersist() + } + } else { + for (iter <- 0 until maxIter) { + itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, + userLocalIndexEncoder) + userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, + itemLocalIndexEncoder) + } + } + val userIdAndFactors = userInBlocks + .mapValues(_.srcIds) + .join(userFactors) + .values + .setName("userFactors") + .cache() + userIdAndFactors.count() + itemFactors.unpersist() + val itemIdAndFactors = itemInBlocks + .mapValues(_.srcIds) + .join(itemFactors) + .values + .setName("itemFactors") + .cache() + itemIdAndFactors.count() + userInBlocks.unpersist() + userOutBlocks.unpersist() + itemInBlocks.unpersist() + itemOutBlocks.unpersist() + blockRatings.unpersist() + val userOutput = userIdAndFactors.flatMap { case (ids, factors) => + ids.view.zip(factors) + } + val itemOutput = itemIdAndFactors.flatMap { case (ids, factors) => + ids.view.zip(factors) + } + (userOutput, itemOutput) + } + + /** + * Factor block that stores factors (Array[Float]) in an Array. + */ + private type FactorBlock = Array[Array[Float]] + + /** + * Out-link block that stores, for each dst (item/user) block, which src (user/item) factors to + * send. For example, outLinkBlock(0) contains the local indices (not the original src IDs) of the + * src factors in this block to send to dst block 0. + */ + private type OutBlock = Array[Array[Int]] + + /** + * In-link block for computing src (user/item) factors. This includes the original src IDs + * of the elements within this block as well as encoded dst (item/user) indices and corresponding + * ratings. The dst indices are in the form of (blockId, localIndex), which are not the original + * dst IDs. To compute src factors, we expect receiving dst factors that match the dst indices. + * For example, if we have an in-link record + * + * {srcId: 0, dstBlockId: 2, dstLocalIndex: 3, rating: 5.0}, + * + * and assume that the dst factors are stored as dstFactors: Map[Int, Array[Array[Float]]], which + * is a blockId to dst factors map, the corresponding dst factor of the record is dstFactor(2)(3). + * + * We use a CSC-like (compressed sparse column) format to store the in-link information. So we can + * compute src factors one after another using only one normal equation instance. + * + * @param srcIds src ids (ordered) + * @param dstPtrs dst pointers. Elements in range [dstPtrs(i), dstPtrs(i+1)) of dst indices and + * ratings are associated with srcIds(i). + * @param dstEncodedIndices encoded dst indices + * @param ratings ratings + * + * @see [[LocalIndexEncoder]] + */ + private[recommendation] case class InBlock( + srcIds: Array[Int], + dstPtrs: Array[Int], + dstEncodedIndices: Array[Int], + ratings: Array[Float]) { + /** Size of the block. */ + val size: Int = ratings.size + + require(dstEncodedIndices.size == size) + require(dstPtrs.size == srcIds.size + 1) + } + + /** + * Initializes factors randomly given the in-link blocks. + * + * @param inBlocks in-link blocks + * @param rank rank + * @return initialized factor blocks + */ + private def initialize(inBlocks: RDD[(Int, InBlock)], rank: Int): RDD[(Int, FactorBlock)] = { + // Choose a unit vector uniformly at random from the unit sphere, but from the + // "first quadrant" where all elements are nonnegative. This can be done by choosing + // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing. + // This appears to create factorizations that have a slightly better reconstruction + // (<1%) compared picking elements uniformly at random in [0,1]. + inBlocks.map { case (srcBlockId, inBlock) => + val random = new XORShiftRandom(srcBlockId) + val factors = Array.fill(inBlock.srcIds.size) { + val factor = Array.fill(rank)(random.nextGaussian().toFloat) + val nrm = blas.snrm2(rank, factor, 1) + blas.sscal(rank, 1.0f / nrm, factor, 1) + factor + } + (srcBlockId, factors) + } + } + + /** + * A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays. + */ + private[recommendation] + case class RatingBlock(srcIds: Array[Int], dstIds: Array[Int], ratings: Array[Float]) { + /** Size of the block. */ + val size: Int = srcIds.size + require(dstIds.size == size) + require(ratings.size == size) + } + + /** + * Builder for [[RatingBlock]]. [[mutable.ArrayBuilder]] is used to avoid boxing/unboxing. + */ + private[recommendation] class RatingBlockBuilder extends Serializable { + + private val srcIds = mutable.ArrayBuilder.make[Int] + private val dstIds = mutable.ArrayBuilder.make[Int] + private val ratings = mutable.ArrayBuilder.make[Float] + var size = 0 + + /** Adds a rating. */ + def add(r: Rating): this.type = { + size += 1 + srcIds += r.user + dstIds += r.item + ratings += r.rating + this + } + + /** Merges another [[RatingBlockBuilder]]. */ + def merge(other: RatingBlock): this.type = { + size += other.srcIds.size + srcIds ++= other.srcIds + dstIds ++= other.dstIds + ratings ++= other.ratings + this + } + + /** Builds a [[RatingBlock]]. */ + def build(): RatingBlock = { + RatingBlock(srcIds.result(), dstIds.result(), ratings.result()) + } + } + + /** + * Partitions raw ratings into blocks. + * + * @param ratings raw ratings + * @param srcPart partitioner for src IDs + * @param dstPart partitioner for dst IDs + * + * @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock) + */ + private def partitionRatings( + ratings: RDD[Rating], + srcPart: Partitioner, + dstPart: Partitioner): RDD[((Int, Int), RatingBlock)] = { + + /* The implementation produces the same result as the following but generates less objects. + + ratings.map { r => + ((srcPart.getPartition(r.user), dstPart.getPartition(r.item)), r) + }.aggregateByKey(new RatingBlockBuilder)( + seqOp = (b, r) => b.add(r), + combOp = (b0, b1) => b0.merge(b1.build())) + .mapValues(_.build()) + */ + + val numPartitions = srcPart.numPartitions * dstPart.numPartitions + ratings.mapPartitions { iter => + val builders = Array.fill(numPartitions)(new RatingBlockBuilder) + iter.flatMap { r => + val srcBlockId = srcPart.getPartition(r.user) + val dstBlockId = dstPart.getPartition(r.item) + val idx = srcBlockId + srcPart.numPartitions * dstBlockId + val builder = builders(idx) + builder.add(r) + if (builder.size >= 2048) { // 2048 * (3 * 4) = 24k + builders(idx) = new RatingBlockBuilder + Iterator.single(((srcBlockId, dstBlockId), builder.build())) + } else { + Iterator.empty + } + } ++ { + builders.view.zipWithIndex.filter(_._1.size > 0).map { case (block, idx) => + val srcBlockId = idx % srcPart.numPartitions + val dstBlockId = idx / srcPart.numPartitions + ((srcBlockId, dstBlockId), block.build()) + } + } + }.groupByKey().mapValues { blocks => + val builder = new RatingBlockBuilder + blocks.foreach(builder.merge) + builder.build() + }.setName("ratingBlocks") + } + + /** + * Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples. + * @param encoder encoder for dst indices + */ + private[recommendation] class UncompressedInBlockBuilder(encoder: LocalIndexEncoder) { + + private val srcIds = mutable.ArrayBuilder.make[Int] + private val dstEncodedIndices = mutable.ArrayBuilder.make[Int] + private val ratings = mutable.ArrayBuilder.make[Float] + + /** + * Adds a dst block of (srcId, dstLocalIndex, rating) tuples. + * + * @param dstBlockId dst block ID + * @param srcIds original src IDs + * @param dstLocalIndices dst local indices + * @param ratings ratings + */ + def add( + dstBlockId: Int, + srcIds: Array[Int], + dstLocalIndices: Array[Int], + ratings: Array[Float]): this.type = { + val sz = srcIds.size + require(dstLocalIndices.size == sz) + require(ratings.size == sz) + this.srcIds ++= srcIds + this.ratings ++= ratings + var j = 0 + while (j < sz) { + this.dstEncodedIndices += encoder.encode(dstBlockId, dstLocalIndices(j)) + j += 1 + } + this + } + + /** Builds a [[UncompressedInBlock]]. */ + def build(): UncompressedInBlock = { + new UncompressedInBlock(srcIds.result(), dstEncodedIndices.result(), ratings.result()) + } + } + + /** + * A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays. + */ + private[recommendation] class UncompressedInBlock( + val srcIds: Array[Int], + val dstEncodedIndices: Array[Int], + val ratings: Array[Float]) { + + /** Size the of block. */ + def size: Int = srcIds.size + + /** + * Compresses the block into an [[InBlock]]. The algorithm is the same as converting a + * sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format. + * Sorting is done using Spark's built-in Timsort to avoid generating too many objects. + */ + def compress(): InBlock = { + val sz = size + assert(sz > 0, "Empty in-link block should not exist.") + sort() + val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[Int] + val dstCountsBuilder = mutable.ArrayBuilder.make[Int] + var preSrcId = srcIds(0) + uniqueSrcIdsBuilder += preSrcId + var curCount = 1 + var i = 1 + var j = 0 + while (i < sz) { + val srcId = srcIds(i) + if (srcId != preSrcId) { + uniqueSrcIdsBuilder += srcId + dstCountsBuilder += curCount + preSrcId = srcId + j += 1 + curCount = 0 + } + curCount += 1 + i += 1 + } + dstCountsBuilder += curCount + val uniqueSrcIds = uniqueSrcIdsBuilder.result() + val numUniqueSrdIds = uniqueSrcIds.size + val dstCounts = dstCountsBuilder.result() + val dstPtrs = new Array[Int](numUniqueSrdIds + 1) + var sum = 0 + i = 0 + while (i < numUniqueSrdIds) { + sum += dstCounts(i) + i += 1 + dstPtrs(i) = sum + } + InBlock(uniqueSrcIds, dstPtrs, dstEncodedIndices, ratings) + } + + private def sort(): Unit = { + val sz = size + // Since there might be interleaved log messages, we insert a unique id for easy pairing. + val sortId = Utils.random.nextInt() + logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)") + val start = System.nanoTime() + val sorter = new Sorter(new UncompressedInBlockSort) + sorter.sort(this, 0, size, Ordering[IntWrapper]) + val duration = (System.nanoTime() - start) / 1e9 + logDebug(s"Sorting took $duration seconds. (sortId = $sortId)") + } + } + + /** + * A wrapper that holds a primitive integer key. + * + * @see [[UncompressedInBlockSort]] + */ + private class IntWrapper(var key: Int = 0) extends Ordered[IntWrapper] { + override def compare(that: IntWrapper): Int = { + key.compare(that.key) + } + } + + /** + * [[SortDataFormat]] of [[UncompressedInBlock]] used by [[Sorter]]. + */ + private class UncompressedInBlockSort extends SortDataFormat[IntWrapper, UncompressedInBlock] { + + override def newKey(): IntWrapper = new IntWrapper() + + override def getKey( + data: UncompressedInBlock, + pos: Int, + reuse: IntWrapper): IntWrapper = { + if (reuse == null) { + new IntWrapper(data.srcIds(pos)) + } else { + reuse.key = data.srcIds(pos) + reuse + } + } + + override def getKey( + data: UncompressedInBlock, + pos: Int): IntWrapper = { + getKey(data, pos, null) + } + + private def swapElements[@specialized(Int, Float) T]( + data: Array[T], + pos0: Int, + pos1: Int): Unit = { + val tmp = data(pos0) + data(pos0) = data(pos1) + data(pos1) = tmp + } + + override def swap(data: UncompressedInBlock, pos0: Int, pos1: Int): Unit = { + swapElements(data.srcIds, pos0, pos1) + swapElements(data.dstEncodedIndices, pos0, pos1) + swapElements(data.ratings, pos0, pos1) + } + + override def copyRange( + src: UncompressedInBlock, + srcPos: Int, + dst: UncompressedInBlock, + dstPos: Int, + length: Int): Unit = { + System.arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length) + System.arraycopy(src.dstEncodedIndices, srcPos, dst.dstEncodedIndices, dstPos, length) + System.arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length) + } + + override def allocate(length: Int): UncompressedInBlock = { + new UncompressedInBlock( + new Array[Int](length), new Array[Int](length), new Array[Float](length)) + } + + override def copyElement( + src: UncompressedInBlock, + srcPos: Int, + dst: UncompressedInBlock, + dstPos: Int): Unit = { + dst.srcIds(dstPos) = src.srcIds(srcPos) + dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos) + dst.ratings(dstPos) = src.ratings(srcPos) + } + } + + /** + * Creates in-blocks and out-blocks from rating blocks. + * @param prefix prefix for in/out-block names + * @param ratingBlocks rating blocks + * @param srcPart partitioner for src IDs + * @param dstPart partitioner for dst IDs + * @return (in-blocks, out-blocks) + */ + private def makeBlocks( + prefix: String, + ratingBlocks: RDD[((Int, Int), RatingBlock)], + srcPart: Partitioner, + dstPart: Partitioner): (RDD[(Int, InBlock)], RDD[(Int, OutBlock)]) = { + val inBlocks = ratingBlocks.map { + case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) => + // The implementation is a faster version of + // val dstIdToLocalIndex = dstIds.toSet.toSeq.sorted.zipWithIndex.toMap + val start = System.nanoTime() + val dstIdSet = new OpenHashSet[Int](1 << 20) + dstIds.foreach(dstIdSet.add) + val sortedDstIds = new Array[Int](dstIdSet.size) + var i = 0 + var pos = dstIdSet.nextPos(0) + while (pos != -1) { + sortedDstIds(i) = dstIdSet.getValue(pos) + pos = dstIdSet.nextPos(pos + 1) + i += 1 + } + assert(i == dstIdSet.size) + ju.Arrays.sort(sortedDstIds) + val dstIdToLocalIndex = new OpenHashMap[Int, Int](sortedDstIds.size) + i = 0 + while (i < sortedDstIds.size) { + dstIdToLocalIndex.update(sortedDstIds(i), i) + i += 1 + } + logDebug( + "Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.") + val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply) + (srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings)) + }.groupByKey(new HashPartitioner(srcPart.numPartitions)) + .mapValues { iter => + val builder = + new UncompressedInBlockBuilder(new LocalIndexEncoder(dstPart.numPartitions)) + iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) => + builder.add(dstBlockId, srcIds, dstLocalIndices, ratings) + } + builder.build().compress() + }.setName(prefix + "InBlocks").cache() + val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) => + val encoder = new LocalIndexEncoder(dstPart.numPartitions) + val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int]) + var i = 0 + val seen = new Array[Boolean](dstPart.numPartitions) + while (i < srcIds.size) { + var j = dstPtrs(i) + ju.Arrays.fill(seen, false) + while (j < dstPtrs(i + 1)) { + val dstBlockId = encoder.blockId(dstEncodedIndices(j)) + if (!seen(dstBlockId)) { + activeIds(dstBlockId) += i // add the local index in this out-block + seen(dstBlockId) = true + } + j += 1 + } + i += 1 + } + activeIds.map { x => + x.result() + } + }.setName(prefix + "OutBlocks").cache() + (inBlocks, outBlocks) + } + + /** + * Compute dst factors by constructing and solving least square problems. + * + * @param srcFactorBlocks src factors + * @param srcOutBlocks src out-blocks + * @param dstInBlocks dst in-blocks + * @param rank rank + * @param regParam regularization constant + * @param srcEncoder encoder for src local indices + * @param implicitPrefs whether to use implicit preference + * @param alpha the alpha constant in the implicit preference formulation + * + * @return dst factors + */ + private def computeFactors( + srcFactorBlocks: RDD[(Int, FactorBlock)], + srcOutBlocks: RDD[(Int, OutBlock)], + dstInBlocks: RDD[(Int, InBlock)], + rank: Int, + regParam: Double, + srcEncoder: LocalIndexEncoder, + implicitPrefs: Boolean = false, + alpha: Double = 1.0): RDD[(Int, FactorBlock)] = { + val numSrcBlocks = srcFactorBlocks.partitions.size + val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None + val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap { + case (srcBlockId, (srcOutBlock, srcFactors)) => + srcOutBlock.view.zipWithIndex.map { case (activeIndices, dstBlockId) => + (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx)))) + } + } + val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.size)) + dstInBlocks.join(merged).mapValues { + case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) => + val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks) + srcFactors.foreach { case (srcBlockId, factors) => + sortedSrcFactors(srcBlockId) = factors + } + val dstFactors = new Array[Array[Float]](dstIds.size) + var j = 0 + val ls = new NormalEquation(rank) + val solver = new CholeskySolver // TODO: add NNLS solver + while (j < dstIds.size) { + ls.reset() + if (implicitPrefs) { + ls.merge(YtY.get) + } + var i = srcPtrs(j) + while (i < srcPtrs(j + 1)) { + val encoded = srcEncodedIndices(i) + val blockId = srcEncoder.blockId(encoded) + val localIndex = srcEncoder.localIndex(encoded) + val srcFactor = sortedSrcFactors(blockId)(localIndex) + val rating = ratings(i) + if (implicitPrefs) { + ls.addImplicit(srcFactor, rating, alpha) + } else { + ls.add(srcFactor, rating) + } + i += 1 + } + dstFactors(j) = solver.solve(ls, regParam) + j += 1 + } + dstFactors + } + } + + /** + * Computes the Gramian matrix of user or item factors, which is only used in implicit preference. + * Caching of the input factors is handled in [[ALS#train]]. + */ + private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = { + factorBlocks.values.aggregate(new NormalEquation(rank))( + seqOp = (ne, factors) => { + factors.foreach(ne.add(_, 0.0f)) + ne + }, + combOp = (ne1, ne2) => ne1.merge(ne2)) + } + + /** + * Encoder for storing (blockId, localIndex) into a single integer. + * + * We use the leading bits (including the sign bit) to store the block id and the rest to store + * the local index. This is based on the assumption that users/items are approximately evenly + * partitioned. With this assumption, we should be able to encode two billion distinct values. + * + * @param numBlocks number of blocks + */ + private[recommendation] class LocalIndexEncoder(numBlocks: Int) extends Serializable { + + require(numBlocks > 0, s"numBlocks must be positive but found $numBlocks.") + + private[this] final val numLocalIndexBits = + math.min(java.lang.Integer.numberOfLeadingZeros(numBlocks - 1), 31) + private[this] final val localIndexMask = (1 << numLocalIndexBits) - 1 + + /** Encodes a (blockId, localIndex) into a single integer. */ + def encode(blockId: Int, localIndex: Int): Int = { + require(blockId < numBlocks) + require((localIndex & ~localIndexMask) == 0) + (blockId << numLocalIndexBits) | localIndex + } + + /** Gets the block id from an encoded index. */ + @inline + def blockId(encoded: Int): Int = { + encoded >>> numLocalIndexBits + } + + /** Gets the local index from an encoded index. */ + @inline + def localIndex(encoded: Int): Int = { + encoded & localIndexMask + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index bee951a2e5e26..5f84677be238d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -90,7 +90,7 @@ case class Rating(user: Int, product: Int, rating: Double) * * Essentially instead of finding the low-rank approximations to the rating matrix `R`, * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if - * r > 0 and 0 if r = 0. The ratings then act as 'confidence' values related to strength of + * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of * indicated user * preferences rather than explicit ratings given to items. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala new file mode 100644 index 0000000000000..cdd4db1b5b7dc --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import java.util.Random + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.scalatest.FunSuite + +import org.apache.spark.Logging +import org.apache.spark.ml.recommendation.ALS._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} + +class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { + + private var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("LocalIndexEncoder") { + val random = new Random + for (numBlocks <- Seq(1, 2, 5, 10, 20, 50, 100)) { + val encoder = new LocalIndexEncoder(numBlocks) + val maxLocalIndex = Int.MaxValue / numBlocks + val tests = Seq.fill(5)((random.nextInt(numBlocks), random.nextInt(maxLocalIndex))) ++ + Seq((0, 0), (numBlocks - 1, maxLocalIndex)) + tests.foreach { case (blockId, localIndex) => + val err = s"Failed with numBlocks=$numBlocks, blockId=$blockId, and localIndex=$localIndex." + val encoded = encoder.encode(blockId, localIndex) + assert(encoder.blockId(encoded) === blockId, err) + assert(encoder.localIndex(encoded) === localIndex, err) + } + } + } + + test("normal equation construction with explict feedback") { + val k = 2 + val ne0 = new NormalEquation(k) + .add(Array(1.0f, 2.0f), 3.0f) + .add(Array(4.0f, 5.0f), 6.0f) + assert(ne0.k === k) + assert(ne0.triK === k * (k + 1) / 2) + assert(ne0.n === 2) + // NumPy code that computes the expected values: + // A = np.matrix("1 2; 4 5") + // b = np.matrix("3; 6") + // ata = A.transpose() * A + // atb = A.transpose() * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8) + + val ne1 = new NormalEquation(2) + .add(Array(7.0f, 8.0f), 9.0f) + ne0.merge(ne1) + assert(ne0.n === 3) + // NumPy code that computes the expected values: + // A = np.matrix("1 2; 4 5; 7 8") + // b = np.matrix("3; 6; 9") + // ata = A.transpose() * A + // atb = A.transpose() * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8) + + intercept[IllegalArgumentException] { + ne0.add(Array(1.0f), 2.0f) + } + intercept[IllegalArgumentException] { + ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f) + } + intercept[IllegalArgumentException] { + val ne2 = new NormalEquation(3) + ne0.merge(ne2) + } + + ne0.reset() + assert(ne0.n === 0) + assert(ne0.ata.forall(_ == 0.0)) + assert(ne0.atb.forall(_ == 0.0)) + } + + test("normal equation construction with implicit feedback") { + val k = 2 + val alpha = 0.5 + val ne0 = new NormalEquation(k) + .addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha) + .addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha) + .addImplicit(Array(1.0f, 2.0f), 3.0f, alpha) + assert(ne0.k === k) + assert(ne0.triK === k * (k + 1) / 2) + assert(ne0.n === 0) // addImplicit doesn't increase the count. + // NumPy code that computes the expected values: + // alpha = 0.5 + // A = np.matrix("-5 -4; -2 -1; 1 2") + // b = np.matrix("-3; 0; 3") + // b1 = b > 0 + // c = 1.0 + alpha * np.abs(b) + // C = np.diag(c.A1) + // I = np.eye(3) + // ata = A.transpose() * (C - I) * A + // atb = A.transpose() * C * b1 + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8) + } + + test("CholeskySolver") { + val k = 2 + val ne0 = new NormalEquation(k) + .add(Array(1.0f, 2.0f), 4.0f) + .add(Array(1.0f, 3.0f), 9.0f) + .add(Array(1.0f, 4.0f), 16.0f) + val ne1 = new NormalEquation(k) + .merge(ne0) + + val chol = new CholeskySolver + val x0 = chol.solve(ne0, 0.0).map(_.toDouble) + // NumPy code that computes the expected solution: + // A = np.matrix("1 2; 1 3; 1 4") + // b = b = np.matrix("3; 6") + // x0 = np.linalg.lstsq(A, b)[0] + assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6) + + assert(ne0.n === 0) + assert(ne0.ata.forall(_ == 0.0)) + assert(ne0.atb.forall(_ == 0.0)) + + val x1 = chol.solve(ne1, 0.5).map(_.toDouble) + // NumPy code that computes the expected solution, where lambda is scaled by n: + // x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b) + assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6) + } + + test("RatingBlockBuilder") { + val emptyBuilder = new RatingBlockBuilder() + assert(emptyBuilder.size === 0) + val emptyBlock = emptyBuilder.build() + assert(emptyBlock.srcIds.isEmpty) + assert(emptyBlock.dstIds.isEmpty) + assert(emptyBlock.ratings.isEmpty) + + val builder0 = new RatingBlockBuilder() + .add(Rating(0, 1, 2.0f)) + .add(Rating(3, 4, 5.0f)) + assert(builder0.size === 2) + val builder1 = new RatingBlockBuilder() + .add(Rating(6, 7, 8.0f)) + .merge(builder0.build()) + assert(builder1.size === 3) + val block = builder1.build() + val ratings = Seq.tabulate(block.size) { i => + (block.srcIds(i), block.dstIds(i), block.ratings(i)) + }.toSet + assert(ratings === Set((0, 1, 2.0f), (3, 4, 5.0f), (6, 7, 8.0f))) + } + + test("UncompressedInBlock") { + val encoder = new LocalIndexEncoder(10) + val uncompressed = new UncompressedInBlockBuilder(encoder) + .add(0, Array(1, 0, 2), Array(0, 1, 4), Array(1.0f, 2.0f, 3.0f)) + .add(1, Array(3, 0), Array(2, 5), Array(4.0f, 5.0f)) + .build() + assert(uncompressed.size === 5) + val records = Seq.tabulate(uncompressed.size) { i => + val dstEncodedIndex = uncompressed.dstEncodedIndices(i) + val dstBlockId = encoder.blockId(dstEncodedIndex) + val dstLocalIndex = encoder.localIndex(dstEncodedIndex) + (uncompressed.srcIds(i), dstBlockId, dstLocalIndex, uncompressed.ratings(i)) + }.toSet + val expected = + Set((1, 0, 0, 1.0f), (0, 0, 1, 2.0f), (2, 0, 4, 3.0f), (3, 1, 2, 4.0f), (0, 1, 5, 5.0f)) + assert(records === expected) + + val compressed = uncompressed.compress() + assert(compressed.size === 5) + assert(compressed.srcIds.toSeq === Seq(0, 1, 2, 3)) + assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5)) + var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)] + var i = 0 + while (i < compressed.srcIds.size) { + var j = compressed.dstPtrs(i) + while (j < compressed.dstPtrs(i + 1)) { + val dstEncodedIndex = compressed.dstEncodedIndices(j) + val dstBlockId = encoder.blockId(dstEncodedIndex) + val dstLocalIndex = encoder.localIndex(dstEncodedIndex) + decompressed += ((compressed.srcIds(i), dstBlockId, dstLocalIndex, compressed.ratings(j))) + j += 1 + } + i += 1 + } + assert(decompressed.toSet === expected) + } + + /** + * Generates an explicit feedback dataset for testing ALS. + * @param numUsers number of users + * @param numItems number of items + * @param rank rank + * @param noiseStd the standard deviation of additive Gaussian noise on training data + * @param seed random seed + * @return (training, test) + */ + def genExplicitTestData( + numUsers: Int, + numItems: Int, + rank: Int, + noiseStd: Double = 0.0, + seed: Long = 11L): (RDD[Rating], RDD[Rating]) = { + val trainingFraction = 0.6 + val testFraction = 0.3 + val totalFraction = trainingFraction + testFraction + val random = new Random(seed) + val userFactors = genFactors(numUsers, rank, random) + val itemFactors = genFactors(numItems, rank, random) + val training = ArrayBuffer.empty[Rating] + val test = ArrayBuffer.empty[Rating] + for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) { + val x = random.nextDouble() + if (x < totalFraction) { + val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1) + if (x < trainingFraction) { + val noise = noiseStd * random.nextGaussian() + training += Rating(userId, itemId, rating + noise.toFloat) + } else { + test += Rating(userId, itemId, rating) + } + } + } + logInfo(s"Generated an explicit feedback dataset with ${training.size} ratings for training " + + s"and ${test.size} for test.") + (sc.parallelize(training, 2), sc.parallelize(test, 2)) + } + + /** + * Generates an implicit feedback dataset for testing ALS. + * @param numUsers number of users + * @param numItems number of items + * @param rank rank + * @param noiseStd the standard deviation of additive Gaussian noise on training data + * @param seed random seed + * @return (training, test) + */ + def genImplicitTestData( + numUsers: Int, + numItems: Int, + rank: Int, + noiseStd: Double = 0.0, + seed: Long = 11L): (RDD[Rating], RDD[Rating]) = { + // The assumption of the implicit feedback model is that unobserved ratings are more likely to + // be negatives. + val positiveFraction = 0.8 + val negativeFraction = 1.0 - positiveFraction + val trainingFraction = 0.6 + val testFraction = 0.3 + val totalFraction = trainingFraction + testFraction + val random = new Random(seed) + val userFactors = genFactors(numUsers, rank, random) + val itemFactors = genFactors(numItems, rank, random) + val training = ArrayBuffer.empty[Rating] + val test = ArrayBuffer.empty[Rating] + for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) { + val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1) + val threshold = if (rating > 0) positiveFraction else negativeFraction + val observed = random.nextDouble() < threshold + if (observed) { + val x = random.nextDouble() + if (x < totalFraction) { + if (x < trainingFraction) { + val noise = noiseStd * random.nextGaussian() + training += Rating(userId, itemId, rating + noise.toFloat) + } else { + test += Rating(userId, itemId, rating) + } + } + } + } + logInfo(s"Generated an implicit feedback dataset with ${training.size} ratings for training " + + s"and ${test.size} for test.") + (sc.parallelize(training, 2), sc.parallelize(test, 2)) + } + + /** + * Generates random user/item factors, with i.i.d. values drawn from U(a, b). + * @param size number of users/items + * @param rank number of features + * @param random random number generator + * @param a min value of the support (default: -1) + * @param b max value of the support (default: 1) + * @return a sequence of (ID, factors) pairs + */ + private def genFactors( + size: Int, + rank: Int, + random: Random, + a: Float = -1.0f, + b: Float = 1.0f): Seq[(Int, Array[Float])] = { + require(size > 0 && size < Int.MaxValue / 3) + require(b > a) + val ids = mutable.Set.empty[Int] + while (ids.size < size) { + ids += random.nextInt() + } + val width = b - a + ids.toSeq.sorted.map(id => (id, Array.fill(rank)(a + random.nextFloat() * width))) + } + + /** + * Test ALS using the given training/test splits and parameters. + * @param training training dataset + * @param test test dataset + * @param rank rank of the matrix factorization + * @param maxIter max number of iterations + * @param regParam regularization constant + * @param implicitPrefs whether to use implicit preference + * @param numUserBlocks number of user blocks + * @param numItemBlocks number of item blocks + * @param targetRMSE target test RMSE + */ + def testALS( + training: RDD[Rating], + test: RDD[Rating], + rank: Int, + maxIter: Int, + regParam: Double, + implicitPrefs: Boolean = false, + numUserBlocks: Int = 2, + numItemBlocks: Int = 3, + targetRMSE: Double = 0.05): Unit = { + val sqlContext = this.sqlContext + import sqlContext.{createSchemaRDD, symbolToUnresolvedAttribute} + val als = new ALS() + .setRank(rank) + .setRegParam(regParam) + .setImplicitPrefs(implicitPrefs) + .setNumUserBlocks(numUserBlocks) + .setNumItemBlocks(numItemBlocks) + val alpha = als.getAlpha + val model = als.fit(training) + val predictions = model.transform(test) + .select('rating, 'prediction) + .map { case Row(rating: Float, prediction: Float) => + (rating.toDouble, prediction.toDouble) + } + val rmse = + if (implicitPrefs) { + // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. + // We limit the ratings and the predictions to interval [0, 1] and compute the weighted RMSE + // with the confidence scores as weights. + val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) => + val confidence = 1.0 + alpha * math.abs(rating) + val rating01 = math.max(math.min(rating, 1.0), 0.0) + val prediction01 = math.max(math.min(prediction, 1.0), 0.0) + val err = prediction01 - rating01 + (confidence, confidence * err * err) + }.reduce { case ((c0, e0), (c1, e1)) => + (c0 + c1, e0 + e1) + } + math.sqrt(weightedSumSq / totalWeight) + } else { + val mse = predictions.map { case (rating, prediction) => + val err = rating - prediction + err * err + }.mean() + math.sqrt(mse) + } + logInfo(s"Test RMSE is $rmse.") + assert(rmse < targetRMSE) + } + + test("exact rank-1 matrix") { + val (training, test) = genExplicitTestData(numUsers = 20, numItems = 40, rank = 1) + testALS(training, test, maxIter = 1, rank = 1, regParam = 1e-5, targetRMSE = 0.001) + testALS(training, test, maxIter = 1, rank = 2, regParam = 1e-5, targetRMSE = 0.001) + } + + test("approximate rank-1 matrix") { + val (training, test) = + genExplicitTestData(numUsers = 20, numItems = 40, rank = 1, noiseStd = 0.01) + testALS(training, test, maxIter = 2, rank = 1, regParam = 0.01, targetRMSE = 0.02) + testALS(training, test, maxIter = 2, rank = 2, regParam = 0.01, targetRMSE = 0.02) + } + + test("approximate rank-2 matrix") { + val (training, test) = + genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03) + testALS(training, test, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03) + } + + test("different block settings") { + val (training, test) = + genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + for ((numUserBlocks, numItemBlocks) <- Seq((1, 1), (1, 2), (2, 1), (2, 2))) { + testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03, + numUserBlocks = numUserBlocks, numItemBlocks = numItemBlocks) + } + } + + test("more blocks than ratings") { + val (training, test) = + genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + testALS(training, test, maxIter = 2, rank = 1, regParam = 1e-4, targetRMSE = 0.002, + numItemBlocks = 5, numUserBlocks = 5) + } + + test("implicit feedback") { + val (training, test) = + genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, implicitPrefs = true, + targetRMSE = 0.3) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index f3b7bfda788fa..e9fc37e000526 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -215,7 +215,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { * @param samplingRate what fraction of the user-product pairs are known * @param matchThreshold max difference allowed to consider a predicted rating correct * @param implicitPrefs flag to test implicit feedback - * @param bulkPredict flag to test bulk prediciton + * @param bulkPredict flag to test bulk predicition * @param negativeWeights whether the generated data can contain negative values * @param numUserBlocks number of user blocks to partition users into * @param numProductBlocks number of product blocks to partition products into From cef1f092a628ac20709857b4388bb10e0b5143b0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 23 Jan 2015 17:53:15 -0800 Subject: [PATCH 09/33] [SPARK-5063] More helpful error messages for several invalid operations This patch adds more helpful error messages for invalid programs that define nested RDDs, broadcast RDDs, perform actions inside of transformations (e.g. calling `count()` from inside of `map()`), and call certain methods on stopped SparkContexts. Currently, these invalid programs lead to confusing NullPointerExceptions at runtime and have been a major source of questions on the mailing list and StackOverflow. In a few cases, I chose to log warnings instead of throwing exceptions in order to avoid any chance that this patch breaks programs that worked "by accident" in earlier Spark releases (e.g. programs that define nested RDDs but never run any jobs with them). In SparkContext, the new `assertNotStopped()` method is used to check whether methods are being invoked on a stopped SparkContext. In some cases, user programs will not crash in spite of calling methods on stopped SparkContexts, so I've only added `assertNotStopped()` calls to methods that always throw exceptions when called on stopped contexts (e.g. by dereferencing a null `dagScheduler` pointer). Author: Josh Rosen Closes #3884 from JoshRosen/SPARK-5063 and squashes the following commits: a38774b [Josh Rosen] Fix spelling typo a943e00 [Josh Rosen] Convert two exceptions into warnings in order to avoid breaking user programs in some edge-cases. 2d0d7f7 [Josh Rosen] Fix test to reflect 1.2.1 compatibility 3f0ea0c [Josh Rosen] Revert two unintentional formatting changes 8e5da69 [Josh Rosen] Remove assertNotStopped() calls for methods that were sometimes safe to call on stopped SC's in Spark 1.2 8cff41a [Josh Rosen] IllegalStateException fix 6ef68d0 [Josh Rosen] Fix Python line length issues. 9f6a0b8 [Josh Rosen] Add improved error messages to PySpark. 13afd0f [Josh Rosen] SparkException -> IllegalStateException 8d404f3 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-5063 b39e041 [Josh Rosen] Fix BroadcastSuite test which broadcasted an RDD 99cc09f [Josh Rosen] Guard against calling methods on stopped SparkContexts. 34833e8 [Josh Rosen] Add more descriptive error message. 57cc8a1 [Josh Rosen] Add error message when directly broadcasting RDD. 15b2e6b [Josh Rosen] [SPARK-5063] Useful error messages for nested RDDs and actions inside of transformations --- .../scala/org/apache/spark/SparkContext.scala | 62 +++++++++++++++---- .../main/scala/org/apache/spark/rdd/RDD.scala | 19 +++++- .../spark/broadcast/BroadcastSuite.scala | 12 +++- .../scala/org/apache/spark/rdd/RDDSuite.scala | 40 ++++++++++++ python/pyspark/context.py | 8 +++ python/pyspark/rdd.py | 11 ++++ 6 files changed, 138 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6a354ed4d1486..8175d175b13df 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -85,6 +85,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val startTime = System.currentTimeMillis() + @volatile private var stopped: Boolean = false + + private def assertNotStopped(): Unit = { + if (stopped) { + throw new IllegalStateException("Cannot call methods on a stopped SparkContext") + } + } + /** * Create a SparkContext that loads settings from system properties (for instance, when * launching with ./bin/spark-submit). @@ -525,6 +533,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * modified collection. Pass a copy of the argument to avoid this. */ def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { + assertNotStopped() new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } @@ -540,6 +549,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * location preferences (hostnames of Spark nodes) for each object. * Create a new partition for each collection item. */ def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { + assertNotStopped() val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) } @@ -549,6 +559,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Hadoop-supported file system URI, and return it as an RDD of Strings. */ def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = { + assertNotStopped() hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], minPartitions).map(pair => pair._2.toString).setName(path) } @@ -582,6 +593,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, String)] = { + assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) NewFileInputFormat.addInputPath(job, new Path(path)) val updateConf = job.getConfiguration @@ -627,6 +639,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @Experimental def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, PortableDataStream)] = { + assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) NewFileInputFormat.addInputPath(job, new Path(path)) val updateConf = job.getConfiguration @@ -651,6 +664,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @Experimental def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration) : RDD[Array[Byte]] = { + assertNotStopped() conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path, classOf[FixedLengthBinaryInputFormat], @@ -684,6 +698,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int = defaultMinPartitions ): RDD[(K, V)] = { + assertNotStopped() // Add necessary security credentials to the JobConf before broadcasting it. SparkHadoopUtil.get.addCredentials(conf) new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions) @@ -703,6 +718,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int = defaultMinPartitions ): RDD[(K, V)] = { + assertNotStopped() // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration)) val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path) @@ -782,6 +798,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli kClass: Class[K], vClass: Class[V], conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { + assertNotStopped() val job = new NewHadoopJob(conf) NewFileInputFormat.addInputPath(job, new Path(path)) val updatedConf = job.getConfiguration @@ -802,6 +819,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli fClass: Class[F], kClass: Class[K], vClass: Class[V]): RDD[(K, V)] = { + assertNotStopped() new NewHadoopRDD(this, fClass, kClass, vClass, conf) } @@ -817,6 +835,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int ): RDD[(K, V)] = { + assertNotStopped() val inputFormatClass = classOf[SequenceFileInputFormat[K, V]] hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions) } @@ -828,9 +847,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. * */ - def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V] - ): RDD[(K, V)] = + def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = { + assertNotStopped() sequenceFile(path, keyClass, valueClass, defaultMinPartitions) + } /** * Version of sequenceFile() for types implicitly convertible to Writables through a @@ -858,6 +878,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli (implicit km: ClassTag[K], vm: ClassTag[V], kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]) : RDD[(K, V)] = { + assertNotStopped() val kc = kcf() val vc = vcf() val format = classOf[SequenceFileInputFormat[Writable, Writable]] @@ -879,6 +900,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli path: String, minPartitions: Int = defaultMinPartitions ): RDD[T] = { + assertNotStopped() sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions) .flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes, Utils.getContextOrSparkClassLoader)) } @@ -954,6 +976,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * The variable will be sent to each cluster only once. */ def broadcast[T: ClassTag](value: T): Broadcast[T] = { + assertNotStopped() + if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) { + // This is a warning instead of an exception in order to avoid breaking user programs that + // might have created RDD broadcast variables but not used them: + logWarning("Can not directly broadcast RDDs; instead, call collect() and " + + "broadcast the result (see SPARK-5063)") + } val bc = env.broadcastManager.newBroadcast[T](value, isLocal) val callSite = getCallSite logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm) @@ -1046,6 +1075,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * memory available for caching. */ def getExecutorMemoryStatus: Map[String, (Long, Long)] = { + assertNotStopped() env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => (blockManagerId.host + ":" + blockManagerId.port, mem) } @@ -1058,6 +1088,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getRDDStorageInfo: Array[RDDInfo] = { + assertNotStopped() val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) rddInfos.filter(_.isCached) @@ -1075,6 +1106,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getExecutorStorageStatus: Array[StorageStatus] = { + assertNotStopped() env.blockManager.master.getStorageStatus } @@ -1084,6 +1116,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getAllPools: Seq[Schedulable] = { + assertNotStopped() // TODO(xiajunluan): We should take nested pools into account taskScheduler.rootPool.schedulableQueue.toSeq } @@ -1094,6 +1127,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getPoolForName(pool: String): Option[Schedulable] = { + assertNotStopped() Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)) } @@ -1101,6 +1135,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Return current scheduling mode */ def getSchedulingMode: SchedulingMode.SchedulingMode = { + assertNotStopped() taskScheduler.schedulingMode } @@ -1206,16 +1241,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { postApplicationEnd() ui.foreach(_.stop()) - // Do this only if not stopped already - best case effort. - // prevent NPE if stopped more than once. - val dagSchedulerCopy = dagScheduler - dagScheduler = null - if (dagSchedulerCopy != null) { + if (!stopped) { + stopped = true env.metricsSystem.report() metadataCleaner.cancel() env.actorSystem.stop(heartbeatReceiver) cleaner.foreach(_.stop()) - dagSchedulerCopy.stop() + dagScheduler.stop() + dagScheduler = null taskScheduler = null // TODO: Cache.stop()? env.stop() @@ -1289,8 +1322,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli partitions: Seq[Int], allowLocal: Boolean, resultHandler: (Int, U) => Unit) { - if (dagScheduler == null) { - throw new SparkException("SparkContext has been shutdown") + if (stopped) { + throw new IllegalStateException("SparkContext has been shutdown") } val callSite = getCallSite val cleanedFunc = clean(func) @@ -1377,6 +1410,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], timeout: Long): PartialResult[R] = { + assertNotStopped() val callSite = getCallSite logInfo("Starting job: " + callSite.shortForm) val start = System.nanoTime @@ -1399,6 +1433,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli resultHandler: (Int, U) => Unit, resultFunc: => R): SimpleFutureAction[R] = { + assertNotStopped() val cleanF = clean(processPartition) val callSite = getCallSite val waiter = dagScheduler.submitJob( @@ -1417,11 +1452,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * for more information. */ def cancelJobGroup(groupId: String) { + assertNotStopped() dagScheduler.cancelJobGroup(groupId) } /** Cancel all jobs that have been scheduled or are running. */ def cancelAllJobs() { + assertNotStopped() dagScheduler.cancelAllJobs() } @@ -1468,7 +1505,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def getCheckpointDir = checkpointDir /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ - def defaultParallelism: Int = taskScheduler.defaultParallelism + def defaultParallelism: Int = { + assertNotStopped() + taskScheduler.defaultParallelism + } /** Default min number of partitions for Hadoop RDDs when not given by user */ @deprecated("use defaultMinPartitions", "1.0.0") diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 97012c7033f9f..ab7410a1f7f99 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -76,10 +76,27 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli * on RDD internals. */ abstract class RDD[T: ClassTag]( - @transient private var sc: SparkContext, + @transient private var _sc: SparkContext, @transient private var deps: Seq[Dependency[_]] ) extends Serializable with Logging { + if (classOf[RDD[_]].isAssignableFrom(elementClassTag.runtimeClass)) { + // This is a warning instead of an exception in order to avoid breaking user programs that + // might have defined nested RDDs without running jobs with them. + logWarning("Spark does not support nested RDDs (see SPARK-5063)") + } + + private def sc: SparkContext = { + if (_sc == null) { + throw new SparkException( + "RDD transformations and actions can only be invoked by the driver, not inside of other " + + "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because " + + "the values transformation and count action cannot be performed inside of the rdd1.map " + + "transformation. For more information, see SPARK-5063.") + } + _sc + } + /** Construct an RDD with just a one-to-one dependency on one parent */ def this(@transient oneParent: RDD[_]) = this(oneParent.context , List(new OneToOneDependency(oneParent))) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index b0a70f012f1f3..af3272692d7a1 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -170,6 +170,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { testPackage.runCallSiteTest(sc) } + test("Broadcast variables cannot be created after SparkContext is stopped (SPARK-5065)") { + sc = new SparkContext("local", "test") + sc.stop() + val thrown = intercept[IllegalStateException] { + sc.broadcast(Seq(1, 2, 3)) + } + assert(thrown.getMessage.toLowerCase.contains("stopped")) + } + /** * Verify the persistence of state associated with an HttpBroadcast in either local mode or * local-cluster mode (when distributed = true). @@ -349,8 +358,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { package object testPackage extends Assertions { def runCallSiteTest(sc: SparkContext) { - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) - val broadcast = sc.broadcast(rdd) + val broadcast = sc.broadcast(Array(1, 2, 3, 4)) broadcast.destroy() val thrown = intercept[SparkException] { broadcast.value } assert(thrown.getMessage.contains("BroadcastSuite.scala")) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 381ee2d45630f..e33b4bbbb8e4c 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -927,4 +927,44 @@ class RDDSuite extends FunSuite with SharedSparkContext { mutableDependencies += dep } } + + test("nested RDDs are not supported (SPARK-5063)") { + val rdd: RDD[Int] = sc.parallelize(1 to 100) + val rdd2: RDD[Int] = sc.parallelize(1 to 100) + val thrown = intercept[SparkException] { + val nestedRDD: RDD[RDD[Int]] = rdd.mapPartitions { x => Seq(rdd2.map(x => x)).iterator } + nestedRDD.count() + } + assert(thrown.getMessage.contains("SPARK-5063")) + } + + test("actions cannot be performed inside of transformations (SPARK-5063)") { + val rdd: RDD[Int] = sc.parallelize(1 to 100) + val rdd2: RDD[Int] = sc.parallelize(1 to 100) + val thrown = intercept[SparkException] { + rdd.map(x => x * rdd2.count).collect() + } + assert(thrown.getMessage.contains("SPARK-5063")) + } + + test("cannot run actions after SparkContext has been stopped (SPARK-5063)") { + val existingRDD = sc.parallelize(1 to 100) + sc.stop() + val thrown = intercept[IllegalStateException] { + existingRDD.count() + } + assert(thrown.getMessage.contains("shutdown")) + } + + test("cannot call methods on a stopped SparkContext (SPARK-5063)") { + sc.stop() + def assertFails(block: => Any): Unit = { + val thrown = intercept[IllegalStateException] { + block + } + assert(thrown.getMessage.contains("stopped")) + } + assertFails { sc.parallelize(1 to 100) } + assertFails { sc.textFile("/nonexistent-path") } + } } diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 64f6a3ca6bf4c..568e21f3803bf 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -229,6 +229,14 @@ def _ensure_initialized(cls, instance=None, gateway=None): else: SparkContext._active_spark_context = instance + def __getnewargs__(self): + # This method is called when attempting to pickle SparkContext, which is always an error: + raise Exception( + "It appears that you are attempting to reference SparkContext from a broadcast " + "variable, action, or transforamtion. SparkContext can only be used on the driver, " + "not in code that it run on workers. For more information, see SPARK-5063." + ) + def __enter__(self): """ Enable 'with SparkContext(...) as sc: app(sc)' syntax. diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4977400ac1c05..f4cfe4845dc20 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -141,6 +141,17 @@ def id(self): def __repr__(self): return self._jrdd.toString() + def __getnewargs__(self): + # This method is called when attempting to pickle an RDD, which is always an error: + raise Exception( + "It appears that you are attempting to broadcast an RDD or reference an RDD from an " + "action or transformation. RDD transformations and actions can only be invoked by the " + "driver, not inside of other transformations; for example, " + "rdd1.map(lambda x: rdd2.values.count() * x) is invalid because the values " + "transformation and count action cannot be performed inside of the rdd1.map " + "transformation. For more information, see SPARK-5063." + ) + @property def context(self): """ From e224dbb011789297cd6c6ba095f702c042869ed6 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 23 Jan 2015 19:25:15 -0800 Subject: [PATCH 10/33] [SPARK-5351][GraphX] Do not use Partitioner.defaultPartitioner as a partitioner of EdgeRDDImp... If the value of 'spark.default.parallelism' does not match the number of partitoins in EdgePartition(EdgeRDDImpl), the following error occurs in ReplicatedVertexView.scala:72; object GraphTest extends Logging { def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): VertexRDD[Int] = { graph.aggregateMessages( ctx => { ctx.sendToSrc(1) ctx.sendToDst(2) }, _ + _) } } val g = GraphLoader.edgeListFile(sc, "graph.txt") val rdd = GraphTest.run(g) java.lang.IllegalArgumentException: Can't zip RDDs with unequal numbers of partitions at org.apache.spark.rdd.ZippedPartitionsBaseRDD.getPartitions(ZippedPartitionsRDD.scala:57) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:206) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:204) at scala.Option.getOrElse(Option.scala:120) at org.apache.spark.rdd.RDD.partitions(RDD.scala:204) at org.apache.spark.rdd.MapPartitionsRDD.getPartitions(MapPartitionsRDD.scala:32) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:206) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:204) at scala.Option.getOrElse(Option.scala:120) at org.apache.spark.rdd.RDD.partitions(RDD.scala:204) at org.apache.spark.ShuffleDependency.(Dependency.scala:82) at org.apache.spark.rdd.ShuffledRDD.getDependencies(ShuffledRDD.scala:80) at org.apache.spark.rdd.RDD$$anonfun$dependencies$2.apply(RDD.scala:193) at org.apache.spark.rdd.RDD$$anonfun$dependencies$2.apply(RDD.scala:191) ... Author: Takeshi Yamamuro Closes #4136 from maropu/EdgePartitionBugFix and squashes the following commits: 0cd8942 [Ankur Dave] Use more concise getOrElse aad4a2c [Ankur Dave] Add unit test for non-default number of edge partitions 0a2f32b [Takeshi Yamamuro] Do not use Partitioner.defaultPartitioner as a partitioner of EdgeRDDImpl --- .../spark/graphx/impl/EdgeRDDImpl.scala | 4 ++-- .../org/apache/spark/graphx/GraphSuite.scala | 20 +++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala index 897c7ee12a436..f1550ac2e18ad 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -19,7 +19,7 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} -import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext} +import org.apache.spark.{OneToOneDependency, HashPartitioner, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -46,7 +46,7 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( * partitioner that allows co-partitioning with `partitionsRDD`. */ override val partitioner = - partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD))) + partitionsRDD.partitioner.orElse(Some(new HashPartitioner(partitions.size))) override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 9da0064104fb6..ed9876b8dc21c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -386,4 +386,24 @@ class GraphSuite extends FunSuite with LocalSparkContext { } } + test("non-default number of edge partitions") { + val n = 10 + val defaultParallelism = 3 + val numEdgePartitions = 4 + assert(defaultParallelism != numEdgePartitions) + val conf = new org.apache.spark.SparkConf() + .set("spark.default.parallelism", defaultParallelism.toString) + val sc = new SparkContext("local", "test", conf) + try { + val edges = sc.parallelize((1 to n).map(x => (x: VertexId, 0: VertexId)), + numEdgePartitions) + val graph = Graph.fromEdgeTuples(edges, 1) + val neighborAttrSums = graph.mapReduceTriplets[Int]( + et => Iterator((et.dstId, et.srcAttr)), _ + _) + assert(neighborAttrSums.collect.toSet === Set((0: VertexId, n))) + } finally { + sc.stop() + } + } + } From 09e09c548e7722fca1cdc89bd37de2cee58f4ce9 Mon Sep 17 00:00:00 2001 From: Jongyoul Lee Date: Fri, 23 Jan 2015 23:34:11 -0800 Subject: [PATCH 11/33] [SPARK-5058] Part 2. Typos and broken URL - Also fixed java link Author: Jongyoul Lee Closes #4172 from jongyoul/SPARK-FIXDOC and squashes the following commits: 6be03e5 [Jongyoul Lee] [SPARK-5058] Part 2. Typos and broken URL - Also fixed java link --- docs/streaming-kafka-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 0e38fe2144e9f..77c0abbbacbd0 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -29,7 +29,7 @@ title: Spark Streaming + Kafka Integration Guide streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]); See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). From 0d1e67ee9b29b51bccfc8a319afe9f9b4581afc8 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 24 Jan 2015 11:00:35 -0800 Subject: [PATCH 12/33] [SPARK-5214][Test] Add a test to demonstrate EventLoop can be stopped in the event thread Author: zsxwing Closes #4174 from zsxwing/SPARK-5214-unittest and squashes the following commits: 443e564 [zsxwing] Change the check interval to 5ms 7aaa2d7 [zsxwing] Add a test to demonstrate EventLoop can be stopped in the event thread --- .../apache/spark/util/EventLoopSuite.scala | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index 10541f878476c..1026cb2aa7cae 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -41,7 +41,7 @@ class EventLoopSuite extends FunSuite with Timeouts { } eventLoop.start() (1 to 100).foreach(eventLoop.post) - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert((1 to 100) === buffer.toSeq) } eventLoop.stop() @@ -76,7 +76,7 @@ class EventLoopSuite extends FunSuite with Timeouts { } eventLoop.start() eventLoop.post(1) - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert(e === receivedError) } eventLoop.stop() @@ -98,7 +98,7 @@ class EventLoopSuite extends FunSuite with Timeouts { } eventLoop.start() eventLoop.post(1) - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert(e === receivedError) assert(eventLoop.isActive) } @@ -153,7 +153,7 @@ class EventLoopSuite extends FunSuite with Timeouts { }.start() } - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert(threadNum * eventsFromEachThread === receivedEventsCount) } eventLoop.stop() @@ -185,4 +185,22 @@ class EventLoopSuite extends FunSuite with Timeouts { } assert(false === eventLoop.isActive) } + + test("EventLoop: stop in eventThread") { + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + stop() + } + + override def onError(e: Throwable): Unit = { + } + + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + } } From d22ca1e921d792e49600c4c484a846e739c340ca Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 25 Jan 2015 00:24:59 -0800 Subject: [PATCH 13/33] Closes #4157 From 412a58e118ef083ea1d1d6daccd9c531852baf53 Mon Sep 17 00:00:00 2001 From: Idan Zalzberg Date: Sun, 25 Jan 2015 11:28:05 -0800 Subject: [PATCH 14/33] Add comment about defaultMinPartitions Added a comment about using math.min for choosing default partition count Author: Idan Zalzberg Closes #4102 from idanz/patch-2 and squashes the following commits: 50e9d58 [Idan Zalzberg] Update SparkContext.scala --- core/src/main/scala/org/apache/spark/SparkContext.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8175d175b13df..4c4ee04cc515e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1514,7 +1514,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @deprecated("use defaultMinPartitions", "1.0.0") def defaultMinSplits: Int = math.min(defaultParallelism, 2) - /** Default min number of partitions for Hadoop RDDs when not given by user */ + /** + * Default min number of partitions for Hadoop RDDs when not given by user + * Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2. + * The reasons for this are discussed in https://github.com/mesos/spark/pull/718 + */ def defaultMinPartitions: Int = math.min(defaultParallelism, 2) private val nextShuffleId = new AtomicInteger(0) From 2d9887bae52a1b4d01765c28dfe333396338bfe6 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Sun, 25 Jan 2015 14:17:59 -0800 Subject: [PATCH 15/33] [SPARK-5401] set executor ID before creating MetricsSystem Author: Ryan Williams Closes #4194 from ryan-williams/metrics and squashes the following commits: 7c5a33f [Ryan Williams] set executor ID before creating MetricsSystem --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 4 ++++ .../main/scala/org/apache/spark/metrics/MetricsSystem.scala | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4d418037bd33f..1264a8126153b 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -326,6 +326,10 @@ object SparkEnv extends Logging { // Then we can start the metrics system. MetricsSystem.createMetricsSystem("driver", conf, securityManager) } else { + // We need to set the executor ID before the MetricsSystem is created because sources and + // sinks specified in the metrics configuration file will want to incorporate this executor's + // ID into the metrics they report. + conf.set("spark.executor.id", executorId) val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager) ms.start() ms diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 45633e3de01dd..83e8eb71260eb 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -130,8 +130,8 @@ private[spark] class MetricsSystem private ( if (appId.isDefined && executorId.isDefined) { MetricRegistry.name(appId.get, executorId.get, source.sourceName) } else { - // Only Driver and Executor are set spark.app.id and spark.executor.id. - // For instance, Master and Worker are not related to a specific application. + // Only Driver and Executor set spark.app.id and spark.executor.id. + // Other instance types, e.g. Master and Worker, are not related to a specific application. val warningMsg = s"Using default name $defaultName for source because %s is not set." if (appId.isEmpty) { logWarning(warningMsg.format("spark.app.id")) } if (executorId.isEmpty) { logWarning(warningMsg.format("spark.executor.id")) } From aea25482c370fbcf712a464501605bc16ee4ed5d Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Sun, 25 Jan 2015 14:20:02 -0800 Subject: [PATCH 16/33] [SPARK-5402] log executor ID at executor-construction time also rename "slaveHostname" to "executorHostname" Author: Ryan Williams Closes #4195 from ryan-williams/exec and squashes the following commits: e60a7bb [Ryan Williams] log executor ID at executor-construction time --- .../scala/org/apache/spark/executor/Executor.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 42566d1a14093..d8c2e41a7c715 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -41,11 +41,14 @@ import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils} */ private[spark] class Executor( executorId: String, - slaveHostname: String, + executorHostname: String, env: SparkEnv, isLocal: Boolean = false) extends Logging { + + logInfo(s"Starting executor ID $executorId on host $executorHostname") + // Application dependencies (added through SparkContext) that we've fetched so far on this node. // Each map holds the master's timestamp for the version of that file or JAR we got. private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() @@ -58,12 +61,12 @@ private[spark] class Executor( @volatile private var isStopped = false // No ip or host:port - just hostname - Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") + Utils.checkHost(executorHostname, "Expected executed slave to be a hostname") // must not have port specified. - assert (0 == Utils.parseHostPort(slaveHostname)._2) + assert (0 == Utils.parseHostPort(executorHostname)._2) // Make sure the local hostname we report matches the cluster scheduler's name for this host - Utils.setCustomHostname(slaveHostname) + Utils.setCustomHostname(executorHostname) if (!isLocal) { // Setup an uncaught exception handler for non-local mode. From c586b45dd25b50be7f195df2ce91b307e1ed71a9 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 25 Jan 2015 15:08:05 -0800 Subject: [PATCH 17/33] SPARK-3852 [DOCS] Document spark.driver.extra* configs As per the JIRA. I copied the `spark.executor.extra*` text, but removed info that appears to be specific to the `executor` config and not `driver`. Author: Sean Owen Closes #4185 from srowen/SPARK-3852 and squashes the following commits: f60a8a1 [Sean Owen] Document spark.driver.extra* configs --- docs/configuration.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index efbab4085317a..7c5b6d011cfd3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -197,6 +197,27 @@ Apart from these, the following properties are also available, and may be useful #### Runtime Environment + + + + + + + + + + + + + + + From 383425ab70aef4d4961f60309309ba5cb663db5f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 25 Jan 2015 15:11:57 -0800 Subject: [PATCH 18/33] SPARK-3782 [CORE] Direct use of log4j in AkkaUtils interferes with certain logging configurations Although the underlying issue can I think be solved by having user code use slf4j 1.7.6+, it might be helpful and consistent to update Spark's slf4j too. I see no reason to believe it would be incompatible with other 1.7.x releases: http://www.slf4j.org/news.html Lots of different version of slf4j are in use in the wild and anecdotally I have never seen an issue mixing them. Author: Sean Owen Closes #4184 from srowen/SPARK-3782 and squashes the following commits: 5608d28 [Sean Owen] Update slf4j to 1.7.10 --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index b993391b15042..05cb3797fc55b 100644 --- a/pom.xml +++ b/pom.xml @@ -117,7 +117,7 @@ 2.0.1 0.21.0 shaded-protobuf - 1.7.5 + 1.7.10 1.2.17 1.0.4 2.4.1 From 1c30afdf94b27e1ad65df0735575306e65d148a1 Mon Sep 17 00:00:00 2001 From: Jacek Lewandowski Date: Sun, 25 Jan 2015 15:15:09 -0800 Subject: [PATCH 19/33] SPARK-5382: Use SPARK_CONF_DIR in spark-class if it is defined Author: Jacek Lewandowski Closes #4179 from jacek-lewandowski/SPARK-5382-1.3 and squashes the following commits: 55d7791 [Jacek Lewandowski] SPARK-5382: Use SPARK_CONF_DIR in spark-class if it is defined --- bin/spark-class | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bin/spark-class b/bin/spark-class index 1b945461fabc8..2f0441bb3c1c2 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -29,6 +29,7 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" +export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}" . "$FWDIR"/bin/load-spark-env.sh @@ -120,8 +121,8 @@ fi JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM" # Load extra JAVA_OPTS from conf/java-opts, if it exists -if [ -e "$FWDIR/conf/java-opts" ] ; then - JAVA_OPTS="$JAVA_OPTS `cat "$FWDIR"/conf/java-opts`" +if [ -e "$SPARK_CONF_DIR/java-opts" ] ; then + JAVA_OPTS="$JAVA_OPTS `cat "$SPARK_CONF_DIR"/java-opts`" fi # Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala! From 9f6435763d173d2abf82d16b5878983fa8bf3419 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 25 Jan 2015 15:25:05 -0800 Subject: [PATCH 20/33] SPARK-4506 [DOCS] Addendum: Update more docs to reflect that standalone works in cluster mode This is a trivial addendum to SPARK-4506, which was already resolved. noted by Asim Jalis in SPARK-4506. Author: Sean Owen Closes #4160 from srowen/SPARK-4506 and squashes the following commits: 5f5f7df [Sean Owen] Update more docs to reflect that standalone works in cluster mode --- docs/submitting-applications.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 3bd1deaccfafe..14a87f8436984 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -58,8 +58,8 @@ for applications that involve the REPL (e.g. Spark shell). Alternatively, if your application is submitted from a machine far from the worker machines (e.g. locally on your laptop), it is common to use `cluster` mode to minimize network latency between -the drivers and the executors. Note that `cluster` mode is currently not supported for standalone -clusters, Mesos clusters, or Python applications. +the drivers and the executors. Note that `cluster` mode is currently not supported for +Mesos clusters or Python applications. For Python applications, simply pass a `.py` file in the place of `` instead of a JAR, and add Python `.zip`, `.egg` or `.py` files to the search path with `--py-files`. From 8f5c827b01026bf45fc774ed7387f11a941abea8 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sun, 25 Jan 2015 15:34:20 -0800 Subject: [PATCH 21/33] [SPARK-5344][WebUI] HistoryServer cannot recognize that inprogress file was renamed to completed file `FsHistoryProvider` tries to update application status but if `checkForLogs` is called before `.inprogress` file is renamed to completed file, the file is not recognized as completed. Author: Kousuke Saruta Closes #4132 from sarutak/SPARK-5344 and squashes the following commits: 9658008 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-5344 d2c72b6 [Kousuke Saruta] Fixed update issue of FsHistoryProvider --- .../deploy/history/FsHistoryProvider.scala | 4 +++- .../history/FsHistoryProviderSuite.scala | 23 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 2b084a2d73b78..0ae45f4ad9130 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -203,7 +203,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (!logInfos.isEmpty) { val newApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() def addIfAbsent(info: FsApplicationHistoryInfo) = { - if (!newApps.contains(info.id)) { + if (!newApps.contains(info.id) || + newApps(info.id).logPath.endsWith(EventLoggingListener.IN_PROGRESS) && + !info.logPath.endsWith(EventLoggingListener.IN_PROGRESS)) { newApps += (info.id -> info) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 8379883e065e7..3fbc1a21d10ed 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -167,6 +167,29 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers list.size should be (1) } + test("history file is renamed from inprogress to completed") { + val conf = new SparkConf() + .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) + .set("spark.testing", "true") + val provider = new FsHistoryProvider(conf) + + val logFile1 = new File(testDir, "app1" + EventLoggingListener.IN_PROGRESS) + writeFile(logFile1, true, None, + SparkListenerApplicationStart("app1", Some("app1"), 1L, "test"), + SparkListenerApplicationEnd(2L) + ) + provider.checkForLogs() + val appListBeforeRename = provider.getListing() + appListBeforeRename.size should be (1) + appListBeforeRename.head.logPath should endWith(EventLoggingListener.IN_PROGRESS) + + logFile1.renameTo(new File(testDir, "app1")) + provider.checkForLogs() + val appListAfterRename = provider.getListing() + appListAfterRename.size should be (1) + appListAfterRename.head.logPath should not endWith(EventLoggingListener.IN_PROGRESS) + } + private def writeFile(file: File, isNewFormat: Boolean, codec: Option[CompressionCodec], events: SparkListenerEvent*) = { val out = From fc2168f04e9b2c7ce45f59db0bf632a26d56c72b Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Sun, 25 Jan 2015 16:48:26 -0800 Subject: [PATCH 22/33] [SPARK-5326] Show fetch wait time as optional metric in the UI With this change, here's what the UI looks like: ![image](https://cloud.githubusercontent.com/assets/1108612/5809994/1ec8a904-9ff4-11e4-8f24-6a59a1a858f7.png) If you want to locally test this, you need to spin up multiple executors, because the shuffle read metrics are only shown for data read remotely. Author: Kay Ousterhout Closes #4110 from kayousterhout/SPARK-5326 and squashes the following commits: 610051e [Kay Ousterhout] Josh style comments 5feaa28 [Kay Ousterhout] What is the difference here?? aa129cb [Kay Ousterhout] Removed inadvertent change 721c742 [Kay Ousterhout] Improved tooltip f3a7111 [Kay Ousterhout] Style fix 679b4e9 [Kay Ousterhout] [SPARK-5326] Show fetch wait time as optional metric in the UI --- .../org/apache/spark/ui/static/webui.css | 3 +- .../scala/org/apache/spark/ui/ToolTips.scala | 6 ++- .../org/apache/spark/ui/jobs/StagePage.scala | 40 ++++++++++++++++++- .../spark/ui/jobs/TaskDetailsClassNames.scala | 1 + 4 files changed, 45 insertions(+), 5 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index a1f7133f897ee..f23ba9dba167f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -190,6 +190,7 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ -.scheduler_delay, .deserialization_time, .serialization_time, .getting_result_time { +.scheduler_delay, .deserialization_time, .fetch_wait_time, .serialization_time, +.getting_result_time { display: none; } diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 6f446c5a95a0a..4307029d44fbb 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -24,8 +24,10 @@ private[spark] object ToolTips { scheduler delay is large, consider decreasing the size of tasks or decreasing the size of task results.""" - val TASK_DESERIALIZATION_TIME = - """Time spent deserializating the task closure on the executor.""" + val TASK_DESERIALIZATION_TIME = "Time spent deserializing the task closure on the executor." + + val SHUFFLE_READ_BLOCKED_TIME = + "Time that the task spent blocked waiting for shuffle data to be read from remote machines." val INPUT = "Bytes read from Hadoop or from Spark storage." diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 09a936c2234c0..d8be1b20b3acd 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -132,6 +132,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Task Deserialization Time + {if (hasShuffleRead) { +
  • + + + Shuffle Read Blocked Time + +
  • + }}
  • @@ -167,7 +176,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ {if (hasInput) Seq(("Input", "")) else Nil} ++ {if (hasOutput) Seq(("Output", "")) else Nil} ++ - {if (hasShuffleRead) Seq(("Shuffle Read", "")) else Nil} ++ + {if (hasShuffleRead) { + Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), + ("Shuffle Read", "")) + } else { + Nil + }} ++ {if (hasShuffleWrite) Seq(("Write Time", ""), ("Shuffle Write", "")) else Nil} ++ {if (hasBytesSpilled) Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) else Nil} ++ @@ -271,6 +285,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } val outputQuantiles =
  • +: getFormattedSizeQuantiles(outputSizes) + val shuffleReadBlockedTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble + } + val shuffleReadBlockedQuantiles = +: + getFormattedTimeQuantiles(shuffleReadBlockedTimes) + val shuffleReadSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble } @@ -308,7 +328,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {gettingResultQuantiles}, if (hasInput) {inputQuantiles} else Nil, if (hasOutput) {outputQuantiles} else Nil, - if (hasShuffleRead) {shuffleReadQuantiles} else Nil, + if (hasShuffleRead) { + + {shuffleReadBlockedQuantiles} + + {shuffleReadQuantiles} + } else { + Nil + }, if (hasShuffleWrite) {shuffleWriteQuantiles} else Nil, if (hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil, if (hasBytesSpilled) {diskBytesSpilledQuantiles} else Nil) @@ -377,6 +404,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { .map(m => s"${Utils.bytesToString(m.bytesWritten)}") .getOrElse("") + val maybeShuffleReadBlockedTime = metrics.flatMap(_.shuffleReadMetrics).map(_.fetchWaitTime) + val shuffleReadBlockedTimeSortable = maybeShuffleReadBlockedTime.map(_.toString).getOrElse("") + val shuffleReadBlockedTimeReadable = + maybeShuffleReadBlockedTime.map(ms => UIUtils.formatDuration(ms)).getOrElse("") + val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead) val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("") val shuffleReadReadable = maybeShuffleRead.map(Utils.bytesToString).getOrElse("") @@ -449,6 +481,10 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { }} {if (hasShuffleRead) { + diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index 2d13bb6ddde42..37cf2c207ba40 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -27,6 +27,7 @@ package org.apache.spark.ui.jobs private[spark] object TaskDetailsClassNames { val SCHEDULER_DELAY = "scheduler_delay" val TASK_DESERIALIZATION_TIME = "deserialization_time" + val SHUFFLE_READ_BLOCKED_TIME = "fetch_wait_time" val RESULT_SERIALIZATION_TIME = "serialization_time" val GETTING_RESULT_TIME = "getting_result_time" } From 0528b85cf96f9c9c074b5fbb5b9c5dd8071c0bc7 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 25 Jan 2015 19:16:44 -0800 Subject: [PATCH 23/33] SPARK-4430 [STREAMING] [TEST] Apache RAT Checks fail spuriously on test files Another trivial one. The RAT failure was due to temp files from `FailureSuite` not being cleaned up. This just makes the cleanup more reliable by using the standard temp dir mechanism. Author: Sean Owen Closes #4189 from srowen/SPARK-4430 and squashes the following commits: 9ea63ff [Sean Owen] Properly acquire a temp directory to ensure it is cleaned up at shutdown, which helps avoid a RAT check failure --- .../scala/org/apache/spark/streaming/FailureSuite.scala | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 40434b1f9b709..6500608bba87c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -28,21 +28,16 @@ import java.io.File */ class FailureSuite extends TestSuiteBase with Logging { - var directory = "FailureSuite" + val directory = Utils.createTempDir().getAbsolutePath val numBatches = 30 override def batchDuration = Milliseconds(1000) override def useManualClock = false - override def beforeFunction() { - super.beforeFunction() - Utils.deleteRecursively(new File(directory)) - } - override def afterFunction() { - super.afterFunction() Utils.deleteRecursively(new File(directory)) + super.afterFunction() } test("multiple failures with map") { From 8df9435512e4b4df12b55143a910d9ce3d4bd62c Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sun, 25 Jan 2015 19:28:53 -0800 Subject: [PATCH 24/33] [SPARK-5268] don't stop CoarseGrainedExecutorBackend for irrelevant DisassociatedEvent https://issues.apache.org/jira/browse/SPARK-5268 In CoarseGrainedExecutorBackend, we subscribe DisassociatedEvent in executor backend actor and exit the program upon receive such event... let's consider the following case The user may develop an Akka-based program which starts the actor with Spark's actor system and communicate with an external actor system (e.g. an Akka-based receiver in spark streaming which communicates with an external system) If the external actor system fails or disassociates with the actor within spark's system with purpose, we may receive DisassociatedEvent and the executor is restarted. This is not the expected behavior..... ---- This is a simple fix to check the event before making the quit decision Author: CodingCat Closes #4063 from CodingCat/SPARK-5268 and squashes the following commits: 4d7d48e [CodingCat] simplify the log 18c36f4 [CodingCat] more descriptive log f299e0b [CodingCat] clean log 1632e79 [CodingCat] check whether DisassociatedEvent is relevant before quit --- .../spark/executor/CoarseGrainedExecutorBackend.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 9a4adfbbb3d71..823825302658c 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -84,8 +84,12 @@ private[spark] class CoarseGrainedExecutorBackend( } case x: DisassociatedEvent => - logError(s"Driver $x disassociated! Shutting down.") - System.exit(1) + if (x.remoteAddress == driver.anchorPath.address) { + logError(s"Driver $x disassociated! Shutting down.") + System.exit(1) + } else { + logWarning(s"Received irrelevant DisassociatedEvent $x") + } case StopExecutor => logInfo("Driver commanded a shutdown") From 81251682edfb6a0de47e9dd72b2f63ee298fca33 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sun, 25 Jan 2015 22:18:09 -0800 Subject: [PATCH 25/33] [SPARK-5384][mllib] Vectors.sqdist returns inconsistent results for sparse/dense vectors when the vectors have different lengths JIRA issue: https://issues.apache.org/jira/browse/SPARK-5384 Currently `Vectors.sqdist` return inconsistent result for sparse/dense vectors when the vectors have different lengths, please refer to JIRA for sample PR scope: Unify the sqdist logic for dense/sparse vectors and fix the inconsistency, also remove the possible sparse to dense conversion in the original code. For reviewers: Maybe we should first discuss what's the correct behavior. 1. Vectors for sqdist must have the same length, like in breeze? 2. If they can have different lengths, what's the correct result for sqdist? (should the extra part get into calculation?) I'll update PR with more optimization and additional ut afterwards. Thanks. Author: Yuhao Yang Closes #4183 from hhbyyh/fixDouble and squashes the following commits: 1f17328 [Yuhao Yang] limit PR scope to size constraints only 54cbf97 [Yuhao Yang] fix Vectors.sqdist inconsistence --- .../scala/org/apache/spark/mllib/linalg/Vectors.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 7ee0224ad4662..b3022add38469 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -333,7 +333,7 @@ object Vectors { math.pow(sum, 1.0 / p) } } - + /** * Returns the squared distance between two Vectors. * @param v1 first Vector. @@ -341,8 +341,9 @@ object Vectors { * @return squared distance between two Vectors. */ def sqdist(v1: Vector, v2: Vector): Double = { + require(v1.size == v2.size, "vector dimension mismatch") var squaredDistance = 0.0 - (v1, v2) match { + (v1, v2) match { case (v1: SparseVector, v2: SparseVector) => val v1Values = v1.values val v1Indices = v1.indices @@ -350,12 +351,12 @@ object Vectors { val v2Indices = v2.indices val nnzv1 = v1Indices.size val nnzv2 = v2Indices.size - + var kv1 = 0 var kv2 = 0 while (kv1 < nnzv1 || kv2 < nnzv2) { var score = 0.0 - + if (kv2 >= nnzv2 || (kv1 < nnzv1 && v1Indices(kv1) < v2Indices(kv2))) { score = v1Values(kv1) kv1 += 1 @@ -397,7 +398,7 @@ object Vectors { val nnzv1 = indices.size val nnzv2 = v2.size var iv1 = if (nnzv1 > 0) indices(kv1) else -1 - + while (kv2 < nnzv2) { var score = 0.0 if (kv2 != iv1) { From 142093179a4c40bdd90744191034de7b94a963ff Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 26 Jan 2015 12:51:32 -0800 Subject: [PATCH 26/33] [SPARK-5355] use j.u.c.ConcurrentHashMap instead of TrieMap j.u.c.ConcurrentHashMap is more battle tested. cc rxin JoshRosen pwendell Author: Davies Liu Closes #4208 from davies/safe-conf and squashes the following commits: c2182dc [Davies Liu] address comments, fix tests 3a1d821 [Davies Liu] fix test da14ced [Davies Liu] Merge branch 'master' of github.com:apache/spark into safe-conf ae4d305 [Davies Liu] change to j.u.c.ConcurrentMap f8fa1cf [Davies Liu] change to TrieMap a1d769a [Davies Liu] make SparkConf thread-safe --- .../scala/org/apache/spark/SparkConf.scala | 38 ++++++++++--------- .../deploy/worker/WorkerArgumentsTest.scala | 4 +- .../apache/spark/storage/LocalDirsSuite.scala | 2 +- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index f9d4aa4240e9d..cd91c8f87547b 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -17,9 +17,11 @@ package org.apache.spark +import java.util.concurrent.ConcurrentHashMap + import scala.collection.JavaConverters._ -import scala.collection.concurrent.TrieMap -import scala.collection.mutable.{HashMap, LinkedHashSet} +import scala.collection.mutable.LinkedHashSet + import org.apache.spark.serializer.KryoSerializer /** @@ -47,12 +49,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Create a SparkConf that loads defaults from system properties and the classpath */ def this() = this(true) - private[spark] val settings = new TrieMap[String, String]() + private val settings = new ConcurrentHashMap[String, String]() if (loadDefaults) { // Load any spark.* system properties for ((k, v) <- System.getProperties.asScala if k.startsWith("spark.")) { - settings(k) = v + set(k, v) } } @@ -64,7 +66,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { if (value == null) { throw new NullPointerException("null value for " + key) } - settings(key) = value + settings.put(key, value) this } @@ -130,15 +132,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Set multiple parameters together */ def setAll(settings: Traversable[(String, String)]) = { - this.settings ++= settings + this.settings.putAll(settings.toMap.asJava) this } /** Set a parameter if it isn't already configured */ def setIfMissing(key: String, value: String): SparkConf = { - if (!settings.contains(key)) { - settings(key) = value - } + settings.putIfAbsent(key, value) this } @@ -164,21 +164,23 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Get a parameter; throws a NoSuchElementException if it's not set */ def get(key: String): String = { - settings.getOrElse(key, throw new NoSuchElementException(key)) + getOption(key).getOrElse(throw new NoSuchElementException(key)) } /** Get a parameter, falling back to a default if not set */ def get(key: String, defaultValue: String): String = { - settings.getOrElse(key, defaultValue) + getOption(key).getOrElse(defaultValue) } /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { - settings.get(key) + Option(settings.get(key)) } /** Get all parameters as a list of pairs */ - def getAll: Array[(String, String)] = settings.toArray + def getAll: Array[(String, String)] = { + settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray + } /** Get a parameter as an integer, falling back to a default if not set */ def getInt(key: String, defaultValue: Int): Int = { @@ -225,11 +227,11 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getAppId: String = get("spark.app.id") /** Does the configuration contain a given parameter? */ - def contains(key: String): Boolean = settings.contains(key) + def contains(key: String): Boolean = settings.containsKey(key) /** Copy this object */ override def clone: SparkConf = { - new SparkConf(false).setAll(settings) + new SparkConf(false).setAll(getAll) } /** @@ -241,7 +243,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Checks for illegal or deprecated config settings. Throws an exception for the former. Not * idempotent - may mutate this conf object to convert deprecated settings to supported ones. */ private[spark] def validateSettings() { - if (settings.contains("spark.local.dir")) { + if (contains("spark.local.dir")) { val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " + "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)." logWarning(msg) @@ -266,7 +268,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } // Validate spark.executor.extraJavaOptions - settings.get(executorOptsKey).map { javaOpts => + getOption(executorOptsKey).map { javaOpts => if (javaOpts.contains("-Dspark")) { val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " + "Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit." @@ -346,7 +348,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { * configuration out for debugging. */ def toDebugString: String = { - settings.toArray.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") + getAll.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala index 1a28a9a187cd7..372d7aa453008 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -43,7 +43,7 @@ class WorkerArgumentsTest extends FunSuite { } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } val conf = new MySparkConf() @@ -62,7 +62,7 @@ class WorkerArgumentsTest extends FunSuite { } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } val conf = new MySparkConf() diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index dae7bf0e336de..8cf951adb354b 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -49,7 +49,7 @@ class LocalDirsSuite extends FunSuite { } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } // spark.local.dir only contains invalid directories, but that's not a problem since From c094c73270837602b875b18998133a2364a89e45 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 26 Jan 2015 13:07:49 -0800 Subject: [PATCH 27/33] [SPARK-5339][BUILD] build/mvn doesn't work because of invalid URL for maven's tgz. build/mvn will automatically download tarball of maven. But currently, the URL is invalid. Author: Kousuke Saruta Closes #4124 from sarutak/SPARK-5339 and squashes the following commits: 6e96121 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-5339 0e012d1 [Kousuke Saruta] Updated Maven version to 3.2.5 ca26499 [Kousuke Saruta] Fixed URL of the tarball of Maven --- build/mvn | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/build/mvn b/build/mvn index 43471f83e904c..f91e2b4bdcc02 100755 --- a/build/mvn +++ b/build/mvn @@ -68,10 +68,10 @@ install_app() { # Install maven under the build/ folder install_mvn() { install_app \ - "http://apache.claz.org/maven/maven-3/3.2.3/binaries" \ - "apache-maven-3.2.3-bin.tar.gz" \ - "apache-maven-3.2.3/bin/mvn" - MVN_BIN="${_DIR}/apache-maven-3.2.3/bin/mvn" + "http://archive.apache.org/dist/maven/maven-3/3.2.5/binaries" \ + "apache-maven-3.2.5-bin.tar.gz" \ + "apache-maven-3.2.5/bin/mvn" + MVN_BIN="${_DIR}/apache-maven-3.2.5/bin/mvn" } # Install zinc under the build/ folder From 54e7b456dd56c9e52132154e699abca87563465b Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 26 Jan 2015 14:23:42 -0800 Subject: [PATCH 28/33] SPARK-4147 [CORE] Reduce log4j dependency Defer use of log4j class until it's known that log4j 1.2 is being used. This may avoid dealing with log4j dependencies for callers that reroute slf4j to another logging framework. The only change is to push one half of the check in the original `if` condition inside. This is a trivial change, may or may not actually solve a problem, but I think it's all that makes sense to do for SPARK-4147. Author: Sean Owen Closes #4190 from srowen/SPARK-4147 and squashes the following commits: 4e99942 [Sean Owen] Defer use of log4j class until it's known that log4j 1.2 is being used. This may avoid dealing with log4j dependencies for callers that reroute slf4j to another logging framework. --- .../main/scala/org/apache/spark/Logging.scala | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index d4f2624061e35..419d093d55643 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -118,15 +118,17 @@ trait Logging { // org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently // org.apache.logging.slf4j.Log4jLoggerFactory val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) - val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements - if (!log4j12Initialized && usingLog4j12) { - val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") - case None => - System.err.println(s"Spark was unable to load $defaultLogProps") + if (usingLog4j12) { + val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements + if (!log4j12Initialized) { + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") + } } } Logging.initialized = true From b38034e878546a12c6d52f17fc961fd1a2453b97 Mon Sep 17 00:00:00 2001 From: "David Y. Ross" Date: Mon, 26 Jan 2015 14:26:10 -0800 Subject: [PATCH 29/33] Fix command spaces issue in make-distribution.sh Storing command in variables is tricky in bash, use an array to handle all issues with spaces, quoting, etc. See: http://mywiki.wooledge.org/BashFAQ/050 Author: David Y. Ross Closes #4126 from dyross/dyr-fix-make-distribution and squashes the following commits: 4ce522b [David Y. Ross] Fix command spaces issue in make-distribution.sh --- make-distribution.sh | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/make-distribution.sh b/make-distribution.sh index 4e2f400be3053..0adca7851819b 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -115,7 +115,7 @@ if which git &>/dev/null; then unset GITREV fi -if ! which $MVN &>/dev/null; then +if ! which "$MVN" &>/dev/null; then echo -e "Could not locate Maven command: '$MVN'." echo -e "Specify the Maven command with the --mvn flag" exit -1; @@ -171,13 +171,16 @@ cd "$SPARK_HOME" export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" -BUILD_COMMAND="$MVN clean package -DskipTests $@" +# Store the command as an array because $MVN variable might have spaces in it. +# Normal quoting tricks don't work. +# See: http://mywiki.wooledge.org/BashFAQ/050 +BUILD_COMMAND=("$MVN" clean package -DskipTests $@) # Actually build the jar echo -e "\nBuilding with..." -echo -e "\$ $BUILD_COMMAND\n" +echo -e "\$ ${BUILD_COMMAND[@]}\n" -${BUILD_COMMAND} +"${BUILD_COMMAND[@]}" # Make directories rm -rf "$DISTDIR" From 0497ea51ac345f8057d222a18dbbf8eae78f5b92 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 26 Jan 2015 14:32:27 -0800 Subject: [PATCH 30/33] SPARK-960 [CORE] [TEST] JobCancellationSuite "two jobs sharing the same stage" is broken This reenables and fixes this test, after addressing two issues: - The Semaphore that was intended to be shared locally was being serialized and copied; it's now a static member in the companion object as in other tests - Later changes to Spark means that cancelling the first task will not cancel the shared stage and therefore the second task should succeed Author: Sean Owen Closes #4180 from srowen/SPARK-960 and squashes the following commits: 43da66f [Sean Owen] Fix 'two jobs sharing the same stage' test and reenable it: truly share a Semaphore locally as intended, and update expectation of failure in non-cancelled task --- .../org/apache/spark/JobCancellationSuite.scala | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 7584ae79fc920..21487bc24d58a 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -171,11 +171,11 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter assert(jobB.get() === 100) } - ignore("two jobs sharing the same stage") { + test("two jobs sharing the same stage") { // sem1: make sure cancel is issued after some tasks are launched - // sem2: make sure the first stage is not finished until cancel is issued + // twoJobsSharingStageSemaphore: + // make sure the first stage is not finished until cancel is issued val sem1 = new Semaphore(0) - val sem2 = new Semaphore(0) sc = new SparkContext("local[2]", "test") sc.addSparkListener(new SparkListener { @@ -186,7 +186,7 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter // Create two actions that would share the some stages. val rdd = sc.parallelize(1 to 10, 2).map { i => - sem2.acquire() + JobCancellationSuite.twoJobsSharingStageSemaphore.acquire() (i, i) }.reduceByKey(_+_) val f1 = rdd.collectAsync() @@ -196,13 +196,13 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter future { sem1.acquire() f1.cancel() - sem2.release(10) + JobCancellationSuite.twoJobsSharingStageSemaphore.release(10) } - // Expect both to fail now. - // TODO: update this test when we change Spark so cancelling f1 wouldn't affect f2. + // Expect f1 to fail due to cancellation, intercept[SparkException] { f1.get() } - intercept[SparkException] { f2.get() } + // but f2 should not be affected + f2.get() } def testCount() { @@ -268,4 +268,5 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter object JobCancellationSuite { val taskStartedSemaphore = new Semaphore(0) val taskCancelledSemaphore = new Semaphore(0) + val twoJobsSharingStageSemaphore = new Semaphore(0) } From 661e0fca5d5d86efab5fb26da600ac2ac96b09ec Mon Sep 17 00:00:00 2001 From: Elmer Garduno Date: Mon, 26 Jan 2015 17:40:48 -0800 Subject: [PATCH 31/33] [SPARK-5052] Add common/base classes to fix guava methods signatures. Fixes problems with incorrect method signatures related to shaded classes. For discussion see the jira issue. Author: Elmer Garduno Closes #3874 from elmer-garduno/fix_guava_signatures and squashes the following commits: aa5d8e0 [Elmer Garduno] Unshade common/base[Function|Supplier] classes to fix guava methods signatures. --- assembly/pom.xml | 2 ++ core/pom.xml | 2 ++ 2 files changed, 4 insertions(+) diff --git a/assembly/pom.xml b/assembly/pom.xml index b2a9d0780ee2b..594fa0c779e1b 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -142,8 +142,10 @@ com/google/common/base/Absent* + com/google/common/base/Function com/google/common/base/Optional* com/google/common/base/Present* + com/google/common/base/Supplier diff --git a/core/pom.xml b/core/pom.xml index d9a49c9e08afc..1984682b9c099 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -372,8 +372,10 @@ com.google.guava:guava com/google/common/base/Absent* + com/google/common/base/Function com/google/common/base/Optional* com/google/common/base/Present* + com/google/common/base/Supplier From f2ba5c6fc3dde81a4d234c75dae2d4e3b46512d1 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Mon, 26 Jan 2015 18:03:21 -0800 Subject: [PATCH 32/33] [SPARK-5119] java.lang.ArrayIndexOutOfBoundsException on trying to train... ... decision tree model Labels loaded from libsvm files are mapped to 0.0 if they are negative labels because they should be nonnegative value. Author: lewuathe Closes #3975 from Lewuathe/map-negative-label-to-positive and squashes the following commits: 12d1d59 [lewuathe] [SPARK-5119] Fix code styles 6d9a18a [lewuathe] [SPARK-5119] Organize test codes 62a150c [lewuathe] [SPARK-5119] Modify Impurities throw exceptions with negatie labels 3336c21 [lewuathe] [SPARK-5119] java.lang.ArrayIndexOutOfBoundsException on trying to train decision tree model --- .../spark/mllib/tree/impurity/Entropy.scala | 5 +++ .../spark/mllib/tree/impurity/Gini.scala | 5 +++ .../spark/mllib/tree/ImpuritySuite.scala | 42 +++++++++++++++++++ 3 files changed, 52 insertions(+) create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 0e02345aa3774..b7950e00786ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -94,6 +94,10 @@ private[tree] class EntropyAggregator(numClasses: Int) throw new IllegalArgumentException(s"EntropyAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } + if (label < 0) { + throw new IllegalArgumentException(s"EntropyAggregator given label $label" + + s"but requires label is non-negative.") + } allStats(offset + label.toInt) += instanceWeight } @@ -147,6 +151,7 @@ private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalc val lbl = label.toInt require(lbl < stats.length, s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + require(lbl >= 0, "Entropy does not support negative labels") val cnt = count if (cnt == 0) { 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 7c83cd48e16a0..c946db9c0d1c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -90,6 +90,10 @@ private[tree] class GiniAggregator(numClasses: Int) throw new IllegalArgumentException(s"GiniAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } + if (label < 0) { + throw new IllegalArgumentException(s"GiniAggregator given label $label" + + s"but requires label is non-negative.") + } allStats(offset + label.toInt) += instanceWeight } @@ -143,6 +147,7 @@ private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcula val lbl = label.toInt require(lbl < stats.length, s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + require(lbl >= 0, "GiniImpurity does not support negative labels") val cnt = count if (cnt == 0) { 0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala new file mode 100644 index 0000000000000..92b498580af03 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. + */ +class ImpuritySuite extends FunSuite with MLlibTestSparkContext { + test("Gini impurity does not support negative labels") { + val gini = new GiniAggregator(2) + intercept[IllegalArgumentException] { + gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) + } + } + + test("Entropy does not support negative labels") { + val entropy = new EntropyAggregator(2) + intercept[IllegalArgumentException] { + entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) + } + } +} From d6894b1c5314c751cfdaf78005b99b2104e6e4d1 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Mon, 26 Jan 2015 19:46:17 -0800 Subject: [PATCH 33/33] [SPARK-3726] [MLlib] Allow sampling_rate not equal to 1.0 in RandomForests I've added support for sampling_rate not equal to 1.0 . I have two major questions. 1. A Scala style test is failing, since the number of parameters now exceed 10. 2. I would like suggestions to understand how to test this. Author: MechCoder Closes #4073 from MechCoder/spark-3726 and squashes the following commits: 8012fb2 [MechCoder] Add test in Strategy e0e0d9c [MechCoder] TST: Add better test d1df1b2 [MechCoder] Add test to verify subsampling behavior a7bfc70 [MechCoder] [SPARK-3726] Allow sampling_rate not equal to 1.0 --- .../apache/spark/mllib/tree/RandomForest.scala | 16 +++++----------- .../mllib/tree/configuration/Strategy.scala | 3 +++ .../spark/mllib/tree/RandomForestSuite.scala | 16 ++++++++++++++++ 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index e9304b5e5c650..482dd4b272d1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -140,6 +140,7 @@ private class RandomForest ( logDebug("maxBins = " + metadata.maxBins) logDebug("featureSubsetStrategy = " + featureSubsetStrategy) logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode) + logDebug("subsamplingRate = " + strategy.subsamplingRate) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. @@ -155,19 +156,12 @@ private class RandomForest ( // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) - val (subsample, withReplacement) = { - // TODO: Have a stricter check for RF in the strategy - val isRandomForest = numTrees > 1 - if (isRandomForest) { - (1.0, true) - } else { - (strategy.subsamplingRate, false) - } - } + val withReplacement = if (numTrees > 1) true else false val baggedInput - = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed) - .persist(StorageLevel.MEMORY_AND_DISK) + = BaggedPoint.convertToBaggedRDD(treeInput, + strategy.subsamplingRate, numTrees, + withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK) // depth of the decision tree val maxDepth = strategy.maxDepth diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 972959885f396..3308adb6752ff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -156,6 +156,9 @@ class Strategy ( s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") require(maxMemoryInMB <= 10240, s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") + require(subsamplingRate > 0 && subsamplingRate <= 1, + s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " + + s"$subsamplingRate") } /** Returns a shallow copy of this instance. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index f7f0f20c6c125..55e963977b54f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -196,6 +196,22 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { featureSubsetStrategy = "sqrt", seed = 12345) EnsembleTestHelper.validateClassifier(model, arr, 1.0) } + + test("subsampling rate in RandomForest"){ + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(5, 20) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClasses = 2, categoricalFeaturesInfo = Map.empty[Int, Int], + useNodeIdCache = true) + + val rf1 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3, + featureSubsetStrategy = "auto", seed = 123) + strategy.subsamplingRate = 0.5 + val rf2 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3, + featureSubsetStrategy = "auto", seed = 123) + assert(rf1.toDebugString != rf2.toDebugString) + } + }
    Property NameDefaultMeaning
    spark.driver.extraJavaOptions(none) + A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. +
    spark.driver.extraClassPath(none) + Extra classpath entries to append to the classpath of the driver. +
    spark.driver.extraLibraryPath(none) + Set a special library path to use when launching the driver JVM. +
    spark.executor.extraJavaOptions (none)OutputShuffle Read Blocked Time
    + {shuffleReadBlockedTimeReadable} + {shuffleReadReadable}