diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index b4b0067801259..74aa441619bd2 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -146,8 +146,9 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", "")) private val secretKey = generateSecretKey() - logInfo("SecurityManager, is authentication enabled: " + authOn + - " are ui acls enabled: " + uiAclsOn + " users with view permissions: " + viewAcls.toString()) + logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") + + "; ui acls " + (if (uiAclsOn) "enabled" else "disabled") + + "; users with view permissions: " + viewAcls.toString()) // Set our own authenticator to properly negotiate user/password for HTTP connections. // This is needed by the HTTP client fetching from the HttpServer. Put here so its diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index dc012cc381346..fc4812753d005 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -42,9 +42,13 @@ class TaskContext( // List of callback functions to execute when the task completes. @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit] + // Set to true when the task is completed, before the onCompleteCallbacks are executed. + @volatile var completed: Boolean = false + /** * Add a callback function to be executed on task completion. An example use * is for HadoopRDD to register a callback to close the input stream. + * Will be called in any situation - success, failure, or cancellation. * @param f Callback function. */ def addOnCompleteCallback(f: () => Unit) { @@ -52,6 +56,7 @@ class TaskContext( } def executeOnCompleteCallbacks() { + completed = true // Process complete callbacks in the reverse order of registration onCompleteCallbacks.reverse.foreach{_()} } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 61407007087c6..fecd9762f3f60 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -56,122 +56,37 @@ private[spark] class PythonRDD[T: ClassTag]( val env = SparkEnv.get val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) - // Ensure worker socket is closed on task completion. Closing sockets is idempotent. - context.addOnCompleteCallback(() => + // Start a thread to feed the process input from our parent's iterator + val writerThread = new WriterThread(env, worker, split, context) + + context.addOnCompleteCallback { () => + writerThread.shutdownOnTaskCompletion() + + // Cleanup the worker socket. This will also cause the Python worker to exit. try { worker.close() } catch { case e: Exception => logWarning("Failed to close worker socket", e) } - ) - - @volatile var readerException: Exception = null - - // Start a thread to feed the process input from our parent's iterator - new Thread("stdin writer for " + pythonExec) { - override def run() { - try { - SparkEnv.set(env) - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) - val dataOut = new DataOutputStream(stream) - // Partition index - dataOut.writeInt(split.index) - // sparkFilesDir - PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) - // Broadcast variables - dataOut.writeInt(broadcastVars.length) - for (broadcast <- broadcastVars) { - dataOut.writeLong(broadcast.id) - dataOut.writeInt(broadcast.value.length) - dataOut.write(broadcast.value) - } - // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.length) - for (include <- pythonIncludes) { - PythonRDD.writeUTF(include, dataOut) - } - dataOut.flush() - // Serialized command: - dataOut.writeInt(command.length) - dataOut.write(command) - // Data values - PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) - dataOut.flush() - worker.shutdownOutput() - } catch { - - case e: java.io.FileNotFoundException => - readerException = e - Try(worker.shutdownOutput()) // kill Python worker process - - case e: IOException => - // This can happen for legitimate reasons if the Python code stops returning data - // before we are done passing elements through, e.g., for take(). Just log a message to - // say it happened (as it could also be hiding a real IOException from a data source). - logInfo("stdin writer to Python finished early (may not be an error)", e) - - case e: Exception => - // We must avoid throwing exceptions here, because the thread uncaught exception handler - // will kill the whole executor (see Executor). - readerException = e - Try(worker.shutdownOutput()) // kill Python worker process - } - } - }.start() - - // Necessary to distinguish between a task that has failed and a task that is finished - @volatile var complete: Boolean = false - - // It is necessary to have a monitor thread for python workers if the user cancels with - // interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the - // threads can block indefinitely. - new Thread(s"Worker Monitor for $pythonExec") { - override def run() { - // Kill the worker if it is interrupted or completed - // When a python task completes, the context is always set to interupted - while (!context.interrupted) { - Thread.sleep(2000) - } - if (!complete) { - try { - logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.toMap) - } catch { - case e: Exception => - logError("Exception when trying to kill worker", e) - } - } - } - }.start() - - /* - * Partial fix for SPARK-1019: Attempts to stop reading the input stream since - * other completion callbacks might invalidate the input. Because interruption - * is not synchronous this still leaves a potential race where the interruption is - * processed only after the stream becomes invalid. - */ - context.addOnCompleteCallback{ () => - complete = true // Indicate that the task has completed successfully - context.interrupted = true } + writerThread.start() + new MonitorThread(env, worker, context).start() + // Return an iterator that read lines from the process's stdout val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) val stdoutIterator = new Iterator[Array[Byte]] { def next(): Array[Byte] = { val obj = _nextObj if (hasNext) { - // FIXME: can deadlock if worker is waiting for us to - // respond to current message (currently irrelevant because - // output is shutdown before we read any input) _nextObj = read() } obj } private def read(): Array[Byte] = { - if (readerException != null) { - throw readerException + if (writerThread.exception.isDefined) { + throw writerThread.exception.get } try { stream.readInt() match { @@ -190,13 +105,14 @@ private[spark] class PythonRDD[T: ClassTag]( val total = finishTime - startTime logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) - read + read() case SpecialLengths.PYTHON_EXCEPTION_THROWN => // Signals that an exception has been thrown in python val exLength = stream.readInt() val obj = new Array[Byte](exLength) stream.readFully(obj) - throw new PythonException(new String(obj, "utf-8"), readerException) + throw new PythonException(new String(obj, "utf-8"), + writerThread.exception.getOrElse(null)) case SpecialLengths.END_OF_DATA_SECTION => // We've finished the data section of the output, but we can still // read some accumulator updates: @@ -210,10 +126,15 @@ private[spark] class PythonRDD[T: ClassTag]( Array.empty[Byte] } } catch { - case e: Exception if readerException != null => + + case e: Exception if context.interrupted => + logDebug("Exception thrown after task interruption", e) + throw new TaskKilledException + + case e: Exception if writerThread.exception.isDefined => logError("Python worker exited unexpectedly (crashed)", e) - logError("Python crash may have been caused by prior exception:", readerException) - throw readerException + logError("This may have been caused by a prior exception:", writerThread.exception.get) + throw writerThread.exception.get case eof: EOFException => throw new SparkException("Python worker exited unexpectedly (crashed)", eof) @@ -224,10 +145,100 @@ private[spark] class PythonRDD[T: ClassTag]( def hasNext = _nextObj.length != 0 } - stdoutIterator + new InterruptibleIterator(context, stdoutIterator) } val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) + + /** + * The thread responsible for writing the data from the PythonRDD's parent iterator to the + * Python process. + */ + class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext) + extends Thread(s"stdout writer for $pythonExec") { + + @volatile private var _exception: Exception = null + + setDaemon(true) + + /** Contains the exception thrown while writing the parent iterator to the Python process. */ + def exception: Option[Exception] = Option(_exception) + + /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ + def shutdownOnTaskCompletion() { + assert(context.completed) + this.interrupt() + } + + override def run() { + try { + SparkEnv.set(env) + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val dataOut = new DataOutputStream(stream) + // Partition index + dataOut.writeInt(split.index) + // sparkFilesDir + PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) + // Broadcast variables + dataOut.writeInt(broadcastVars.length) + for (broadcast <- broadcastVars) { + dataOut.writeLong(broadcast.id) + dataOut.writeInt(broadcast.value.length) + dataOut.write(broadcast.value) + } + // Python includes (*.zip and *.egg files) + dataOut.writeInt(pythonIncludes.length) + for (include <- pythonIncludes) { + PythonRDD.writeUTF(include, dataOut) + } + dataOut.flush() + // Serialized command: + dataOut.writeInt(command.length) + dataOut.write(command) + // Data values + PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) + dataOut.flush() + } catch { + case e: Exception if context.completed || context.interrupted => + logDebug("Exception thrown after task completion (likely due to cleanup)", e) + + case e: Exception => + // We must avoid throwing exceptions here, because the thread uncaught exception handler + // will kill the whole executor (see org.apache.spark.executor.Executor). + _exception = e + } finally { + Try(worker.shutdownOutput()) // kill Python worker process + } + } + } + + /** + * It is necessary to have a monitor thread for python workers if the user cancels with + * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the + * threads can block indefinitely. + */ + class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) + extends Thread(s"Worker Monitor for $pythonExec") { + + setDaemon(true) + + override def run() { + // Kill the worker if it is interrupted, checking until task completion. + // TODO: This has a race condition if interruption occurs, as completed may still become true. + while (!context.interrupted && !context.completed) { + Thread.sleep(2000) + } + if (!context.completed) { + try { + logWarning("Incomplete task interrupted: Attempting to kill Python Worker") + env.destroyPythonWorker(pythonExec, envVars.toMap) + } catch { + case e: Exception => + logError("Exception when trying to kill worker", e) + } + } + } + } } /** Thrown for exceptions in user Python code. */ diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index cf69fa1d53fde..6d3e257c4d5df 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.python -import java.io.File +import java.io.{File, InputStream, IOException, OutputStream} import scala.collection.mutable.ArrayBuffer @@ -40,3 +40,28 @@ private[spark] object PythonUtils { paths.filter(_ != "").mkString(File.pathSeparator) } } + + +/** + * A utility class to redirect the child process's stdout or stderr. + */ +private[spark] class RedirectThread( + in: InputStream, + out: OutputStream, + name: String) + extends Thread(name) { + + setDaemon(true) + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME: We copy the stream on the level of bytes to avoid encoding problems. + val buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + out.write(buf, 0, len) + out.flush() + len = in.read(buf) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index b0bf4e052b3e9..002f2acd94dee 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -17,15 +17,18 @@ package org.apache.spark.api.python -import java.io.{DataInputStream, File, IOException, OutputStreamWriter} +import java.io.{DataInputStream, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} import scala.collection.JavaConversions._ import org.apache.spark._ +import org.apache.spark.util.Utils private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) - extends Logging { + extends Logging { + + import PythonWorkerFactory._ // Because forking processes from Java is expensive, we prefer to launch a single Python daemon // (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently @@ -38,7 +41,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String var daemonPort: Int = 0 val pythonPath = PythonUtils.mergePythonPaths( - PythonUtils.sparkPythonPath, envVars.getOrElse("PYTHONPATH", "")) + PythonUtils.sparkPythonPath, + envVars.getOrElse("PYTHONPATH", ""), + sys.env.getOrElse("PYTHONPATH", "")) def create(): Socket = { if (useDaemon) { @@ -61,12 +66,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { new Socket(daemonHost, daemonPort) } catch { - case exc: SocketException => { + case exc: SocketException => logWarning("Python daemon unexpectedly quit, attempting to restart") stopDaemon() startDaemon() new Socket(daemonHost, daemonPort) - } case e: Throwable => throw e } } @@ -87,39 +91,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String workerEnv.put("PYTHONPATH", pythonPath) val worker = pb.start() - // Redirect the worker's stderr to ours - new Thread("stderr reader for " + pythonExec) { - setDaemon(true) - override def run() { - scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME: We copy the stream on the level of bytes to avoid encoding problems. - val in = worker.getErrorStream - val buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - System.err.write(buf, 0, len) - len = in.read(buf) - } - } - } - }.start() - - // Redirect worker's stdout to our stderr - new Thread("stdout reader for " + pythonExec) { - setDaemon(true) - override def run() { - scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME: We copy the stream on the level of bytes to avoid encoding problems. - val in = worker.getInputStream - val buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - System.err.write(buf, 0, len) - len = in.read(buf) - } - } - } - }.start() + // Redirect worker stdout and stderr + redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream) // Tell the worker our port val out = new OutputStreamWriter(worker.getOutputStream) @@ -142,10 +115,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String null } - def stop() { - stopDaemon() - } - private def startDaemon() { synchronized { // Is it already running? @@ -161,46 +130,38 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String workerEnv.put("PYTHONPATH", pythonPath) daemon = pb.start() - // Redirect the stderr to ours - new Thread("stderr reader for " + pythonExec) { - setDaemon(true) - override def run() { - scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME: We copy the stream on the level of bytes to avoid encoding problems. - val in = daemon.getErrorStream - val buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - System.err.write(buf, 0, len) - len = in.read(buf) - } - } - } - }.start() - val in = new DataInputStream(daemon.getInputStream) daemonPort = in.readInt() - // Redirect further stdout output to our stderr - new Thread("stdout reader for " + pythonExec) { - setDaemon(true) - override def run() { - scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME: We copy the stream on the level of bytes to avoid encoding problems. - val buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - System.err.write(buf, 0, len) - len = in.read(buf) - } - } - } - }.start() + // Redirect daemon stdout and stderr + redirectStreamsToStderr(in, daemon.getErrorStream) + } catch { - case e: Throwable => { + case e: Exception => + + // If the daemon exists, wait for it to finish and get its stderr + val stderr = Option(daemon) + .flatMap { d => Utils.getStderr(d, PROCESS_WAIT_TIMEOUT_MS) } + .getOrElse("") + stopDaemon() - throw e - } + + if (stderr != "") { + val formattedStderr = stderr.replace("\n", "\n ") + val errorMessage = s""" + |Error from python worker: + | $formattedStderr + |PYTHONPATH was: + | $pythonPath + |$e""" + + // Append error message from python daemon, but keep original stack trace + val wrappedException = new SparkException(errorMessage.stripMargin) + wrappedException.setStackTrace(e.getStackTrace) + throw wrappedException + } else { + throw e + } } // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly @@ -208,6 +169,19 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } } + /** + * Redirect the given streams to our stderr in separate threads. + */ + private def redirectStreamsToStderr(stdout: InputStream, stderr: InputStream) { + try { + new RedirectThread(stdout, System.err, "stdout reader for " + pythonExec).start() + new RedirectThread(stderr, System.err, "stderr reader for " + pythonExec).start() + } catch { + case e: Exception => + logError("Exception in redirecting streams", e) + } + } + private def stopDaemon() { synchronized { // Request shutdown of existing daemon by sending SIGTERM @@ -219,4 +193,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String daemonPort = 0 } } + + def stop() { + stopDaemon() + } +} + +private object PythonWorkerFactory { + val PROCESS_WAIT_TIMEOUT_MS = 10000 } diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index f2e7c7a508b3f..e20d4486c8f0c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -17,13 +17,10 @@ package org.apache.spark.deploy -import java.io.{IOException, File, InputStream, OutputStream} - import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ -import org.apache.spark.SparkContext -import org.apache.spark.api.python.PythonUtils +import org.apache.spark.api.python.{PythonUtils, RedirectThread} /** * A main class used by spark-submit to launch Python applications. It executes python as a @@ -62,23 +59,4 @@ object PythonRunner { System.exit(process.waitFor()) } - - /** - * A utility class to redirect the child process's stdout or stderr - */ - class RedirectThread(in: InputStream, out: OutputStream, name: String) extends Thread(name) { - setDaemon(true) - override def run() { - scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME: We copy the stream on the level of bytes to avoid encoding problems. - val buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - out.write(buf, 0, len) - out.flush() - len = in.read(buf) - } - } - } - } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 3b3524f33e811..a1ca612cc9a09 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -128,7 +128,7 @@ abstract class RDD[T: ClassTag]( @transient var name: String = null /** Assign a name to this RDD */ - def setName(_name: String): RDD[T] = { + def setName(_name: String): this.type = { name = _name this } @@ -138,7 +138,7 @@ abstract class RDD[T: ClassTag]( * it is computed. This can only be used to assign a new storage level if the RDD does not * have a storage level set yet.. */ - def persist(newLevel: StorageLevel): RDD[T] = { + def persist(newLevel: StorageLevel): this.type = { // TODO: Handle changes of StorageLevel if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) { throw new UnsupportedOperationException( @@ -152,10 +152,10 @@ abstract class RDD[T: ClassTag]( } /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY) + def persist(): this.type = persist(StorageLevel.MEMORY_ONLY) /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - def cache(): RDD[T] = persist() + def cache(): this.type = persist() /** * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. @@ -163,7 +163,7 @@ abstract class RDD[T: ClassTag]( * @param blocking Whether to block until all blocks are deleted. * @return This RDD. */ - def unpersist(blocking: Boolean = true): RDD[T] = { + def unpersist(blocking: Boolean = true): this.type = { logInfo("Removing RDD " + id + " from persistence list") sc.unpersistRDD(id, blocking) storageLevel = StorageLevel.NONE diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 02b62de7e36b6..2259df0b56bad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -17,11 +17,13 @@ package org.apache.spark.scheduler +import scala.language.existentials + import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.mutable.HashMap -import scala.language.existentials +import scala.util.Try import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics @@ -196,7 +198,11 @@ private[spark] class ShuffleMapTask( } finally { // Release the writers back to the shuffle block manager. if (shuffle != null && shuffle.writers != null) { - shuffle.releaseWriters(success) + try { + shuffle.releaseWriters(success) + } catch { + case e: Exception => logError("Failed to release shuffle writers", e) + } } // Execute the callbacks on task completion. context.executeOnCompleteCallbacks() diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 202bd46956f87..3f0ed61c5bbfb 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1088,4 +1088,41 @@ private[spark] object Utils extends Logging { def stripDirectory(path: String): String = { path.split(File.separator).last } + + /** + * Wait for a process to terminate for at most the specified duration. + * Return whether the process actually terminated after the given timeout. + */ + def waitForProcess(process: Process, timeoutMs: Long): Boolean = { + var terminated = false + val startTime = System.currentTimeMillis + while (!terminated) { + try { + process.exitValue + terminated = true + } catch { + case e: IllegalThreadStateException => + // Process not terminated yet + if (System.currentTimeMillis - startTime > timeoutMs) { + return false + } + Thread.sleep(100) + } + } + true + } + + /** + * Return the stderr of a process after waiting for the process to terminate. + * If the process does not terminate within the specified timeout, return None. + */ + def getStderr(process: Process, timeoutMs: Long): Option[String] = { + val terminated = Utils.waitForProcess(process, timeoutMs) + if (terminated) { + Some(Source.fromInputStream(process.getErrorStream).getLines().mkString("\n")) + } else { + None + } + } + } diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 296277e58b341..acf0feff42a8d 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -93,17 +93,14 @@ The recursive tree construction is stopped at a node when one of the two conditi 1. The node depth is equal to the `maxDepth` training parameter 2. No split candidate leads to an information gain at the node. +### Max memory requirements + +For faster processing, the decision tree algorithm performs simultaneous histogram computations for all nodes at each level of the tree. This could lead to high memory requirements at deeper levels of the tree leading to memory overflow errors. To alleviate this problem, a 'maxMemoryInMB' training parameter is provided which specifies the maximum amount of memory at the workers (twice as much at the master) to be allocated to the histogram computation. The default value is conservatively chosen to be 128 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements for a level-wise computation crosses the `maxMemoryInMB` threshold, the node training tasks at each subsequent level is split into smaller tasks. + ### Practical limitations -1. The tree implementation stores an `Array[Double]` of size *O(#features \* #splits \* 2^maxDepth)* - in memory for aggregating histograms over partitions. The current implementation might not scale - to very deep trees since the memory requirement grows exponentially with tree depth. -2. The implemented algorithm reads both sparse and dense data. However, it is not optimized for - sparse input. -3. Python is not supported in this release. - -We are planning to solve these problems in the near future. Please drop us a line if you encounter -any issues. +1. The implemented algorithm reads both sparse and dense data. However, it is not optimized for sparse input. +2. Python is not supported in this release. ## Examples diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index a77dfb2577835..33700ab4f8c53 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -36,14 +36,13 @@ def rmse(R, ms, us): def update(i, vec, mat, ratings): uu = mat.shape[0] ff = mat.shape[1] - XtX = matrix(np.zeros((ff, ff))) - Xty = np.zeros((ff, 1)) - - for j in range(uu): - v = mat[j, :] - XtX += v.T * v - Xty += v.T * ratings[i, j] - XtX += np.eye(ff, ff) * LAMBDA * uu + + XtX = mat.T * mat + XtY = mat.T * ratings[i, :].T + + for j in range(ff): + XtX[j,j] += LAMBDA * uu + return np.linalg.solve(XtX, Xty) if __name__ == "__main__": diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 0bd847d7bab30..9832bec90d7ee 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -51,7 +51,7 @@ object DecisionTreeRunner { algo: Algo = Classification, maxDepth: Int = 5, impurity: ImpurityType = Gini, - maxBins: Int = 20) + maxBins: Int = 100) def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 703f02255b94b..0e4447e0de24f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -43,7 +43,8 @@ object MovieLensALS { kryo: Boolean = false, numIterations: Int = 20, lambda: Double = 1.0, - rank: Int = 10) + rank: Int = 10, + implicitPrefs: Boolean = false) def main(args: Array[String]) { val defaultParams = Params() @@ -62,6 +63,9 @@ object MovieLensALS { opt[Unit]("kryo") .text(s"use Kryo serialization") .action((_, c) => c.copy(kryo = true)) + opt[Unit]("implicitPrefs") + .text("use implicit preference") + .action((_, c) => c.copy(implicitPrefs = true)) arg[String]("") .required() .text("input paths to a MovieLens dataset of ratings") @@ -88,7 +92,25 @@ object MovieLensALS { val ratings = sc.textFile(params.input).map { line => val fields = line.split("::") - Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble) + if (params.implicitPrefs) { + /* + * MovieLens ratings are on a scale of 1-5: + * 5: Must see + * 4: Will enjoy + * 3: It's okay + * 2: Fairly bad + * 1: Awful + * So we should not recommend a movie if the predicted rating is less than 3. + * To map ratings to confidence scores, we use + * 5 -> 2.5, 4 -> 1.5, 3 -> 0.5, 2 -> -0.5, 1 -> -1.5. This mappings means unobserved + * entries are generally between It's okay and Fairly bad. + * The semantics of 0 in this expanded world of non-positive weights + * are "the same as never having interacted at all". + */ + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) + } else { + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble) + } }.cache() val numRatings = ratings.count() @@ -99,7 +121,18 @@ object MovieLensALS { val splits = ratings.randomSplit(Array(0.8, 0.2)) val training = splits(0).cache() - val test = splits(1).cache() + val test = if (params.implicitPrefs) { + /* + * 0 means "don't know" and positive values mean "confident that the prediction should be 1". + * Negative values means "confident that the prediction should be 0". + * We have in this case used some kind of weighted RMSE. The weight is the absolute value of + * the confidence. The error is the difference between prediction and either 1 or 0, + * depending on whether r is positive or negative. + */ + splits(1).map(x => Rating(x.user, x.product, if (x.rating > 0) 1.0 else 0.0)) + } else { + splits(1) + }.cache() val numTraining = training.count() val numTest = test.count() @@ -111,9 +144,10 @@ object MovieLensALS { .setRank(params.rank) .setIterations(params.numIterations) .setLambda(params.lambda) + .setImplicitPrefs(params.implicitPrefs) .run(training) - val rmse = computeRmse(model, test, numTest) + val rmse = computeRmse(model, test, params.implicitPrefs) println(s"Test RMSE = $rmse.") @@ -121,11 +155,14 @@ object MovieLensALS { } /** Compute RMSE (Root Mean Squared Error). */ - def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], n: Long) = { + def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean) = { + + def mapPredictedRating(r: Double) = if (implicitPrefs) math.max(math.min(r, 1.0), 0.0) else r + val predictions: RDD[Rating] = model.predict(data.map(x => (x.user, x.product))) - val predictionsAndRatings = predictions.map(x => ((x.user, x.product), x.rating)) - .join(data.map(x => ((x.user, x.product), x.rating))) - .values - math.sqrt(predictionsAndRatings.map(x => (x._1 - x._2) * (x._1 - x._2)).reduce(_ + _) / n) + val predictionsAndRatings = predictions.map{ x => + ((x.user, x.product), mapPredictedRating(x.rating)) + }.join(data.map(x => ((x.user, x.product), x.rating))).values + math.sqrt(predictionsAndRatings.map(x => (x._1 - x._2) * (x._1 - x._2)).mean()) } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index 6d04bf790e3a5..fa78ca99b8891 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -51,18 +51,12 @@ class EdgeRDD[@specialized ED: ClassTag]( override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() - override def persist(newLevel: StorageLevel): EdgeRDD[ED] = { + override def persist(newLevel: StorageLevel): this.type = { partitionsRDD.persist(newLevel) this } - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - override def persist(): EdgeRDD[ED] = persist(StorageLevel.MEMORY_ONLY) - - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - override def cache(): EdgeRDD[ED] = persist() - - override def unpersist(blocking: Boolean = true): EdgeRDD[ED] = { + override def unpersist(blocking: Boolean = true): this.type = { partitionsRDD.unpersist(blocking) this } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index d6788d4d4b9fd..f0fc605c88575 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -71,18 +71,12 @@ class VertexRDD[@specialized VD: ClassTag]( override protected def getPreferredLocations(s: Partition): Seq[String] = partitionsRDD.preferredLocations(s) - override def persist(newLevel: StorageLevel): VertexRDD[VD] = { + override def persist(newLevel: StorageLevel): this.type = { partitionsRDD.persist(newLevel) this } - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - override def persist(): VertexRDD[VD] = persist(StorageLevel.MEMORY_ONLY) - - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - override def cache(): VertexRDD[VD] = persist() - - override def unpersist(blocking: Boolean = true): VertexRDD[VD] = { + override def unpersist(blocking: Boolean = true): this.type = { partitionsRDD.unpersist(blocking) this } diff --git a/make-distribution.sh b/make-distribution.sh index ebcd8c74fc5a6..759e555b4b69a 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -189,7 +189,7 @@ if [ "$SPARK_TACHYON" == "true" ]; then TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'` pushd $TMPD > /dev/null - echo "Fetchting tachyon tgz" + echo "Fetching tachyon tgz" wget "$TACHYON_URL" tar xf "tachyon-${TACHYON_VERSION}-bin.tar.gz" diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 59ed01debf150..0fe30a3e7040b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -54,12 +54,13 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - logDebug("numSplits = " + bins(0).length) + val numBins = bins(0).length + logDebug("numBins = " + numBins) // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 + val maxNumNodes = math.pow(2, maxDepth).toInt - 1 // Initialize an array to hold filters applied to points for each node. val filters = new Array[List[Filter]](maxNumNodes) // The filter at the top node is an empty list. @@ -68,7 +69,28 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val parentImpurities = new Array[Double](maxNumNodes) // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) + // num features + val numFeatures = input.take(1)(0).features.size + + // Calculate level for single group construction + // Max memory usage for aggregates + val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 + logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") + val numElementsPerNode = + strategy.algo match { + case Classification => 2 * numBins * numFeatures + case Regression => 3 * numBins * numFeatures + } + + logDebug("numElementsPerNode = " + numElementsPerNode) + val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array + val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1) + logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup) + // nodes at a level is 2^level. level is zero indexed. + val maxLevelForSingleGroup = math.max( + (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0) + logDebug("max level for single group = " + maxLevelForSingleGroup) /* * The main idea here is to perform level-wise training of the decision tree nodes thus @@ -88,7 +110,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find best split for all nodes at a level. val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, - level, filters, splits, bins) + level, filters, splits, bins, maxLevelForSingleGroup) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { // Extract info for nodes at the current level. @@ -98,7 +120,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo filters) logDebug("final best split = " + nodeSplitStats._1) } - require(scala.math.pow(2, level) == splitsStatsForLevel.length) + require(math.pow(2, level) == splitsStatsForLevel.length) // Check whether all the nodes at the current level at leaves. val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) @@ -109,6 +131,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } } + logDebug("#####################################") + logDebug("Extracting tree model") + logDebug("#####################################") + // Initialize the top or root node of the tree. val topNode = nodes(0) // Build the full tree using the node info calculated in the level-wise best split calculations. @@ -127,7 +153,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo nodes: Array[Node]): Unit = { val split = nodeSplitStats._1 val stats = nodeSplitStats._2 - val nodeIndex = scala.math.pow(2, level).toInt - 1 + index + val nodeIndex = math.pow(2, level).toInt - 1 + index val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) @@ -148,7 +174,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo var i = 0 while (i <= 1) { // Calculate the index of the node from the node level and the index at the current level. - val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i + val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i if (level < maxDepth - 1) { val impurity = if (i == 0) { nodeSplitStats._2.leftImpurity @@ -249,7 +275,8 @@ object DecisionTree extends Serializable with Logging { private val InvalidBinIndex = -1 /** - * Returns an array of optimal splits for all nodes at a given level + * Returns an array of optimal splits for all nodes at a given level. Splits the task into + * multiple groups if the level-wise training task could lead to memory overflow. * * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree @@ -260,6 +287,7 @@ object DecisionTree extends Serializable with Logging { * @param filters Filters for all nodes at a given level * @param splits possible splits for all features * @param bins possible bins for all features + * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. * @return array of splits with best splits for all nodes at a given level. */ protected[tree] def findBestSplits( @@ -269,7 +297,57 @@ object DecisionTree extends Serializable with Logging { level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], - bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = { + bins: Array[Array[Bin]], + maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { + // split into groups to avoid memory overflow during aggregation + if (level > maxLevelForSingleGroup) { + // When information for all nodes at a given level cannot be stored in memory, + // the nodes are divided into multiple groups at each level with the number of groups + // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, + // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. + val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt + logDebug("numGroups = " + numGroups) + var bestSplits = new Array[(Split, InformationGainStats)](0) + // Iterate over each group of nodes at a level. + var groupIndex = 0 + while (groupIndex < numGroups) { + val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, + filters, splits, bins, numGroups, groupIndex) + bestSplits = Array.concat(bestSplits, bestSplitsForGroup) + groupIndex += 1 + } + bestSplits + } else { + findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins) + } + } + + /** + * Returns an array of optimal splits for a group of nodes at a given level + * + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param parentImpurities Impurities for all parent nodes for the current level + * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + * parameters for construction the DecisionTree + * @param level Level of the tree + * @param filters Filters for all nodes at a given level + * @param splits possible splits for all features + * @param bins possible bins for all features + * @param numGroups total number of node groups at the current level. Default value is set to 1. + * @param groupIndex index of the node group being processed. Default value is set to 0. + * @return array of splits with best splits for all nodes at a given level. + */ + private def findBestSplitsPerGroup( + input: RDD[LabeledPoint], + parentImpurities: Array[Double], + strategy: Strategy, + level: Int, + filters: Array[List[Filter]], + splits: Array[Array[Split]], + bins: Array[Array[Bin]], + numGroups: Int = 1, + groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { /* * The high-level description for the best split optimizations are noted here. @@ -296,7 +374,7 @@ object DecisionTree extends Serializable with Logging { */ // common calculations for multiple nested methods - val numNodes = scala.math.pow(2, level).toInt + val numNodes = math.pow(2, level).toInt / numGroups logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. val numFeatures = input.first().features.size @@ -304,12 +382,15 @@ object DecisionTree extends Serializable with Logging { val numBins = bins(0).length logDebug("numBins = " + numBins) + // shift when more than one group is used at deep tree level + val groupShift = numNodes * groupIndex + /** Find the filters used before reaching the current code. */ def findParentFilters(nodeIndex: Int): List[Filter] = { if (level == 0) { List[Filter]() } else { - val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex + val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift filters(nodeFilterIndex) } } @@ -878,7 +959,7 @@ object DecisionTree extends Serializable with Logging { // Iterating over all nodes at this level var node = 0 while (node < numNodes) { - val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node + val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 8767aca47cd5a..1b505fd76eb75 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -35,6 +35,9 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * k) implies the feature n is categorical with k categories 0, * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. + * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is + * 128 MB. + * */ @Experimental class Strategy ( @@ -43,4 +46,5 @@ class Strategy ( val maxDepth: Int, val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, - val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()) extends Serializable + val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + val maxMemoryInMB: Int = 128) extends Serializable diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index c4b433499a091..8a16284118cf7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -81,11 +81,11 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa // Add a extra variable consisting of all 1.0's for the intercept. val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42) val data = testData.map { case LabeledPoint(label, features) => - label -> Vectors.dense(1.0, features.toArray: _*) + label -> Vectors.dense(1.0 +: features.toArray) } val dataRDD = sc.parallelize(data, 2).cache() - val initialWeightsWithIntercept = Vectors.dense(1.0, initialWeights: _*) + val initialWeightsWithIntercept = Vectors.dense(1.0 +: initialWeights.toArray) val (_, loss) = GradientDescent.runMiniBatchSGD( dataRDD, @@ -111,7 +111,7 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa // Add a extra variable consisting of all 1.0's for the intercept. val testData = GradientDescentSuite.generateGDInput(2.0, -1.5, 10000, 42) val data = testData.map { case LabeledPoint(label, features) => - label -> Vectors.dense(1.0, features.toArray: _*) + label -> Vectors.dense(1.0 +: features.toArray) } val dataRDD = sc.parallelize(data, 2).cache() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index be383aab714d3..35e92d71dc63f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -22,7 +22,8 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.Filter -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.model.Split +import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.linalg.Vectors @@ -242,7 +243,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -269,7 +270,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -298,7 +299,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -321,7 +322,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -345,7 +346,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -369,7 +370,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -378,13 +379,60 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.rightImpurity === 0) assert(bestSplits(0)._2.predict === 1) } + + test("test second level node building with/without groups") { + val arr = DecisionTreeSuite.generateOrderedLabeledPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1) + val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1) + val filters = Array[List[Filter]](List(), List(leftFilter), List(rightFilter)) + val parentImpurities = Array(0.5, 0.5, 0.5) + + // Single group second level tree construction. + val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters, + splits, bins, 10) + assert(bestSplits.length === 2) + assert(bestSplits(0)._2.gain > 0) + assert(bestSplits(1)._2.gain > 0) + + // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second + // level tree construction. + val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, + filters, splits, bins, 0) + assert(bestSplitsWithGroups.length === 2) + assert(bestSplitsWithGroups(0)._2.gain > 0) + assert(bestSplitsWithGroups(1)._2.gain > 0) + + // Verify whether the splits obtained using single group and multiple group level + // construction strategies are the same. + for (i <- 0 until bestSplits.length) { + assert(bestSplits(i)._1 === bestSplitsWithGroups(i)._1) + assert(bestSplits(i)._2.gain === bestSplitsWithGroups(i)._2.gain) + assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity) + assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity) + assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity) + assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict) + } + + } + } object DecisionTreeSuite { def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } @@ -393,17 +441,31 @@ object DecisionTreeSuite { def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i)) arr(i) = lp } arr } + def generateOrderedLabeledPoints(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000) { + if (i < 600) { + val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) + arr(i) = lp + } else { + val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i)) + arr(i) = lp + } + } + arr + } + def generateCategoricalDataPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ - if (i < 600){ + for (i <- 0 until 1000) { + if (i < 600) { arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) } else { arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0)) diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index d540dc0a986e9..efdb38e907d14 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -74,6 +74,8 @@ object MimaBuild { ) ++ excludeSparkClass("rdd.ClassTags") ++ excludeSparkClass("util.XORShiftRandom") ++ + excludeSparkClass("graphx.EdgeRDD") ++ + excludeSparkClass("graphx.VertexRDD") ++ excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ excludeSparkClass("mllib.optimization.SquaredGradient") ++ excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ diff --git a/python/pyspark/context.py b/python/pyspark/context.py index c7dc85ea03544..cac133d0fcf6c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -453,7 +453,7 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): >>> lock = threading.Lock() >>> def map_func(x): ... sleep(100) - ... return x * x + ... raise Exception("Task should have been cancelled") >>> def start_job(x): ... global result ... try: diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index eb18ec08c9139..b2f226a55ec13 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -74,6 +74,17 @@ def handle_sigchld(*args): raise signal.signal(SIGCHLD, handle_sigchld) + # Blocks until the socket is closed by draining the input stream + # until it raises an exception or returns EOF. + def waitSocketClose(sock): + try: + while True: + # Empty string is returned upon EOF (and only then). + if sock.recv(4096) == '': + return + except: + pass + # Handle clients while not should_exit(): # Wait until a client arrives or we have to exit @@ -105,7 +116,8 @@ def handle_sigchld(*args): exit_code = exc.code finally: outfile.flush() - sock.close() + # The Scala side will close the socket upon task completion. + waitSocketClose(sock) os._exit(compute_real_exit_code(exit_code)) else: sock.close() diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 0aa3a51de706b..7511ca7573ddb 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -49,8 +49,7 @@ def __init__(self, size, *args): >>> print SparseVector(4, [1, 3], [1.0, 5.5]) [1: 1.0, 3: 5.5] """ - assert type(size) == int, "first argument must be an int" - self.size = size + self.size = int(size) assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments" if len(args) == 1: pairs = args[0] diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py new file mode 100644 index 0000000000000..50d0cdd087625 --- /dev/null +++ b/python/pyspark/mllib/util.py @@ -0,0 +1,177 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np + +from pyspark.mllib.linalg import Vectors, SparseVector +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib._common import _convert_vector + +class MLUtils: + """ + Helper methods to load, save and pre-process data used in MLlib. + """ + + @staticmethod + def _parse_libsvm_line(line, multiclass): + """ + Parses a line in LIBSVM format into (label, indices, values). + """ + items = line.split(None) + label = float(items[0]) + if not multiclass: + label = 1.0 if label > 0.5 else 0.0 + nnz = len(items) - 1 + indices = np.zeros(nnz, dtype=np.int32) + values = np.zeros(nnz) + for i in xrange(nnz): + index, value = items[1 + i].split(":") + indices[i] = int(index) - 1 + values[i] = float(value) + return label, indices, values + + + @staticmethod + def _convert_labeled_point_to_libsvm(p): + """Converts a LabeledPoint to a string in LIBSVM format.""" + items = [str(p.label)] + v = _convert_vector(p.features) + if type(v) == np.ndarray: + for i in xrange(len(v)): + items.append(str(i + 1) + ":" + str(v[i])) + elif type(v) == SparseVector: + nnz = len(v.indices) + for i in xrange(nnz): + items.append(str(v.indices[i] + 1) + ":" + str(v.values[i])) + else: + raise TypeError("_convert_labeled_point_to_libsvm needs either ndarray or SparseVector" + " but got " % type(v)) + return " ".join(items) + + + @staticmethod + def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=None): + """ + Loads labeled data in the LIBSVM format into an RDD of + LabeledPoint. The LIBSVM format is a text-based format used by + LIBSVM and LIBLINEAR. Each line represents a labeled sparse + feature vector using the following format: + + label index1:value1 index2:value2 ... + + where the indices are one-based and in ascending order. This + method parses each line into a LabeledPoint, where the feature + indices are converted to zero-based. + + @param sc: Spark context + @param path: file or directory path in any Hadoop-supported file + system URI + @param multiclass: whether the input labels contain more than + two classes. If false, any label with value + greater than 0.5 will be mapped to 1.0, or + 0.0 otherwise. So it works for both +1/-1 and + 1/0 cases. If true, the double value parsed + directly from the label string will be used + as the label value. + @param numFeatures: number of features, which will be determined + from the input data if a nonpositive value + is given. This is useful when the dataset is + already split into multiple files and you + want to load them separately, because some + features may not present in certain files, + which leads to inconsistent feature + dimensions. + @param minPartitions: min number of partitions + @return: labeled data stored as an RDD of LabeledPoint + + >>> from tempfile import NamedTemporaryFile + >>> from pyspark.mllib.util import MLUtils + >>> tempFile = NamedTemporaryFile(delete=True) + >>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") + >>> tempFile.flush() + >>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() + >>> multiclass_examples = MLUtils.loadLibSVMFile(sc, tempFile.name, True).collect() + >>> tempFile.close() + >>> examples[0].label + 1.0 + >>> examples[0].features.size + 6 + >>> print examples[0].features + [0: 1.0, 2: 2.0, 4: 3.0] + >>> examples[1].label + 0.0 + >>> examples[1].features.size + 6 + >>> print examples[1].features + [] + >>> examples[2].label + 0.0 + >>> examples[2].features.size + 6 + >>> print examples[2].features + [1: 4.0, 3: 5.0, 5: 6.0] + >>> multiclass_examples[1].label + -1.0 + """ + + lines = sc.textFile(path, minPartitions) + parsed = lines.map(lambda l: MLUtils._parse_libsvm_line(l, multiclass)) + if numFeatures <= 0: + parsed.cache() + numFeatures = parsed.map(lambda x: 0 if x[1].size == 0 else x[1][-1]).reduce(max) + 1 + return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2]))) + + + @staticmethod + def saveAsLibSVMFile(data, dir): + """ + Save labeled data in LIBSVM format. + + @param data: an RDD of LabeledPoint to be saved + @param dir: directory to save the data + + >>> from tempfile import NamedTemporaryFile + >>> from fileinput import input + >>> from glob import glob + >>> from pyspark.mllib.util import MLUtils + >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \ + LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] + >>> tempFile = NamedTemporaryFile(delete=True) + >>> tempFile.close() + >>> MLUtils.saveAsLibSVMFile(sc.parallelize(examples), tempFile.name) + >>> ''.join(sorted(input(glob(tempFile.name + "/part-0000*")))) + '0.0 1:1.01 2:2.02 3:3.03\\n1.1 1:1.23 3:4.56\\n' + """ + lines = data.map(lambda p: MLUtils._convert_labeled_point_to_libsvm(p)) + lines.saveAsTextFile(dir) + + +def _test(): + import doctest + from pyspark.context import SparkContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 1a62031db5c41..6789d7002b3b7 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -360,6 +360,35 @@ def getCheckpointFile(self): else: return None + def coalesce(self, numPartitions, shuffle=False): + rdd = self._jschema_rdd.coalesce(numPartitions, shuffle) + return SchemaRDD(rdd, self.sql_ctx) + + def distinct(self): + rdd = self._jschema_rdd.distinct() + return SchemaRDD(rdd, self.sql_ctx) + + def intersection(self, other): + if (other.__class__ is SchemaRDD): + rdd = self._jschema_rdd.intersection(other._jschema_rdd) + return SchemaRDD(rdd, self.sql_ctx) + else: + raise ValueError("Can only intersect with another SchemaRDD") + + def repartition(self, numPartitions): + rdd = self._jschema_rdd.repartition(numPartitions) + return SchemaRDD(rdd, self.sql_ctx) + + def subtract(self, other, numPartitions=None): + if (other.__class__ is SchemaRDD): + if numPartitions is None: + rdd = self._jschema_rdd.subtract(other._jschema_rdd) + else: + rdd = self._jschema_rdd.subtract(other._jschema_rdd, numPartitions) + return SchemaRDD(rdd, self.sql_ctx) + else: + raise ValueError("Can only subtract another SchemaRDD") + def _test(): import doctest from pyspark.context import SparkContext diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 8c76a3aa96546..b3a3a1ef1b5eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -114,6 +114,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val JOIN = Keyword("JOIN") protected val LEFT = Keyword("LEFT") protected val LIMIT = Keyword("LIMIT") + protected val MAX = Keyword("MAX") + protected val MIN = Keyword("MIN") protected val NOT = Keyword("NOT") protected val NULL = Keyword("NULL") protected val ON = Keyword("ON") @@ -318,6 +320,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers { COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } | FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } | AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } | + MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } | + MAX ~> "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } | IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { case c ~ "," ~ t ~ "," ~ f => If(c,t,f) } | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 40d2b42a0cda3..0b3a4e728ec54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -182,7 +182,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) } - def cast: Any => Any = dataType match { + private lazy val cast: Any => Any = dataType match { case StringType => castToString case BinaryType => castToBinary case DecimalType => castToDecimal diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index b152f95f96c70..7777d372903e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -86,6 +86,67 @@ abstract class AggregateFunction override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray) } +case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def references = child.references + override def nullable = child.nullable + override def dataType = child.dataType + override def toString = s"MIN($child)" + + override def asPartial: SplitEvaluation = { + val partialMin = Alias(Min(child), "PartialMin")() + SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil) + } + + override def newInstance() = new MinFunction(child, this) +} + +case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { + def this() = this(null, null) // Required for serialization. + + var currentMin: Any = _ + + override def update(input: Row): Unit = { + if (currentMin == null) { + currentMin = expr.eval(input) + } else if(GreaterThan(Literal(currentMin, expr.dataType), expr).eval(input) == true) { + currentMin = expr.eval(input) + } + } + + override def eval(input: Row): Any = currentMin +} + +case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def references = child.references + override def nullable = child.nullable + override def dataType = child.dataType + override def toString = s"MAX($child)" + + override def asPartial: SplitEvaluation = { + val partialMax = Alias(Max(child), "PartialMax")() + SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil) + } + + override def newInstance() = new MaxFunction(child, this) +} + +case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { + def this() = this(null, null) // Required for serialization. + + var currentMax: Any = _ + + override def update(input: Row): Unit = { + if (currentMax == null) { + currentMax = expr.eval(input) + } else if(LessThan(Literal(currentMax, expr.dataType), expr).eval(input) == true) { + currentMax = expr.eval(input) + } + } + + override def eval(input: Row): Any = currentMax +} + + case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { override def references = child.references override def nullable = false @@ -97,7 +158,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil) } - override def newInstance()= new CountFunction(child, this) + override def newInstance() = new CountFunction(child, this) } case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression { @@ -106,7 +167,7 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi override def nullable = false override def dataType = IntegerType override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})" - override def newInstance()= new CountDistinctFunction(expressions, this) + override def newInstance() = new CountDistinctFunction(expressions, this) } case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { @@ -126,7 +187,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN partialCount :: partialSum :: Nil) } - override def newInstance()= new AverageFunction(child, this) + override def newInstance() = new AverageFunction(child, this) } case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { @@ -142,7 +203,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ partialSum :: Nil) } - override def newInstance()= new SumFunction(child, this) + override def newInstance() = new SumFunction(child, this) } case class SumDistinct(child: Expression) @@ -153,7 +214,7 @@ case class SumDistinct(child: Expression) override def dataType = child.dataType override def toString = s"SUM(DISTINCT $child)" - override def newInstance()= new SumDistinctFunction(child, this) + override def newInstance() = new SumDistinctFunction(child, this) } case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { @@ -168,7 +229,7 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod First(partialFirst.toAttribute), partialFirst :: Nil) } - override def newInstance()= new FirstFunction(child, this) + override def newInstance() = new FirstFunction(child, this) } case class AverageFunction(expr: Expression, base: AggregateExpression) @@ -176,11 +237,13 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) def this() = this(null, null) // Required for serialization. + private val zero = Cast(Literal(0), expr.dataType) + private var count: Long = _ - private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(EmptyRow)) + private val sum = MutableLiteral(zero.eval(EmptyRow)) private val sumAsDouble = Cast(sum, DoubleType) - private val addFunction = Add(sum, expr) + private val addFunction = Add(sum, Coalesce(Seq(expr, zero))) override def eval(input: Row): Any = sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble @@ -209,9 +272,11 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { def this() = this(null, null) // Required for serialization. - private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(null)) + private val zero = Cast(Literal(0), expr.dataType) + + private val sum = MutableLiteral(zero.eval(null)) - private val addFunction = Add(sum, expr) + private val addFunction = Add(sum, Coalesce(Seq(expr, zero))) override def update(input: Row): Unit = { sum.update(addFunction, input) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index d7782d6b32819..34200be3ac955 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -19,14 +19,16 @@ package org.apache.spark.sql import net.razorvine.pickle.Pickler -import org.apache.spark.{Dependency, OneToOneDependency, Partition, TaskContext} +import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext} import org.apache.spark.annotation.{AlphaComponent, Experimental} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.api.java.JavaSchemaRDD import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.types.BooleanType +import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} import org.apache.spark.api.java.JavaRDD import java.util.{Map => JMap} @@ -296,6 +298,13 @@ class SchemaRDD( */ def toSchemaRDD = this + /** + * Returns this RDD as a JavaSchemaRDD. + * + * @group schema + */ + def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan) + private[sql] def javaToPython: JavaRDD[Array[Byte]] = { val fieldNames: Seq[String] = this.queryExecution.analyzed.output.map(_.name) this.mapPartitions { iter => @@ -314,4 +323,60 @@ class SchemaRDD( } } } + + /** + * Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value + * of base RDD functions that do not change schema. + * + * @param rdd RDD derived from this one and has same schema + * + * @group schema + */ + private def applySchema(rdd: RDD[Row]): SchemaRDD = { + new SchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(logicalPlan.output, rdd))) + } + + // ======================================================================= + // Base RDD functions that do NOT change schema + // ======================================================================= + + // Transformations (return a new RDD) + + override def coalesce(numPartitions: Int, shuffle: Boolean = false) + (implicit ord: Ordering[Row] = null): SchemaRDD = + applySchema(super.coalesce(numPartitions, shuffle)(ord)) + + override def distinct(): SchemaRDD = + applySchema(super.distinct()) + + override def distinct(numPartitions: Int) + (implicit ord: Ordering[Row] = null): SchemaRDD = + applySchema(super.distinct(numPartitions)(ord)) + + override def filter(f: Row => Boolean): SchemaRDD = + applySchema(super.filter(f)) + + override def intersection(other: RDD[Row]): SchemaRDD = + applySchema(super.intersection(other)) + + override def intersection(other: RDD[Row], partitioner: Partitioner) + (implicit ord: Ordering[Row] = null): SchemaRDD = + applySchema(super.intersection(other, partitioner)(ord)) + + override def intersection(other: RDD[Row], numPartitions: Int): SchemaRDD = + applySchema(super.intersection(other, numPartitions)) + + override def repartition(numPartitions: Int) + (implicit ord: Ordering[Row] = null): SchemaRDD = + applySchema(super.repartition(numPartitions)(ord)) + + override def subtract(other: RDD[Row]): SchemaRDD = + applySchema(super.subtract(other)) + + override def subtract(other: RDD[Row], numPartitions: Int): SchemaRDD = + applySchema(super.subtract(other, numPartitions)) + + override def subtract(other: RDD[Row], p: Partitioner) + (implicit ord: Ordering[Row] = null): SchemaRDD = + applySchema(super.subtract(other, p)(ord)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index d43d672938f51..22f57b758dd02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.api.java +import org.apache.spark.Partitioner import org.apache.spark.api.java.{JavaRDDLike, JavaRDD} +import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.sql.{SQLContext, SchemaRDD, SchemaRDDLike} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel /** * An RDD of [[Row]] objects that is returned as the result of a Spark SQL query. In addition to @@ -45,4 +48,141 @@ class JavaSchemaRDD( override def wrapRDD(rdd: RDD[Row]): JavaRDD[Row] = JavaRDD.fromRDD(rdd) val rdd = baseSchemaRDD.map(new Row(_)) + + override def toString: String = baseSchemaRDD.toString + + // ======================================================================= + // Base RDD functions that do NOT change schema + // ======================================================================= + + // Common RDD functions + + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + def cache(): JavaSchemaRDD = { + baseSchemaRDD.cache() + this + } + + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + def persist(): JavaSchemaRDD = { + baseSchemaRDD.persist() + this + } + + /** + * Set this RDD's storage level to persist its values across operations after the first time + * it is computed. This can only be used to assign a new storage level if the RDD does not + * have a storage level set yet.. + */ + def persist(newLevel: StorageLevel): JavaSchemaRDD = { + baseSchemaRDD.persist(newLevel) + this + } + + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * + * @param blocking Whether to block until all blocks are deleted. + * @return This RDD. + */ + def unpersist(blocking: Boolean = true): JavaSchemaRDD = { + baseSchemaRDD.unpersist(blocking) + this + } + + /** Assign a name to this RDD */ + def setName(name: String): JavaSchemaRDD = { + baseSchemaRDD.setName(name) + this + } + + // Transformations (return a new RDD) + + /** + * Return a new RDD that is reduced into `numPartitions` partitions. + */ + def coalesce(numPartitions: Int, shuffle: Boolean = false): JavaSchemaRDD = + baseSchemaRDD.coalesce(numPartitions, shuffle).toJavaSchemaRDD + + /** + * Return a new RDD containing the distinct elements in this RDD. + */ + def distinct(): JavaSchemaRDD = + baseSchemaRDD.distinct().toJavaSchemaRDD + + /** + * Return a new RDD containing the distinct elements in this RDD. + */ + def distinct(numPartitions: Int): JavaSchemaRDD = + baseSchemaRDD.distinct(numPartitions).toJavaSchemaRDD + + /** + * Return a new RDD containing only the elements that satisfy a predicate. + */ + def filter(f: JFunction[Row, java.lang.Boolean]): JavaSchemaRDD = + baseSchemaRDD.filter(x => f.call(new Row(x)).booleanValue()).toJavaSchemaRDD + + /** + * Return the intersection of this RDD and another one. The output will not contain any + * duplicate elements, even if the input RDDs did. + * + * Note that this method performs a shuffle internally. + */ + def intersection(other: JavaSchemaRDD): JavaSchemaRDD = + this.baseSchemaRDD.intersection(other.baseSchemaRDD).toJavaSchemaRDD + + /** + * Return the intersection of this RDD and another one. The output will not contain any + * duplicate elements, even if the input RDDs did. + * + * Note that this method performs a shuffle internally. + * + * @param partitioner Partitioner to use for the resulting RDD + */ + def intersection(other: JavaSchemaRDD, partitioner: Partitioner): JavaSchemaRDD = + this.baseSchemaRDD.intersection(other.baseSchemaRDD, partitioner).toJavaSchemaRDD + + /** + * Return the intersection of this RDD and another one. The output will not contain any + * duplicate elements, even if the input RDDs did. Performs a hash partition across the cluster + * + * Note that this method performs a shuffle internally. + * + * @param numPartitions How many partitions to use in the resulting RDD + */ + def intersection(other: JavaSchemaRDD, numPartitions: Int): JavaSchemaRDD = + this.baseSchemaRDD.intersection(other.baseSchemaRDD, numPartitions).toJavaSchemaRDD + + /** + * Return a new RDD that has exactly `numPartitions` partitions. + * + * Can increase or decrease the level of parallelism in this RDD. Internally, this uses + * a shuffle to redistribute data. + * + * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, + * which can avoid performing a shuffle. + */ + def repartition(numPartitions: Int): JavaSchemaRDD = + baseSchemaRDD.repartition(numPartitions).toJavaSchemaRDD + + /** + * Return an RDD with the elements from `this` that are not in `other`. + * + * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting + * RDD will be <= us. + */ + def subtract(other: JavaSchemaRDD): JavaSchemaRDD = + this.baseSchemaRDD.subtract(other.baseSchemaRDD).toJavaSchemaRDD + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: JavaSchemaRDD, numPartitions: Int): JavaSchemaRDD = + this.baseSchemaRDD.subtract(other.baseSchemaRDD, numPartitions).toJavaSchemaRDD + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: JavaSchemaRDD, p: Partitioner): JavaSchemaRDD = + this.baseSchemaRDD.subtract(other.baseSchemaRDD, p).toJavaSchemaRDD } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index dde957d715a28..e966d89c30cf5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -50,6 +50,13 @@ class SQLQuerySuite extends QueryTest { Seq((1,3),(2,3),(3,3))) } + test("aggregates with nulls") { + checkAnswer( + sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), + (1, 3, 2, 6, 3) :: Nil + ) + } + test("select *") { checkAnswer( sql("SELECT * FROM testData"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index b5973c0f51be8..aa71e274f7f4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -84,4 +84,14 @@ object TestData { List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) nullableRepeatedData.registerAsTable("nullableRepeatedData") + + case class NullInts(a: Integer) + val nullInts = + TestSQLContext.sparkContext.parallelize( + NullInts(1) :: + NullInts(2) :: + NullInts(3) :: + NullInts(null) :: Nil + ) + nullInts.registerAsTable("nullInts") } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index 96f8aa93394f5..32f8861dc9503 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -21,7 +21,7 @@ import java.io.File import java.net.URI import scala.collection.JavaConversions._ -import scala.collection.mutable.HashMap +import scala.collection.mutable.{HashMap, ListBuffer} import org.apache.hadoop.fs.Path import org.apache.hadoop.yarn.api._ @@ -44,9 +44,9 @@ trait ExecutorRunnableUtil extends Logging { hostname: String, executorMemory: Int, executorCores: Int, - localResources: HashMap[String, LocalResource]) = { + localResources: HashMap[String, LocalResource]): List[String] = { // Extra options for the JVM - var JAVA_OPTS = "" + val JAVA_OPTS = ListBuffer[String]() // Set the JVM memory val executorMemoryString = executorMemory + "m" JAVA_OPTS += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " " @@ -56,10 +56,21 @@ trait ExecutorRunnableUtil extends Logging { JAVA_OPTS += opts } - JAVA_OPTS += " -Djava.io.tmpdir=" + - new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + " " + JAVA_OPTS += "-Djava.io.tmpdir=" + + new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) JAVA_OPTS += ClientBase.getLog4jConfiguration(localResources) + // Certain configs need to be passed here because they are needed before the Executor + // registers with the Scheduler and transfers the spark configs. Since the Executor backend + // uses Akka to connect to the scheduler, the akka settings are needed as well as the + // authentication settings. + sparkConf.getAll. + filter { case (k, v) => k.startsWith("spark.auth") || k.startsWith("spark.akka") }. + foreach { case (k, v) => JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\"" } + + sparkConf.getAkkaConf. + foreach { case (k, v) => JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\"" } + // Commenting it out for now - so that people can refer to the properties if required. Remove // it once cpuset version is pushed out. // The context is, default gc for server class machines end up using all cores to do gc - hence @@ -85,25 +96,25 @@ trait ExecutorRunnableUtil extends Logging { } */ - val commands = List[String]( - Environment.JAVA_HOME.$() + "/bin/java" + - " -server " + + val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", + "-server", // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling. // Not killing the task leaves various aspects of the executor and (to some extent) the jvm in // an inconsistent state. // TODO: If the OOM is not recoverable by rescheduling it on different node, then do // 'something' to fail job ... akin to blacklisting trackers in mapred ? - " -XX:OnOutOfMemoryError='kill %p' " + - JAVA_OPTS + - " org.apache.spark.executor.CoarseGrainedExecutorBackend " + - masterAddress + " " + - slaveId + " " + - hostname + " " + - executorCores + - " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + - " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") - - commands + "-XX:OnOutOfMemoryError='kill %p'") ++ + JAVA_OPTS ++ + Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend", + masterAddress.toString, + slaveId.toString, + hostname.toString, + executorCores.toString, + "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout", + "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") + + // TODO: it would be nicer to just make sure there are no null commands here + commands.map(s => if (s == null) "null" else s).toList } private def setupDistributedCache(