diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 2301caafb07ff..dc1e8f6c21b62 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -244,6 +244,36 @@ trait AccumulatorParam[T] extends AccumulableParam[T, T] { } } +object AccumulatorParam { + + // The following implicit objects were in SparkContext before 1.2 and users had to + // `import SparkContext._` to enable them. Now we move them here to make the compiler find + // them automatically. However, as there are duplicate codes in SparkContext for backward + // compatibility, please update them accordingly if you modify the following implicit objects. + + implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { + def addInPlace(t1: Double, t2: Double): Double = t1 + t2 + def zero(initialValue: Double) = 0.0 + } + + implicit object IntAccumulatorParam extends AccumulatorParam[Int] { + def addInPlace(t1: Int, t2: Int): Int = t1 + t2 + def zero(initialValue: Int) = 0 + } + + implicit object LongAccumulatorParam extends AccumulatorParam[Long] { + def addInPlace(t1: Long, t2: Long) = t1 + t2 + def zero(initialValue: Long) = 0L + } + + implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { + def addInPlace(t1: Float, t2: Float) = t1 + t2 + def zero(initialValue: Float) = 0f + } + + // TODO: Add AccumulatorParams for other types, e.g. lists and strings +} + // TODO: The multi-thread support in accumulators is kind of lame; check // if there's a more intuitive way of doing it right private object Accumulators { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ae8bbfb56f493..9b0d5be7a7ab2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -83,6 +83,8 @@ class SparkContext(config: SparkConf) extends Logging { // contains a map from hostname to a list of input format splits on the host. private[spark] var preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map() + val startTime = System.currentTimeMillis() + /** * Create a SparkContext that loads settings from system properties (for instance, when * launching with ./bin/spark-submit). @@ -269,8 +271,6 @@ class SparkContext(config: SparkConf) extends Logging { /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) - val startTime = System.currentTimeMillis() - // Add each JAR given through the constructor if (jars != null) { jars.foreach(addJar) @@ -1624,47 +1624,74 @@ object SparkContext extends Logging { private[spark] val DRIVER_IDENTIFIER = "" - implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { + // The following deprecated objects have already been copied to `object AccumulatorParam` to + // make the compiler find them automatically. They are duplicate codes only for backward + // compatibility, please update `object AccumulatorParam` accordingly if you plan to modify the + // following ones. + + @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + + "backward compatibility.", "1.2.0") + object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double) = 0.0 } - implicit object IntAccumulatorParam extends AccumulatorParam[Int] { + @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + + "backward compatibility.", "1.2.0") + object IntAccumulatorParam extends AccumulatorParam[Int] { def addInPlace(t1: Int, t2: Int): Int = t1 + t2 def zero(initialValue: Int) = 0 } - implicit object LongAccumulatorParam extends AccumulatorParam[Long] { + @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + + "backward compatibility.", "1.2.0") + object LongAccumulatorParam extends AccumulatorParam[Long] { def addInPlace(t1: Long, t2: Long) = t1 + t2 def zero(initialValue: Long) = 0L } - implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { + @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + + "backward compatibility.", "1.2.0") + object FloatAccumulatorParam extends AccumulatorParam[Float] { def addInPlace(t1: Float, t2: Float) = t1 + t2 def zero(initialValue: Float) = 0f } - // TODO: Add AccumulatorParams for other types, e.g. lists and strings + // The following deprecated functions have already been moved to `object RDD` to + // make the compiler find them automatically. They are still kept here for backward compatibility + // and just call the corresponding functions in `object RDD`. - implicit def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { - new PairRDDFunctions(rdd) + RDD.rddToPairRDDFunctions(rdd) } - implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd) + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = RDD.rddToAsyncRDDActions(rdd) - implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( rdd: RDD[(K, V)]) = - new SequenceFileRDDFunctions(rdd) + RDD.rddToSequenceFileRDDFunctions(rdd) - implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( rdd: RDD[(K, V)]) = - new OrderedRDDFunctions[K, V, (K, V)](rdd) + RDD.rddToOrderedRDDFunctions(rdd) - implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd) + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = RDD.doubleRDDToDoubleRDDFunctions(rdd) - implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = - new DoubleRDDFunctions(rdd.map(x => num.toDouble(x))) + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = + RDD.numericRDDToDoubleRDDFunctions(rdd) // Implicit conversions to common Writable types, for saveAsSequenceFile @@ -1690,40 +1717,49 @@ object SparkContext extends Logging { arr.map(x => anyToWritable(x)).toArray) } - // Helper objects for converting common types to Writable - private def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T) - : WritableConverter[T] = { - val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]] - new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W])) - } + // The following deprecated functions have already been moved to `object WritableConverter` to + // make the compiler find them automatically. They are still kept here for backward compatibility + // and just call the corresponding functions in `object WritableConverter`. - implicit def intWritableConverter(): WritableConverter[Int] = - simpleWritableConverter[Int, IntWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def intWritableConverter(): WritableConverter[Int] = + WritableConverter.intWritableConverter() - implicit def longWritableConverter(): WritableConverter[Long] = - simpleWritableConverter[Long, LongWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def longWritableConverter(): WritableConverter[Long] = + WritableConverter.longWritableConverter() - implicit def doubleWritableConverter(): WritableConverter[Double] = - simpleWritableConverter[Double, DoubleWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def doubleWritableConverter(): WritableConverter[Double] = + WritableConverter.doubleWritableConverter() - implicit def floatWritableConverter(): WritableConverter[Float] = - simpleWritableConverter[Float, FloatWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def floatWritableConverter(): WritableConverter[Float] = + WritableConverter.floatWritableConverter() - implicit def booleanWritableConverter(): WritableConverter[Boolean] = - simpleWritableConverter[Boolean, BooleanWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def booleanWritableConverter(): WritableConverter[Boolean] = + WritableConverter.booleanWritableConverter() - implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = { - simpleWritableConverter[Array[Byte], BytesWritable](bw => - // getBytes method returns array which is longer then data to be returned - Arrays.copyOfRange(bw.getBytes, 0, bw.getLength) - ) - } + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def bytesWritableConverter(): WritableConverter[Array[Byte]] = + WritableConverter.bytesWritableConverter() - implicit def stringWritableConverter(): WritableConverter[String] = - simpleWritableConverter[String, Text](_.toString) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def stringWritableConverter(): WritableConverter[String] = + WritableConverter.stringWritableConverter() - implicit def writableWritableConverter[T <: Writable]() = - new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T]) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def writableWritableConverter[T <: Writable]() = + WritableConverter.writableWritableConverter() /** * Find the JAR from which a given class was loaded, to make it easy for users to pass @@ -1950,3 +1986,46 @@ private[spark] class WritableConverter[T]( val writableClass: ClassTag[T] => Class[_ <: Writable], val convert: Writable => T) extends Serializable + +object WritableConverter { + + // Helper objects for converting common types to Writable + private[spark] def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T) + : WritableConverter[T] = { + val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]] + new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W])) + } + + // The following implicit functions were in SparkContext before 1.2 and users had to + // `import SparkContext._` to enable them. Now we move them here to make the compiler find + // them automatically. However, we still keep the old functions in SparkContext for backward + // compatibility and forward to the following functions directly. + + implicit def intWritableConverter(): WritableConverter[Int] = + simpleWritableConverter[Int, IntWritable](_.get) + + implicit def longWritableConverter(): WritableConverter[Long] = + simpleWritableConverter[Long, LongWritable](_.get) + + implicit def doubleWritableConverter(): WritableConverter[Double] = + simpleWritableConverter[Double, DoubleWritable](_.get) + + implicit def floatWritableConverter(): WritableConverter[Float] = + simpleWritableConverter[Float, FloatWritable](_.get) + + implicit def booleanWritableConverter(): WritableConverter[Boolean] = + simpleWritableConverter[Boolean, BooleanWritable](_.get) + + implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = { + simpleWritableConverter[Array[Byte], BytesWritable](bw => + // getBytes method returns array which is longer then data to be returned + Arrays.copyOfRange(bw.getBytes, 0, bw.getLength) + ) + } + + implicit def stringWritableConverter(): WritableConverter[String] = + simpleWritableConverter[String, Text](_.toString) + + implicit def writableWritableConverter[T <: Writable]() = + new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T]) +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index e37f3acaf6e30..7af3538262fd6 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -32,13 +32,13 @@ import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.Partitioner._ -import org.apache.spark.SparkContext.rddToPairRDDFunctions import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.{OrderedRDDFunctions, RDD} +import org.apache.spark.rdd.RDD.rddToPairRDDFunctions import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 6a6d9bf6857d3..97f5c9f257e09 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -33,7 +33,7 @@ import org.apache.hadoop.mapred.{InputFormat, JobConf} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ -import org.apache.spark.SparkContext._ +import org.apache.spark.AccumulatorParam._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index 6ff2aa5244847..36a2e2c6a6349 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -18,12 +18,13 @@ package org.apache.spark.deploy.master import java.io._ -import java.nio.ByteBuffer + +import scala.reflect.ClassTag + +import akka.serialization.Serialization import org.apache.spark.Logging -import org.apache.spark.serializer.Serializer -import scala.reflect.ClassTag /** * Stores data in a single on-disk directory with one file per application and worker. @@ -34,10 +35,9 @@ import scala.reflect.ClassTag */ private[spark] class FileSystemPersistenceEngine( val dir: String, - val serialization: Serializer) + val serialization: Serialization) extends PersistenceEngine with Logging { - val serializer = serialization.newInstance() new File(dir).mkdir() override def persist(name: String, obj: Object): Unit = { @@ -56,17 +56,17 @@ private[spark] class FileSystemPersistenceEngine( private def serializeIntoFile(file: File, value: AnyRef) { val created = file.createNewFile() if (!created) { throw new IllegalStateException("Could not create file: " + file) } - - val out = serializer.serializeStream(new FileOutputStream(file)) + val serializer = serialization.findSerializerFor(value) + val serialized = serializer.toBinary(value) + val out = new FileOutputStream(file) try { - out.writeObject(value) + out.write(serialized) } finally { out.close() } - } - def deserializeFromFile[T](file: File): T = { + private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = { val fileData = new Array[Byte](file.length().asInstanceOf[Int]) val dis = new DataInputStream(new FileInputStream(file)) try { @@ -74,7 +74,9 @@ private[spark] class FileSystemPersistenceEngine( } finally { dis.close() } - - serializer.deserializeStream(dis).readObject() + val clazz = m.runtimeClass.asInstanceOf[Class[T]] + val serializer = serialization.serializerFor(clazz) + serializer.fromBinary(fileData).asInstanceOf[T] } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 021454e25804c..7b32c505def9b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -30,6 +30,7 @@ import scala.util.Random import akka.actor._ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import akka.serialization.Serialization import akka.serialization.SerializationExtension import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} @@ -132,15 +133,18 @@ private[spark] class Master( val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match { case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") - val zkFactory = new ZooKeeperRecoveryModeFactory(conf) + val zkFactory = + new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(context.system)) (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => - val fsFactory = new FileSystemRecoveryModeFactory(conf) + val fsFactory = + new FileSystemRecoveryModeFactory(conf, SerializationExtension(context.system)) (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) case "CUSTOM" => val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory")) - val factory = clazz.getConstructor(conf.getClass) - .newInstance(conf).asInstanceOf[StandaloneRecoveryModeFactory] + val factory = clazz.getConstructor(conf.getClass, Serialization.getClass) + .newInstance(conf, SerializationExtension(context.system)) + .asInstanceOf[StandaloneRecoveryModeFactory] (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => (new BlackHolePersistenceEngine(), new MonarchyLeaderAgent(this)) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala index d9d36c1ed5f9f..1096eb0368357 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala @@ -17,9 +17,10 @@ package org.apache.spark.deploy.master +import akka.serialization.Serialization + import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.serializer.JavaSerializer /** * ::DeveloperApi:: @@ -29,7 +30,7 @@ import org.apache.spark.serializer.JavaSerializer * */ @DeveloperApi -abstract class StandaloneRecoveryModeFactory(conf: SparkConf) { +abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) { /** * PersistenceEngine defines how the persistent data(Information about worker, driver etc..) @@ -48,21 +49,21 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf) { * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual * recovery is made by restoring from filesystem. */ -private[spark] class FileSystemRecoveryModeFactory(conf: SparkConf) - extends StandaloneRecoveryModeFactory(conf) with Logging { +private[spark] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization) + extends StandaloneRecoveryModeFactory(conf, serializer) with Logging { val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") def createPersistenceEngine() = { logInfo("Persisting recovery state to directory: " + RECOVERY_DIR) - new FileSystemPersistenceEngine(RECOVERY_DIR, new JavaSerializer(conf)) + new FileSystemPersistenceEngine(RECOVERY_DIR, serializer) } def createLeaderElectionAgent(master: LeaderElectable) = new MonarchyLeaderAgent(master) } -private[spark] class ZooKeeperRecoveryModeFactory(conf: SparkConf) - extends StandaloneRecoveryModeFactory(conf) { - def createPersistenceEngine() = new ZooKeeperPersistenceEngine(new JavaSerializer(conf), conf) +private[spark] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization) + extends StandaloneRecoveryModeFactory(conf, serializer) { + def createPersistenceEngine() = new ZooKeeperPersistenceEngine(conf, serializer) def createLeaderElectionAgent(master: LeaderElectable) = new ZooKeeperLeaderElectionAgent(master, conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 96c2139eb02f0..e11ac031fb9c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -17,27 +17,24 @@ package org.apache.spark.deploy.master +import akka.serialization.Serialization + import scala.collection.JavaConversions._ +import scala.reflect.ClassTag import org.apache.curator.framework.CuratorFramework import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.serializer.Serializer -import java.nio.ByteBuffer -import scala.reflect.ClassTag - -private[spark] class ZooKeeperPersistenceEngine(val serialization: Serializer, conf: SparkConf) +private[spark] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) extends PersistenceEngine with Logging { val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status" val zk: CuratorFramework = SparkCuratorUtil.newClient(conf) - val serializer = serialization.newInstance() - SparkCuratorUtil.mkdir(zk, WORKING_DIR) @@ -59,14 +56,17 @@ private[spark] class ZooKeeperPersistenceEngine(val serialization: Serializer, c } private def serializeIntoFile(path: String, value: AnyRef) { - val serialized = serializer.serialize(value) - zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized.array()) + val serializer = serialization.findSerializerFor(value) + val serialized = serializer.toBinary(value) + zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized) } - def deserializeFromFile[T](filename: String): Option[T] = { + def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = { val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename) + val clazz = m.runtimeClass.asInstanceOf[Class[T]] + val serializer = serialization.serializerFor(clazz) try { - Some(serializer.deserialize(ByteBuffer.wrap(fileData))) + Some(serializer.fromBinary(fileData).asInstanceOf[T]) } catch { case e: Exception => { logWarning("Exception while reading persisted file, deleting", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e4025bcf48db6..3add4a76192ca 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -21,6 +21,7 @@ import java.util.{Properties, Random} import scala.collection.{mutable, Map} import scala.collection.mutable.ArrayBuffer +import scala.language.implicitConversions import scala.reflect.{classTag, ClassTag} import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus @@ -28,6 +29,7 @@ import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text +import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ @@ -1383,3 +1385,31 @@ abstract class RDD[T: ClassTag]( new JavaRDD(this)(elementClassTag) } } + +object RDD { + + // The following implicit functions were in SparkContext before 1.2 and users had to + // `import SparkContext._` to enable them. Now we move them here to make the compiler find + // them automatically. However, we still keep the old functions in SparkContext for backward + // compatibility and forward to the following functions directly. + + implicit def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) + (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { + new PairRDDFunctions(rdd) + } + + implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd) + + implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( + rdd: RDD[(K, V)]) = + new SequenceFileRDDFunctions(rdd) + + implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( + rdd: RDD[(K, V)]) = + new OrderedRDDFunctions[K, V, (K, V)](rdd) + + implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd) + + implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = + new DoubleRDDFunctions(rdd.map(x => num.toDouble(x))) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 22449517d100f..b1222af662e9b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -751,14 +751,15 @@ class DAGScheduler( localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 if (shouldRunLocally) { // Compute very short actions like first() or take() with no parent stages locally. - listenerBus.post(SparkListenerJobStart(job.jobId, Array[Int](), properties)) + listenerBus.post(SparkListenerJobStart(job.jobId, Seq.empty, properties)) runLocally(job) } else { jobIdToActiveJob(jobId) = job activeJobs += job finalStage.resultOfJob = Some(job) - listenerBus.post(SparkListenerJobStart(job.jobId, jobIdToStageIds(jobId).toArray, - properties)) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post(SparkListenerJobStart(job.jobId, stageInfos, properties)) submitStage(finalStage) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 86afe3bd5265f..b62b0c1312693 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -56,8 +56,15 @@ case class SparkListenerTaskEnd( extends SparkListenerEvent @DeveloperApi -case class SparkListenerJobStart(jobId: Int, stageIds: Seq[Int], properties: Properties = null) - extends SparkListenerEvent +case class SparkListenerJobStart( + jobId: Int, + stageInfos: Seq[StageInfo], + properties: Properties = null) + extends SparkListenerEvent { + // Note: this is here for backwards-compatibility with older versions of this event which + // only stored stageIds and not StageInfos: + val stageIds: Seq[Int] = stageInfos.map(_.stageId) +} @DeveloperApi case class SparkListenerJobEnd(jobId: Int, jobResult: JobResult) extends SparkListenerEvent diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 049938f827291..176907dffa46a 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -23,7 +23,7 @@ import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.JettyUtils._ import org.apache.spark.ui.env.{EnvironmentListener, EnvironmentTab} import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab} -import org.apache.spark.ui.jobs.{JobProgressListener, JobProgressTab} +import org.apache.spark.ui.jobs.{JobsTab, JobProgressListener, StagesTab} import org.apache.spark.ui.storage.{StorageListener, StorageTab} /** @@ -43,17 +43,20 @@ private[spark] class SparkUI private ( extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath, "SparkUI") with Logging { + val killEnabled = sc.map(_.conf.getBoolean("spark.ui.killEnabled", true)).getOrElse(false) + /** Initialize all components of the server. */ def initialize() { - val jobProgressTab = new JobProgressTab(this) - attachTab(jobProgressTab) + attachTab(new JobsTab(this)) + val stagesTab = new StagesTab(this) + attachTab(stagesTab) attachTab(new StorageTab(this)) attachTab(new EnvironmentTab(this)) attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler(createRedirectHandler("/", "/stages", basePath = basePath)) + attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) attachHandler( - createRedirectHandler("/stages/stage/kill", "/stages", jobProgressTab.handleKillRequest)) + createRedirectHandler("/stages/stage/kill", "/stages", stagesTab.handleKillRequest)) // If the UI is live, then serve sc.foreach { _.env.metricsSystem.getServletHandlers.foreach(attachHandler) } } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 7bc1e24d58711..0c418beaf7581 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -169,7 +169,8 @@ private[spark] object UIUtils extends Logging { title: String, content: => Seq[Node], activeTab: SparkUITab, - refreshInterval: Option[Int] = None): Seq[Node] = { + refreshInterval: Option[Int] = None, + helpText: Option[String] = None): Seq[Node] = { val appName = activeTab.appName val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." @@ -178,6 +179,9 @@ private[spark] object UIUtils extends Logging { {tab.name} } + val helpButton: Seq[Node] = helpText.map { helpText => + (?) + }.getOrElse(Seq.empty) @@ -201,6 +205,7 @@ private[spark] object UIUtils extends Logging {

{title} + {helpButton}

@@ -283,4 +288,24 @@ private[spark] object UIUtils extends Logging { } + + def makeProgressBar( + started: Int, + completed: Int, + failed: Int, + skipped:Int, + total: Int): Seq[Node] = { + val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) + val startWidth = "width: %s%%".format((started.toDouble/total)*100) + +
+ + {completed}/{total} + { if (failed > 0) s"($failed failed)" } + { if (skipped > 0) s"($skipped skipped)" } + +
+
+
+ } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala new file mode 100644 index 0000000000000..ea2d187a0e8e4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import scala.xml.{Node, NodeSeq} + +import javax.servlet.http.HttpServletRequest + +import org.apache.spark.JobExecutionStatus +import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.jobs.UIData.JobUIData + +/** Page showing list of all ongoing and recently finished jobs */ +private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { + private val startTime: Option[Long] = parent.sc.map(_.startTime) + private val listener = parent.listener + + private def jobsTable(jobs: Seq[JobUIData]): Seq[Node] = { + val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined) + + val columns: Seq[Node] = { + {if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id"} + Description + Submitted + Duration + Stages: Succeeded/Total + Tasks (for all stages): Succeeded/Total + } + + def makeRow(job: JobUIData): Seq[Node] = { + val lastStageInfo = listener.stageIdToInfo.get(job.stageIds.max) + val lastStageData = lastStageInfo.flatMap { s => + listener.stageIdToData.get((s.stageId, s.attemptId)) + } + val isComplete = job.status == JobExecutionStatus.SUCCEEDED + val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") + val lastStageDescription = lastStageData.flatMap(_.description).getOrElse("") + val duration: Option[Long] = { + job.startTime.map { start => + val end = job.endTime.getOrElse(System.currentTimeMillis()) + end - start + } + } + val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") + val formattedSubmissionTime = job.startTime.map(UIUtils.formatDate).getOrElse("Unknown") + val detailUrl = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), job.jobId) + + + {job.jobId} {job.jobGroup.map(id => s"($id)").getOrElse("")} + + +
{lastStageDescription}
+ {lastStageName} + + + {formattedSubmissionTime} + + {formattedDuration} + + {job.completedStageIndices.size}/{job.stageIds.size - job.numSkippedStages} + {if (job.numFailedStages > 0) s"(${job.numFailedStages} failed)"} + {if (job.numSkippedStages > 0) s"(${job.numSkippedStages} skipped)"} + + + {UIUtils.makeProgressBar(started = job.numActiveTasks, completed = job.numCompletedTasks, + failed = job.numFailedTasks, skipped = job.numSkippedTasks, + total = job.numTasks - job.numSkippedTasks)} + + + } + + + {columns} + + {jobs.map(makeRow)} + +
+ } + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val activeJobs = listener.activeJobs.values.toSeq + val completedJobs = listener.completedJobs.reverse.toSeq + val failedJobs = listener.failedJobs.reverse.toSeq + val now = System.currentTimeMillis + + val activeJobsTable = + jobsTable(activeJobs.sortBy(_.startTime.getOrElse(-1L)).reverse) + val completedJobsTable = + jobsTable(completedJobs.sortBy(_.endTime.getOrElse(-1L)).reverse) + val failedJobsTable = + jobsTable(failedJobs.sortBy(_.endTime.getOrElse(-1L)).reverse) + + val summary: NodeSeq = +
+ +
+ + val content = summary ++ +

Active Jobs ({activeJobs.size})

++ activeJobsTable ++ +

Completed Jobs ({completedJobs.size})

++ completedJobsTable ++ +

Failed Jobs ({failedJobs.size})

++ failedJobsTable + + val helpText = """A job is triggered by a action, like "count()" or "saveAsTextFile()".""" + + " Click on a job's title to see information about the stages of tasks associated with" + + " the job." + + UIUtils.headerSparkPage("Spark Jobs", content, parent, helpText = Some(helpText)) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala similarity index 87% rename from core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala rename to core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 83a7898071c9b..b0f8ca2ab0d3f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -25,7 +25,7 @@ import org.apache.spark.scheduler.Schedulable import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing list of all ongoing and recently finished stages and pools */ -private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") { +private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { private val sc = parent.sc private val listener = parent.listener private def isFairScheduler = parent.isFairScheduler @@ -41,11 +41,14 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") val activeStagesTable = new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, - parent, parent.killEnabled) + parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, + killEnabled = parent.killEnabled) val completedStagesTable = - new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent) + new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent.basePath, + parent.listener, isFairScheduler = parent.isFairScheduler, killEnabled = false) val failedStagesTable = - new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent) + new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent.basePath, + parent.listener, isFairScheduler = parent.isFairScheduler) // For now, pool information is only accessible in live UIs val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable]) @@ -93,7 +96,7 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")

Failed Stages ({numFailedStages})

++ failedStagesTable.toNodeSeq - UIUtils.headerSparkPage("Spark Stages", content, parent) + UIUtils.headerSparkPage("Spark Stages (for all jobs)", content, parent) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index fa0f96bff34ff..35bbe8b4f9ac8 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -25,7 +25,7 @@ import org.apache.spark.ui.jobs.UIData.StageUIData import org.apache.spark.util.Utils /** Stage summary grouped by executors. */ -private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobProgressTab) { +private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: StagesTab) { private val listener = parent.listener def toNodeSeq: Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala new file mode 100644 index 0000000000000..77d36209c6048 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import scala.collection.mutable +import scala.xml.{NodeSeq, Node} + +import javax.servlet.http.HttpServletRequest + +import org.apache.spark.JobExecutionStatus +import org.apache.spark.scheduler.StageInfo +import org.apache.spark.ui.{UIUtils, WebUIPage} + +/** Page showing statistics and stage list for a given job */ +private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { + private val listener = parent.listener + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val jobId = request.getParameter("id").toInt + val jobDataOption = listener.jobIdToData.get(jobId) + if (jobDataOption.isEmpty) { + val content = +
+

No information to display for job {jobId}

+
+ return UIUtils.headerSparkPage( + s"Details for Job $jobId", content, parent) + } + val jobData = jobDataOption.get + val isComplete = jobData.status != JobExecutionStatus.RUNNING + val stages = jobData.stageIds.map { stageId => + // This could be empty if the JobProgressListener hasn't received information about the + // stage or if the stage information has been garbage collected + listener.stageIdToInfo.getOrElse(stageId, + new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, "Unknown")) + } + + val activeStages = mutable.Buffer[StageInfo]() + val completedStages = mutable.Buffer[StageInfo]() + // If the job is completed, then any pending stages are displayed as "skipped": + val pendingOrSkippedStages = mutable.Buffer[StageInfo]() + val failedStages = mutable.Buffer[StageInfo]() + for (stage <- stages) { + if (stage.submissionTime.isEmpty) { + pendingOrSkippedStages += stage + } else if (stage.completionTime.isDefined) { + if (stage.failureReason.isDefined) { + failedStages += stage + } else { + completedStages += stage + } + } else { + activeStages += stage + } + } + + val activeStagesTable = + new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, + parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, + killEnabled = parent.killEnabled) + val pendingOrSkippedStagesTable = + new StageTableBase(pendingOrSkippedStages.sortBy(_.stageId).reverse, + parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, + killEnabled = false) + val completedStagesTable = + new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent.basePath, + parent.listener, isFairScheduler = parent.isFairScheduler, killEnabled = false) + val failedStagesTable = + new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent.basePath, + parent.listener, isFairScheduler = parent.isFairScheduler) + + val shouldShowActiveStages = activeStages.nonEmpty + val shouldShowPendingStages = !isComplete && pendingOrSkippedStages.nonEmpty + val shouldShowCompletedStages = completedStages.nonEmpty + val shouldShowSkippedStages = isComplete && pendingOrSkippedStages.nonEmpty + val shouldShowFailedStages = failedStages.nonEmpty + + val summary: NodeSeq = +
+ +
+ + var content = summary + if (shouldShowActiveStages) { + content ++=

Active Stages ({activeStages.size})

++ + activeStagesTable.toNodeSeq + } + if (shouldShowPendingStages) { + content ++=

Pending Stages ({pendingOrSkippedStages.size})

++ + pendingOrSkippedStagesTable.toNodeSeq + } + if (shouldShowCompletedStages) { + content ++=

Completed Stages ({completedStages.size})

++ + completedStagesTable.toNodeSeq + } + if (shouldShowSkippedStages) { + content ++=

Skipped Stages ({pendingOrSkippedStages.size})

++ + pendingOrSkippedStagesTable.toNodeSeq + } + if (shouldShowFailedStages) { + content ++=

Failed Stages ({failedStages.size})

++ + failedStagesTable.toNodeSeq + } + UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index ccdcf0e047f48..72935beb3a34a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import scala.collection.mutable.{HashMap, ListBuffer} +import scala.collection.mutable.{HashMap, HashSet, ListBuffer} import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi @@ -49,8 +49,6 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { type PoolName = String type ExecutorId = String - // Define all of our state: - // Jobs: val activeJobs = new HashMap[JobId, JobUIData] val completedJobs = ListBuffer[JobUIData]() @@ -60,9 +58,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // Stages: val activeStages = new HashMap[StageId, StageInfo] val completedStages = ListBuffer[StageInfo]() + val skippedStages = ListBuffer[StageInfo]() val failedStages = ListBuffer[StageInfo]() val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData] val stageIdToInfo = new HashMap[StageId, StageInfo] + val stageIdToActiveJobIds = new HashMap[StageId, HashSet[JobId]] val poolToActiveStages = HashMap[PoolName, HashMap[StageId, StageInfo]]() // Total of completed and failed stages that have ever been run. These may be greater than // `completedStages.size` and `failedStages.size` if we have run more stages or jobs than @@ -95,7 +95,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { Map( "activeStages" -> activeStages.size, "activeJobs" -> activeJobs.size, - "poolToActiveStages" -> poolToActiveStages.values.map(_.size).sum + "poolToActiveStages" -> poolToActiveStages.values.map(_.size).sum, + "stageIdToActiveJobIds" -> stageIdToActiveJobIds.values.map(_.size).sum ) } @@ -106,6 +107,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { "completedJobs" -> completedJobs.size, "failedJobs" -> failedJobs.size, "completedStages" -> completedStages.size, + "skippedStages" -> skippedStages.size, "failedStages" -> failedStages.size ) } @@ -144,11 +146,39 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } override def onJobStart(jobStart: SparkListenerJobStart) = synchronized { - val jobGroup = Option(jobStart.properties).map(_.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) + val jobGroup = for ( + props <- Option(jobStart.properties); + group <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) + ) yield group val jobData: JobUIData = - new JobUIData(jobStart.jobId, jobStart.stageIds, jobGroup, JobExecutionStatus.RUNNING) + new JobUIData( + jobId = jobStart.jobId, + startTime = Some(System.currentTimeMillis), + endTime = None, + stageIds = jobStart.stageIds, + jobGroup = jobGroup, + status = JobExecutionStatus.RUNNING) + // Compute (a potential underestimate of) the number of tasks that will be run by this job. + // This may be an underestimate because the job start event references all of the result + // stages's transitive stage dependencies, but some of these stages might be skipped if their + // output is available from earlier runs. + // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. + jobData.numTasks = { + val allStages = jobStart.stageInfos + val missingStages = allStages.filter(_.completionTime.isEmpty) + missingStages.map(_.numTasks).sum + } jobIdToData(jobStart.jobId) = jobData activeJobs(jobStart.jobId) = jobData + for (stageId <- jobStart.stageIds) { + stageIdToActiveJobIds.getOrElseUpdate(stageId, new HashSet[StageId]).add(jobStart.jobId) + } + // If there's no information for a stage, store the StageInfo received from the scheduler + // so that we can display stage descriptions for pending stages: + for (stageInfo <- jobStart.stageInfos) { + stageIdToInfo.getOrElseUpdate(stageInfo.stageId, stageInfo) + stageIdToData.getOrElseUpdate((stageInfo.stageId, stageInfo.attemptId), new StageUIData) + } } override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { @@ -156,6 +186,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { logWarning(s"Job completed for unknown job ${jobEnd.jobId}") new JobUIData(jobId = jobEnd.jobId) } + jobData.endTime = Some(System.currentTimeMillis()) jobEnd.jobResult match { case JobSucceeded => completedJobs += jobData @@ -166,6 +197,20 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { trimJobsIfNecessary(failedJobs) jobData.status = JobExecutionStatus.FAILED } + for (stageId <- jobData.stageIds) { + stageIdToActiveJobIds.get(stageId).foreach { jobsUsingStage => + jobsUsingStage.remove(jobEnd.jobId) + stageIdToInfo.get(stageId).foreach { stageInfo => + if (stageInfo.submissionTime.isEmpty) { + // if this stage is pending, it won't complete, so mark it as "skipped": + skippedStages += stageInfo + trimStagesIfNecessary(skippedStages) + jobData.numSkippedStages += 1 + jobData.numSkippedTasks += stageInfo.numTasks + } + } + } + } } override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { @@ -193,6 +238,19 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { numFailedStages += 1 trimStagesIfNecessary(failedStages) } + + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveStages -= 1 + if (stage.failureReason.isEmpty) { + jobData.completedStageIndices.add(stage.stageId) + } else { + jobData.numFailedStages += 1 + } + } } /** For FIFO, all stages are contained by "default" pool but "default" pool here is meaningless */ @@ -214,6 +272,14 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo]) stages(stage.stageId) = stage + + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveStages += 1 + } } override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { @@ -226,6 +292,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.numActiveTasks += 1 stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo)) } + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveTasks += 1 + } } override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { @@ -283,6 +356,20 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { taskData.taskInfo = info taskData.taskMetrics = metrics taskData.errorMessage = errorMessage + + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskEnd.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveTasks -= 1 + taskEnd.reason match { + case Success => + jobData.numCompletedTasks += 1 + case _ => + jobData.numFailedTasks += 1 + } + } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala new file mode 100644 index 0000000000000..b2bbfdee56946 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import org.apache.spark.scheduler.SchedulingMode +import org.apache.spark.ui.{SparkUI, SparkUITab} + +/** Web UI showing progress status of all jobs in the given SparkContext. */ +private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { + val sc = parent.sc + val killEnabled = parent.killEnabled + def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) + val listener = parent.jobProgressListener + + attachPage(new AllJobsPage(this)) + attachPage(new JobPage(this)) +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 770d99eea1c9d..5fc6cc7533150 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -25,7 +25,7 @@ import org.apache.spark.scheduler.{Schedulable, StageInfo} import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing specific pool details */ -private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") { +private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { private val sc = parent.sc private val listener = parent.listener @@ -37,8 +37,9 @@ private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") { case Some(s) => s.values.toSeq case None => Seq[StageInfo]() } - val activeStagesTable = - new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, parent) + val activeStagesTable = new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, + parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, + killEnabled = parent.killEnabled) // For now, pool information is only accessible in live UIs val pools = sc.map(_.getPoolForName(poolName).get).toSeq diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index 64178e1e33d41..df1899e7a9b84 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -24,7 +24,7 @@ import org.apache.spark.scheduler.{Schedulable, StageInfo} import org.apache.spark.ui.UIUtils /** Table showing list of pools */ -private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) { +private[ui] class PoolTable(pools: Seq[Schedulable], parent: StagesTab) { private val listener = parent.listener def toNodeSeq: Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 36afc4942e085..40e05f86b661d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.{Utils, Distribution} import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} /** Page showing statistics and task list for a given stage */ -private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { +private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 2ff561ccc7da0..e7d6244dcd679 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -31,11 +31,10 @@ import org.apache.spark.util.Utils /** Page showing list of all ongoing and recently finished stages */ private[ui] class StageTableBase( stages: Seq[StageInfo], - parent: JobProgressTab, - killEnabled: Boolean = false) { - - private val listener = parent.listener - protected def isFairScheduler = parent.isFairScheduler + basePath: String, + listener: JobProgressListener, + isFairScheduler: Boolean, + killEnabled: Boolean) { protected def columns: Seq[Node] = { Stage Id ++ @@ -73,25 +72,11 @@ private[ui] class StageTableBase( } - private def makeProgressBar(started: Int, completed: Int, failed: Int, total: Int): Seq[Node] = - { - val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) - val startWidth = "width: %s%%".format((started.toDouble/total)*100) - -
- - {completed}/{total} { if (failed > 0) s"($failed failed)" else "" } - -
-
-
- } - private def makeDescription(s: StageInfo): Seq[Node] = { // scalastyle:off val killLink = if (killEnabled) { val killLinkUri = "%s/stages/stage/kill?id=%s&terminate=true" - .format(UIUtils.prependBaseUri(parent.basePath), s.stageId) + .format(UIUtils.prependBaseUri(basePath), s.stageId) val confirm = "return window.confirm('Are you sure you want to kill stage %s ?');" .format(s.stageId) @@ -101,7 +86,7 @@ private[ui] class StageTableBase( // scalastyle:on val nameLinkUri ="%s/stages/stage?id=%s&attempt=%s" - .format(UIUtils.prependBaseUri(parent.basePath), s.stageId, s.attemptId) + .format(UIUtils.prependBaseUri(basePath), s.stageId, s.attemptId) val nameLink = {s.name} val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) @@ -115,7 +100,7 @@ private[ui] class StageTableBase( Text("RDD: ") ++ // scalastyle:off cachedRddInfos.map { i => - {i.name} + {i.name} } // scalastyle:on }} @@ -167,7 +152,7 @@ private[ui] class StageTableBase( {if (isFairScheduler) { + .format(UIUtils.prependBaseUri(basePath), stageData.schedulingPool)}> {stageData.schedulingPool} @@ -180,8 +165,9 @@ private[ui] class StageTableBase( {formattedDuration} - {makeProgressBar(stageData.numActiveTasks, stageData.completedIndices.size, - stageData.numFailedTasks, s.numTasks)} + {UIUtils.makeProgressBar(started = stageData.numActiveTasks, + completed = stageData.completedIndices.size, failed = stageData.numFailedTasks, + skipped = 0, total = s.numTasks)} {inputReadWithUnit} {outputWriteWithUnit} @@ -195,9 +181,10 @@ private[ui] class StageTableBase( private[ui] class FailedStageTable( stages: Seq[StageInfo], - parent: JobProgressTab, - killEnabled: Boolean = false) - extends StageTableBase(stages, parent, killEnabled) { + basePath: String, + listener: JobProgressListener, + isFairScheduler: Boolean) + extends StageTableBase(stages, basePath, listener, isFairScheduler, killEnabled = false) { override protected def columns: Seq[Node] = super.columns ++ Failure Reason diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala similarity index 83% rename from core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala rename to core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 03ca918e2e8b3..937261de00e3a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -19,18 +19,16 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest -import org.apache.spark.SparkConf import org.apache.spark.scheduler.SchedulingMode import org.apache.spark.ui.{SparkUI, SparkUITab} -/** Web UI showing progress status of all jobs in the given SparkContext. */ -private[ui] class JobProgressTab(parent: SparkUI) extends SparkUITab(parent, "stages") { +/** Web UI showing progress status of all stages in the given SparkContext. */ +private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages") { val sc = parent.sc - val conf = sc.map(_.conf).getOrElse(new SparkConf) - val killEnabled = sc.map(_.conf.getBoolean("spark.ui.killEnabled", true)).getOrElse(false) + val killEnabled = parent.killEnabled val listener = parent.jobProgressListener - attachPage(new JobProgressPage(this)) + attachPage(new AllStagesPage(this)) attachPage(new StagePage(this)) attachPage(new PoolPage(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 2f7d618df5f6f..48fd7caa1a1ed 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -40,9 +40,28 @@ private[jobs] object UIData { class JobUIData( var jobId: Int = -1, + var startTime: Option[Long] = None, + var endTime: Option[Long] = None, var stageIds: Seq[Int] = Seq.empty, var jobGroup: Option[String] = None, - var status: JobExecutionStatus = JobExecutionStatus.UNKNOWN + var status: JobExecutionStatus = JobExecutionStatus.UNKNOWN, + /* Tasks */ + // `numTasks` is a potential underestimate of the true number of tasks that this job will run. + // This may be an underestimate because the job start event references all of the result + // stages's transitive stage dependencies, but some of these stages might be skipped if their + // output is available from earlier runs. + // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. + var numTasks: Int = 0, + var numActiveTasks: Int = 0, + var numCompletedTasks: Int = 0, + var numSkippedTasks: Int = 0, + var numFailedTasks: Int = 0, + /* Stages */ + var numActiveStages: Int = 0, + // This needs to be a set instead of a simple count to prevent double-counting of rerun stages: + var completedStageIndices: OpenHashSet[Int] = new OpenHashSet[Int](), + var numSkippedStages: Int = 0, + var numFailedStages: Int = 0 ) class StageUIData { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 7e536edfe807b..7b5db1ed76265 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -31,6 +31,21 @@ import org.apache.spark.scheduler._ import org.apache.spark.storage._ import org.apache.spark._ +/** + * Serializes SparkListener events to/from JSON. This protocol provides strong backwards- + * and forwards-compatibility guarantees: any version of Spark should be able to read JSON output + * written by any other version, including newer versions. + * + * JsonProtocolSuite contains backwards-compatibility tests which check that the current version of + * JsonProtocol is able to read output written by earlier versions. We do not currently have tests + * for reading newer JSON output with older Spark versions. + * + * To ensure that we provide these guarantees, follow these rules when modifying these methods: + * + * - Never delete any JSON fields. + * - Any new JSON fields should be optional; use `Utils.jsonOption` when reading these fields + * in `*FromJson` methods. + */ private[spark] object JsonProtocol { // TODO: Remove this file and put JSON serialization into each individual class. @@ -121,6 +136,7 @@ private[spark] object JsonProtocol { val properties = propertiesToJson(jobStart.properties) ("Event" -> Utils.getFormattedClassName(jobStart)) ~ ("Job ID" -> jobStart.jobId) ~ + ("Stage Infos" -> jobStart.stageInfos.map(stageInfoToJson)) ~ // Added in Spark 1.2.0 ("Stage IDs" -> jobStart.stageIds) ~ ("Properties" -> properties) } @@ -455,7 +471,12 @@ private[spark] object JsonProtocol { val jobId = (json \ "Job ID").extract[Int] val stageIds = (json \ "Stage IDs").extract[List[JValue]].map(_.extract[Int]) val properties = propertiesFromJson(json \ "Properties") - SparkListenerJobStart(jobId, stageIds, properties) + // The "Stage Infos" field was added in Spark 1.2.0 + val stageInfos = Utils.jsonOption(json \ "Stage Infos") + .map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse { + stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, "unknown")) + } + SparkListenerJobStart(jobId, stageInfos, properties) } def jobEndFromJson(json: JValue): SparkListenerJobEnd = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index c617ff5c51d04..15bda1c9cc29c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -205,6 +205,13 @@ private[spark] class ExternalSorter[K, V, C]( map.changeValue((getPartition(kv._1), kv._1), update) maybeSpillCollection(usingMap = true) } + } else if (bypassMergeSort) { + // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies + if (records.hasNext) { + spillToPartitionFiles(records.map { kv => + ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) + }) + } } else { // Stick values into our buffer while (records.hasNext) { @@ -336,6 +343,10 @@ private[spark] class ExternalSorter[K, V, C]( * @param collection whichever collection we're using (map or buffer) */ private def spillToPartitionFiles(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { + spillToPartitionFiles(collection.iterator) + } + + private def spillToPartitionFiles(iterator: Iterator[((Int, K), C)]): Unit = { assert(bypassMergeSort) // Create our file writers if we haven't done so yet @@ -350,9 +361,9 @@ private[spark] class ExternalSorter[K, V, C]( } } - val it = collection.iterator // No need to sort stuff, just write each element out - while (it.hasNext) { - val elem = it.next() + // No need to sort stuff, just write each element out + while (iterator.hasNext) { + val elem = iterator.next() val partitionId = elem._1._1 val key = elem._1._2 val value = elem._2 @@ -748,6 +759,12 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled context.taskMetrics.diskBytesSpilled += diskBytesSpilled + context.taskMetrics.shuffleWriteMetrics.filter(_ => bypassMergeSort).foreach { m => + if (curWriteMetrics != null) { + m.shuffleBytesWritten += curWriteMetrics.shuffleBytesWritten + m.shuffleWriteTime += curWriteMetrics.shuffleWriteTime + } + } lengths } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index cda942e15a704..85e5f9ab444b3 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -95,14 +95,14 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - // 10 partitions from 4 keys - val NUM_BLOCKS = 10 + // 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys + val NUM_BLOCKS = 201 val a = sc.parallelize(1 to 4, NUM_BLOCKS) val b = a.map(x => (x, x*2)) // NOTE: The default Java serializer doesn't create zero-sized blocks. // So, use Kryo - val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10)) + val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(NUM_BLOCKS)) .setSerializer(new KryoSerializer(conf)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId @@ -122,13 +122,13 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - // 10 partitions from 4 keys - val NUM_BLOCKS = 10 + // 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys + val NUM_BLOCKS = 201 val a = sc.parallelize(1 to 4, NUM_BLOCKS) val b = a.map(x => (x, x*2)) // NOTE: The default Java serializer should create zero-sized blocks - val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10)) + val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(NUM_BLOCKS)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 4) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 9e454ddcc52a6..1362022104195 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -82,7 +82,7 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { bytesWritable.set(inputArray, 0, 10) bytesWritable.set(inputArray, 0, 5) - val converter = SparkContext.bytesWritableConverter() + val converter = WritableConverter.bytesWritableConverter() val byteArray = converter.convert(bytesWritable) assert(byteArray.length === 5) diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index bacf6a16fc233..d2857b8b55664 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -17,16 +17,20 @@ package org.apache.spark.ui -import org.apache.spark.api.java.StorageLevels -import org.apache.spark.{SparkException, SparkConf, SparkContext} -import org.openqa.selenium.WebDriver +import scala.collection.JavaConversions._ + +import org.openqa.selenium.{By, WebDriver} import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ +import org.apache.spark._ +import org.apache.spark.SparkContext._ import org.apache.spark.LocalSparkContext._ +import org.apache.spark.api.java.StorageLevels +import org.apache.spark.shuffle.FetchFailedException /** * Selenium tests for the Spark Web UI. These tests are not run by default @@ -89,7 +93,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { sc.parallelize(1 to 10).map { x => throw new Exception()}.collect() } eventually(timeout(5 seconds), interval(50 milliseconds)) { - go to sc.ui.get.appUIAddress + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") find(id("active")).get.text should be("Active Stages (0)") find(id("failed")).get.text should be("Failed Stages (1)") } @@ -101,7 +105,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { sc.parallelize(1 to 10).map { x => unserializableObject}.collect() } eventually(timeout(5 seconds), interval(50 milliseconds)) { - go to sc.ui.get.appUIAddress + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") find(id("active")).get.text should be("Active Stages (0)") // The failure occurs before the stage becomes active, hence we should still show only one // failed stage, not two: @@ -109,4 +113,191 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { } } } + + test("spark.ui.killEnabled should properly control kill button display") { + def getSparkContext(killEnabled: Boolean): SparkContext = { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.ui.enabled", "true") + .set("spark.ui.killEnabled", killEnabled.toString) + new SparkContext(conf) + } + + def hasKillLink = find(className("kill-link")).isDefined + def runSlowJob(sc: SparkContext) { + sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() + } + + withSpark(getSparkContext(killEnabled = true)) { sc => + runSlowJob(sc) + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") + assert(hasKillLink) + } + } + + withSpark(getSparkContext(killEnabled = false)) { sc => + runSlowJob(sc) + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") + assert(!hasKillLink) + } + } + } + + test("jobs page should not display job group name unless some job was submitted in a job group") { + withSpark(newSparkContext()) { sc => + // If no job has been run in a job group, then "(Job Group)" should not appear in the header + sc.parallelize(Seq(1, 2, 3)).count() + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs") + val tableHeaders = findAll(cssSelector("th")).map(_.text).toSeq + tableHeaders should not contain "Job Id (Job Group)" + } + // Once at least one job has been run in a job group, then we should display the group name: + sc.setJobGroup("my-job-group", "my-job-group-description") + sc.parallelize(Seq(1, 2, 3)).count() + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs") + val tableHeaders = findAll(cssSelector("th")).map(_.text).toSeq + tableHeaders should contain ("Job Id (Job Group)") + } + } + } + + test("job progress bars should handle stage / task failures") { + withSpark(newSparkContext()) { sc => + val data = sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity) + val shuffleHandle = + data.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle + // Simulate fetch failures: + val mappedData = data.map { x => + val taskContext = TaskContext.get + if (taskContext.attemptId() == 1) { // Cause this stage to fail on its first attempt. + val env = SparkEnv.get + val bmAddress = env.blockManager.blockManagerId + val shuffleId = shuffleHandle.shuffleId + val mapId = 0 + val reduceId = taskContext.partitionId() + val message = "Simulated fetch failure" + throw new FetchFailedException(bmAddress, shuffleId, mapId, reduceId, message) + } else { + x + } + } + mappedData.count() + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs") + find(cssSelector(".stage-progress-cell")).get.text should be ("2/2 (1 failed)") + // Ideally, the following test would pass, but currently we overcount completed tasks + // if task recomputations occur: + // find(cssSelector(".progress-cell .progress")).get.text should be ("2/2 (1 failed)") + // Instead, we guarantee that the total number of tasks is always correct, while the number + // of completed tasks may be higher: + find(cssSelector(".progress-cell .progress")).get.text should be ("3/2 (1 failed)") + } + } + } + + test("job details page should display useful information for stages that haven't started") { + withSpark(newSparkContext()) { sc => + // Create a multi-stage job with a long delay in the first stage: + val rdd = sc.parallelize(Seq(1, 2, 3)).map { x => + // This long sleep call won't slow down the tests because we don't actually need to wait + // for the job to finish. + Thread.sleep(20000) + }.groupBy(identity).map(identity).groupBy(identity).map(identity) + // Start the job: + rdd.countAsync() + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs/job/?id=0") + find(id("active")).get.text should be ("Active Stages (1)") + find(id("pending")).get.text should be ("Pending Stages (2)") + // Essentially, we want to check that none of the stage rows show + // "No data available for this stage". Checking for the absence of that string is brittle + // because someone could change the error message and cause this test to pass by accident. + // Instead, it's safer to check that each row contains a link to a stage details page. + findAll(cssSelector("tbody tr")).foreach { row => + val link = row.underlying.findElement(By.xpath(".//a")) + link.getAttribute("href") should include ("stage") + } + } + } + } + + test("job progress bars / cells reflect skipped stages / tasks") { + withSpark(newSparkContext()) { sc => + // Create an RDD that involves multiple stages: + val rdd = sc.parallelize(1 to 8, 8) + .map(x => x).groupBy((x: Int) => x, numPartitions = 8) + .flatMap(x => x._2).groupBy((x: Int) => x, numPartitions = 8) + // Run it twice; this will cause the second job to have two "phantom" stages that were + // mentioned in its job start event but which were never actually executed: + rdd.count() + rdd.count() + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs") + // The completed jobs table should have two rows. The first row will be the most recent job: + val firstRow = find(cssSelector("tbody tr")).get.underlying + val firstRowColumns = firstRow.findElements(By.tagName("td")) + firstRowColumns(0).getText should be ("1") + firstRowColumns(4).getText should be ("1/1 (2 skipped)") + firstRowColumns(5).getText should be ("8/8 (16 skipped)") + // The second row is the first run of the job, where nothing was skipped: + val secondRow = findAll(cssSelector("tbody tr")).toSeq(1).underlying + val secondRowColumns = secondRow.findElements(By.tagName("td")) + secondRowColumns(0).getText should be ("0") + secondRowColumns(4).getText should be ("3/3") + secondRowColumns(5).getText should be ("24/24") + } + } + } + + test("stages that aren't run appear as 'skipped stages' after a job finishes") { + withSpark(newSparkContext()) { sc => + // Create an RDD that involves multiple stages: + val rdd = + sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity) + // Run it twice; this will cause the second job to have two "phantom" stages that were + // mentioned in its job start event but which were never actually executed: + rdd.count() + rdd.count() + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs/job/?id=1") + find(id("pending")) should be (None) + find(id("active")) should be (None) + find(id("failed")) should be (None) + find(id("completed")).get.text should be ("Completed Stages (1)") + find(id("skipped")).get.text should be ("Skipped Stages (2)") + // Essentially, we want to check that none of the stage rows show + // "No data available for this stage". Checking for the absence of that string is brittle + // because someone could change the error message and cause this test to pass by accident. + // Instead, it's safer to check that each row contains a link to a stage details page. + findAll(cssSelector("tbody tr")).foreach { row => + val link = row.underlying.findElement(By.xpath(".//a")) + link.getAttribute("href") should include ("stage") + } + } + } + } + + test("jobs with stages that are skipped should show correct link descriptions on all jobs page") { + withSpark(newSparkContext()) { sc => + // Create an RDD that involves multiple stages: + val rdd = + sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity) + // Run it twice; this will cause the second job to have two "phantom" stages that were + // mentioned in its job start event but which were never actually executed: + rdd.count() + rdd.count() + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs") + findAll(cssSelector("tbody tr a")).foreach { link => + link.text.toLowerCase should include ("count") + link.text.toLowerCase should not include "unknown" + } + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 15c5b4e702efa..12af60caf7d54 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -43,7 +43,10 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc } private def createJobStartEvent(jobId: Int, stageIds: Seq[Int]) = { - SparkListenerJobStart(jobId, stageIds) + val stageInfos = stageIds.map { stageId => + new StageInfo(stageId, 0, stageId.toString, 0, null, "") + } + SparkListenerJobStart(jobId, stageInfos) } private def createJobEndEvent(jobId: Int, failed: Boolean = false) = { @@ -52,8 +55,9 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc } private def runJob(listener: SparkListener, jobId: Int, shouldFail: Boolean = false) { + val stagesThatWontBeRun = jobId * 200 to jobId * 200 + 10 val stageIds = jobId * 100 to jobId * 100 + 50 - listener.onJobStart(createJobStartEvent(jobId, stageIds)) + listener.onJobStart(createJobStartEvent(jobId, stageIds ++ stagesThatWontBeRun)) for (stageId <- stageIds) { listener.onStageSubmitted(createStageStartEvent(stageId)) listener.onStageCompleted(createStageEndEvent(stageId, failed = stageId % 2 == 0)) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 50f42054b9296..0bc9492675863 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -47,7 +47,12 @@ class JsonProtocolSuite extends FunSuite { val taskEndWithOutput = SparkListenerTaskEnd(1, 0, "ResultTask", Success, makeTaskInfo(123L, 234, 67, 345L, false), makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = true, hasOutput = true)) - val jobStart = SparkListenerJobStart(10, Seq[Int](1, 2, 3, 4), properties) + val jobStart = { + val stageIds = Seq[Int](1, 2, 3, 4) + val stageInfos = stageIds.map(x => + makeStageInfo(x, x * 200, x * 300, x * 400L, x * 500L)) + SparkListenerJobStart(10, stageInfos, properties) + } val jobEnd = SparkListenerJobEnd(20, JobSucceeded) val environmentUpdate = SparkListenerEnvironmentUpdate(Map[String, Seq[(String, String)]]( "JVM Information" -> Seq(("GC speed", "9999 objects/s"), ("Java home", "Land of coffee")), @@ -224,6 +229,19 @@ class JsonProtocolSuite extends FunSuite { assert(expectedExecutorLostFailure === JsonProtocol.taskEndReasonFromJson(oldEvent)) } + test("SparkListenerJobStart backward compatibility") { + // Prior to Spark 1.2.0, SparkListenerJobStart did not have a "Stage Infos" property. + val stageIds = Seq[Int](1, 2, 3, 4) + val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400, x * 500)) + val dummyStageInfos = + stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, "unknown")) + val jobStart = SparkListenerJobStart(10, stageInfos, properties) + val oldEvent = JsonProtocol.jobStartToJson(jobStart).removeField({_._1 == "Stage Infos"}) + val expectedJobStart = + SparkListenerJobStart(10, dummyStageInfos, properties) + assertEquals(expectedJobStart, JsonProtocol.jobStartFromJson(oldEvent)) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ @@ -306,7 +324,7 @@ class JsonProtocolSuite extends FunSuite { case (e1: SparkListenerJobStart, e2: SparkListenerJobStart) => assert(e1.jobId === e2.jobId) assert(e1.properties === e2.properties) - assertSeqEquals(e1.stageIds, e2.stageIds, (i1: Int, i2: Int) => assert(i1 === i2)) + assert(e1.stageIds === e2.stageIds) case (e1: SparkListenerJobEnd, e2: SparkListenerJobEnd) => assert(e1.jobId === e2.jobId) assertEquals(e1.jobResult, e2.jobResult) @@ -1051,6 +1069,260 @@ class JsonProtocolSuite extends FunSuite { |{ | "Event": "SparkListenerJobStart", | "Job ID": 10, + | "Stage Infos": [ + | { + | "Stage ID": 1, + | "Stage Attempt ID": 0, + | "Stage Name": "greetings", + | "Number of Tasks": 200, + | "RDD Info": [ + | { + | "RDD ID": 1, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 200, + | "Number of Cached Partitions": 300, + | "Memory Size": 400, + | "Tachyon Size": 0, + | "Disk Size": 500 + | } + | ], + | "Details": "details", + | "Accumulables": [ + | { + | "ID": 2, + | "Name": " Accumulable 2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 1, + | "Name": " Accumulable 1", + | "Update": "delta1", + | "Value": "val1" + | } + | ] + | }, + | { + | "Stage ID": 2, + | "Stage Attempt ID": 0, + | "Stage Name": "greetings", + | "Number of Tasks": 400, + | "RDD Info": [ + | { + | "RDD ID": 2, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 400, + | "Number of Cached Partitions": 600, + | "Memory Size": 800, + | "Tachyon Size": 0, + | "Disk Size": 1000 + | }, + | { + | "RDD ID": 3, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 401, + | "Number of Cached Partitions": 601, + | "Memory Size": 801, + | "Tachyon Size": 0, + | "Disk Size": 1001 + | } + | ], + | "Details": "details", + | "Accumulables": [ + | { + | "ID": 2, + | "Name": " Accumulable 2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 1, + | "Name": " Accumulable 1", + | "Update": "delta1", + | "Value": "val1" + | } + | ] + | }, + | { + | "Stage ID": 3, + | "Stage Attempt ID": 0, + | "Stage Name": "greetings", + | "Number of Tasks": 600, + | "RDD Info": [ + | { + | "RDD ID": 3, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 600, + | "Number of Cached Partitions": 900, + | "Memory Size": 1200, + | "Tachyon Size": 0, + | "Disk Size": 1500 + | }, + | { + | "RDD ID": 4, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 601, + | "Number of Cached Partitions": 901, + | "Memory Size": 1201, + | "Tachyon Size": 0, + | "Disk Size": 1501 + | }, + | { + | "RDD ID": 5, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 602, + | "Number of Cached Partitions": 902, + | "Memory Size": 1202, + | "Tachyon Size": 0, + | "Disk Size": 1502 + | } + | ], + | "Details": "details", + | "Accumulables": [ + | { + | "ID": 2, + | "Name": " Accumulable 2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 1, + | "Name": " Accumulable 1", + | "Update": "delta1", + | "Value": "val1" + | } + | ] + | }, + | { + | "Stage ID": 4, + | "Stage Attempt ID": 0, + | "Stage Name": "greetings", + | "Number of Tasks": 800, + | "RDD Info": [ + | { + | "RDD ID": 4, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 800, + | "Number of Cached Partitions": 1200, + | "Memory Size": 1600, + | "Tachyon Size": 0, + | "Disk Size": 2000 + | }, + | { + | "RDD ID": 5, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 801, + | "Number of Cached Partitions": 1201, + | "Memory Size": 1601, + | "Tachyon Size": 0, + | "Disk Size": 2001 + | }, + | { + | "RDD ID": 6, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 802, + | "Number of Cached Partitions": 1202, + | "Memory Size": 1602, + | "Tachyon Size": 0, + | "Disk Size": 2002 + | }, + | { + | "RDD ID": 7, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 803, + | "Number of Cached Partitions": 1203, + | "Memory Size": 1603, + | "Tachyon Size": 0, + | "Disk Size": 2003 + | } + | ], + | "Details": "details", + | "Accumulables": [ + | { + | "ID": 2, + | "Name": " Accumulable 2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 1, + | "Name": " Accumulable 1", + | "Update": "delta1", + | "Value": "val1" + | } + | ] + | } + | ], | "Stage IDs": [ | 1, | 2, diff --git a/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala b/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala new file mode 100644 index 0000000000000..4918e2d92beb4 --- /dev/null +++ b/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.sparktest + +/** + * A test suite to make sure all `implicit` functions work correctly. + * Please don't `import org.apache.spark.SparkContext._` in this class. + * + * As `implicit` is a compiler feature, we don't need to run this class. + * What we need to do is making the compiler happy. + */ +class ImplicitSuite { + + // We only want to test if `implict` works well with the compiler, so we don't need a real + // SparkContext. + def mockSparkContext[T]: org.apache.spark.SparkContext = null + + // We only want to test if `implict` works well with the compiler, so we don't need a real RDD. + def mockRDD[T]: org.apache.spark.rdd.RDD[T] = null + + def testRddToPairRDDFunctions(): Unit = { + val rdd: org.apache.spark.rdd.RDD[(Int, Int)] = mockRDD + rdd.groupByKey() + } + + def testRddToAsyncRDDActions(): Unit = { + val rdd: org.apache.spark.rdd.RDD[Int] = mockRDD + rdd.countAsync() + } + + def testRddToSequenceFileRDDFunctions(): Unit = { + // TODO eliminating `import intToIntWritable` needs refactoring SequenceFileRDDFunctions. + // That will be a breaking change. + import org.apache.spark.SparkContext.intToIntWritable + val rdd: org.apache.spark.rdd.RDD[(Int, Int)] = mockRDD + rdd.saveAsSequenceFile("/a/test/path") + } + + def testRddToOrderedRDDFunctions(): Unit = { + val rdd: org.apache.spark.rdd.RDD[(Int, Int)] = mockRDD + rdd.sortByKey() + } + + def testDoubleRDDToDoubleRDDFunctions(): Unit = { + val rdd: org.apache.spark.rdd.RDD[Double] = mockRDD + rdd.stats() + } + + def testNumericRDDToDoubleRDDFunctions(): Unit = { + val rdd: org.apache.spark.rdd.RDD[Int] = mockRDD + rdd.stats() + } + + def testDoubleAccumulatorParam(): Unit = { + val sc = mockSparkContext + sc.accumulator(123.4) + } + + def testIntAccumulatorParam(): Unit = { + val sc = mockSparkContext + sc.accumulator(123) + } + + def testLongAccumulatorParam(): Unit = { + val sc = mockSparkContext + sc.accumulator(123L) + } + + def testFloatAccumulatorParam(): Unit = { + val sc = mockSparkContext + sc.accumulator(123F) + } + + def testIntWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Int, Int]("/a/test/path") + } + + def testLongWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Long, Long]("/a/test/path") + } + + def testDoubleWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Double, Double]("/a/test/path") + } + + def testFloatWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Float, Float]("/a/test/path") + } + + def testBooleanWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Boolean, Boolean]("/a/test/path") + } + + def testBytesWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Array[Byte], Array[Byte]]("/a/test/path") + } + + def testStringWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[String, String]("/a/test/path") + } + + def testWritableWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[org.apache.hadoop.io.Text, org.apache.hadoop.io.Text]("/a/test/path") + } +} diff --git a/docs/README.md b/docs/README.md index d2d58e435d4c4..119484038083f 100644 --- a/docs/README.md +++ b/docs/README.md @@ -43,7 +43,7 @@ You can modify the default Jekyll build as follows: ## Pygments We also use pygments (http://pygments.org) for syntax highlighting in documentation markdown pages, -so you will also need to install that (it requires Python) by running `sudo easy_install Pygments`. +so you will also need to install that (it requires Python) by running `sudo pip install Pygments`. To mark a block of code in your markdown to be syntax highlighted by jekyll during the compile phase, use the following sytax: @@ -53,6 +53,11 @@ phase, use the following sytax: // supported languages too. {% endhighlight %} +## Sphinx + +We use Sphinx to generate Python API docs, so you will need to install it by running +`sudo pip install sphinx`. + ## API Docs (Scaladoc and Sphinx) You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory. diff --git a/docs/building-spark.md b/docs/building-spark.md index bb18414092aae..fee6a8440634c 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -92,8 +92,11 @@ mvn -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -DskipTests clean package # Apache Hadoop 2.3.X mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -DskipTests clean package -# Apache Hadoop 2.4.X -mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package +# Apache Hadoop 2.4.X or 2.5.X +mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=VERSION -DskipTests clean package + +Versions of Hadoop after 2.5.X may or may not work with the -Phadoop-2.4 profile (they were +released after this version of Spark). # Different versions of HDFS and YARN. mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -DskipTests clean package diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 28bb98175188a..e298c51f8a5b7 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -57,77 +57,15 @@ title: GraphX Programming Guide # Overview -GraphX is the new (alpha) Spark API for graphs and graph-parallel computation. At a high level, -GraphX extends the Spark [RDD](api/scala/index.html#org.apache.spark.rdd.RDD) by introducing the -[Resilient Distributed Property Graph](#property_graph): a directed multigraph with properties +GraphX is a new component in Spark for graphs and graph-parallel computation. At a high level, +GraphX extends the Spark [RDD](api/scala/index.html#org.apache.spark.rdd.RDD) by introducing a +new [Graph](#property_graph) abstraction: a directed multigraph with properties attached to each vertex and edge. To support graph computation, GraphX exposes a set of fundamental operators (e.g., [subgraph](#structural_operators), [joinVertices](#join_operators), and -[aggregateMessages](#aggregateMessages)) as well as an optimized variant of the [Pregel](#pregel) API. In -addition, GraphX includes a growing collection of graph [algorithms](#graph_algorithms) and +[aggregateMessages](#aggregateMessages)) as well as an optimized variant of the [Pregel](#pregel) API. In addition, GraphX includes a growing collection of graph [algorithms](#graph_algorithms) and [builders](#graph_builders) to simplify graph analytics tasks. -## Motivation - -From social networks to language modeling, the growing scale and importance of -graph data has driven the development of numerous new *graph-parallel* systems -(e.g., [Giraph](http://giraph.apache.org) and -[GraphLab](http://graphlab.org)). By restricting the types of computation that can be -expressed and introducing new techniques to partition and distribute graphs, -these systems can efficiently execute sophisticated graph algorithms orders of -magnitude faster than more general *data-parallel* systems. - -

- Data-Parallel vs. Graph-Parallel - -

- -However, the same restrictions that enable these substantial performance gains also make it -difficult to express many of the important stages in a typical graph-analytics pipeline: -constructing the graph, modifying its structure, or expressing computation that spans multiple -graphs. Furthermore, how we look at data depends on our objectives and the same raw data may have -many different table and graph views. - -

- Tables and Graphs - -

- -As a consequence, it is often necessary to be able to move between table and graph views. -However, existing graph analytics pipelines must compose graph-parallel and data- -parallel systems, leading to extensive data movement and duplication and a complicated programming -model. - -

- Graph Analytics Pipeline - -

- -The goal of the GraphX project is to unify graph-parallel and data-parallel computation in one -system with a single composable API. The GraphX API enables users to view data both as a graph and -as collections (i.e., RDDs) without data movement or duplication. By incorporating recent advances -in graph-parallel systems, GraphX is able to optimize the execution of graph operations. - - - ## Migrating from Spark 1.1 GraphX in Spark {{site.SPARK_VERSION}} contains a few user facing API changes: @@ -174,7 +112,7 @@ identifiers. The property graph is parameterized over the vertex (`VD`) and edge (`ED`) types. These are the types of the objects associated with each vertex and edge respectively. -> GraphX optimizes the representation of vertex and edge types when they are plain old data types +> GraphX optimizes the representation of vertex and edge types when they are primitive data types > (e.g., int, double, etc...) reducing the in memory footprint by storing them in specialized > arrays. @@ -791,14 +729,13 @@ Graphs are inherently recursive data structures as properties of vertices depend their neighbors which in turn depend on properties of *their* neighbors. As a consequence many important graph algorithms iteratively recompute the properties of each vertex until a fixed-point condition is reached. A range of graph-parallel abstractions have been proposed -to express these iterative algorithms. GraphX exposes a Pregel-like operator which is a fusion of -the widely used Pregel and GraphLab abstractions. +to express these iterative algorithms. GraphX exposes a variant of the Pregel API. At a high level the Pregel operator in GraphX is a bulk-synchronous parallel messaging abstraction *constrained to the topology of the graph*. The Pregel operator executes in a series of super steps in which vertices receive the *sum* of their inbound messages from the previous super step, compute a new value for the vertex property, and then send messages to neighboring vertices in the next -super step. Unlike Pregel and instead more like GraphLab messages are computed in parallel as a +super step. Unlike Pregel, messages are computed in parallel as a function of the edge triplet and the message computation has access to both the source and destination vertex attributes. Vertices that do not receive a message are skipped within a super step. The Pregel operators terminates iteration and returns the final graph when there are no diff --git a/docs/img/data_parallel_vs_graph_parallel.png b/docs/img/data_parallel_vs_graph_parallel.png deleted file mode 100644 index d3918f01d8f3b..0000000000000 Binary files a/docs/img/data_parallel_vs_graph_parallel.png and /dev/null differ diff --git a/docs/img/graph_analytics_pipeline.png b/docs/img/graph_analytics_pipeline.png deleted file mode 100644 index 6d606e01894ae..0000000000000 Binary files a/docs/img/graph_analytics_pipeline.png and /dev/null differ diff --git a/docs/img/tables_and_graphs.png b/docs/img/tables_and_graphs.png deleted file mode 100644 index ec37bb45a62f0..0000000000000 Binary files a/docs/img/tables_and_graphs.png and /dev/null differ diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 0c52ef8ed96ac..227acc117502d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -27,6 +27,7 @@ object HiveFromSpark { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("HiveFromSpark") val sc = new SparkContext(sparkConf) + val path = s"${System.getenv("SPARK_HOME")}/examples/src/main/resources/kv1.txt" // A local hive context creates an instance of the Hive Metastore in process, storing // the warehouse data in the current directory. This location can be overridden by @@ -35,7 +36,7 @@ object HiveFromSpark { import hiveContext._ sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - sql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src") + sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE src") // Queries are expressed in HiveQL println("Result of 'SELECT *': ") diff --git a/make-distribution.sh b/make-distribution.sh index 2267b1aa08a6c..7c0fb8992a155 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -119,7 +119,7 @@ VERSION=$(mvn help:evaluate -Dexpression=project.version 2>/dev/null | grep -v " SPARK_HADOOP_VERSION=$(mvn help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ | grep -v "INFO"\ | tail -n 1) -SPARK_HIVE=$(mvn help:evaluate -Dexpression=project.activeProfiles $@ 2>/dev/null\ +SPARK_HIVE=$(mvn help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ | grep -v "INFO"\ | fgrep --count "hive";\ # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index b6f7618171224..f04df1c156898 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -74,10 +74,28 @@ class PythonMLLibAPI extends Serializable { learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], initialWeights: Vector): JList[Object] = { - // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. - learner.disableUncachedWarning() - val model = learner.run(data.rdd, initialWeights) - List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + try { + val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights) + List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + } finally { + data.rdd.unpersist(blocking = false) + } + } + + /** + * Return the Updater from string + */ + def getUpdaterFromString(regType: String): Updater = { + if (regType == "l2") { + new SquaredL2Updater + } else if (regType == "l1") { + new L1Updater + } else if (regType == null || regType == "none") { + new SimpleUpdater + } else { + throw new IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: ['l1', 'l2', None].") + } } /** @@ -99,16 +117,7 @@ class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) - if (regType == "l2") { - lrAlg.optimizer.setUpdater(new SquaredL2Updater) - } else if (regType == "l1") { - lrAlg.optimizer.setUpdater(new L1Updater) - } else if (regType == null) { - lrAlg.optimizer.setUpdater(new SimpleUpdater) - } else { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: ['l1', 'l2', None].") - } + lrAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( lrAlg, data, @@ -178,16 +187,7 @@ class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) - if (regType == "l2") { - SVMAlg.optimizer.setUpdater(new SquaredL2Updater) - } else if (regType == "l1") { - SVMAlg.optimizer.setUpdater(new L1Updater) - } else if (regType == null) { - SVMAlg.optimizer.setUpdater(new SimpleUpdater) - } else { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: ['l1', 'l2', None].") - } + SVMAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( SVMAlg, data, @@ -213,16 +213,7 @@ class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) - if (regType == "l2") { - LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) - } else if (regType == "l1") { - LogRegAlg.optimizer.setUpdater(new L1Updater) - } else if (regType == null) { - LogRegAlg.optimizer.setUpdater(new SimpleUpdater) - } else { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: ['l1', 'l2', None].") - } + LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( LogRegAlg, data, @@ -248,16 +239,7 @@ class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setNumCorrections(corrections) .setConvergenceTol(tolerance) - if (regType == "l2") { - LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) - } else if (regType == "l1") { - LogRegAlg.optimizer.setUpdater(new L1Updater) - } else if (regType == null) { - LogRegAlg.optimizer.setUpdater(new SimpleUpdater) - } else { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: ['l1', 'l2', None].") - } + LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( LogRegAlg, data, @@ -289,9 +271,11 @@ class PythonMLLibAPI extends Serializable { .setMaxIterations(maxIterations) .setRuns(runs) .setInitializationMode(initializationMode) - // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. - .disableUncachedWarning() - kMeansAlg.run(data.rdd) + try { + kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) + } finally { + data.rdd.unpersist(blocking = false) + } } /** @@ -425,16 +409,18 @@ class PythonMLLibAPI extends Serializable { numPartitions: Int, numIterations: Int, seed: Long): Word2VecModelWrapper = { - val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER) val word2vec = new Word2Vec() .setVectorSize(vectorSize) .setLearningRate(learningRate) .setNumPartitions(numPartitions) .setNumIterations(numIterations) .setSeed(seed) - val model = word2vec.fit(data) - data.unpersist() - new Word2VecModelWrapper(model) + try { + val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)) + new Word2VecModelWrapper(model) + } finally { + dataJRDD.rdd.unpersist(blocking = false) + } } private[python] class Word2VecModelWrapper(model: Word2VecModel) { @@ -495,8 +481,11 @@ class PythonMLLibAPI extends Serializable { categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap, minInstancesPerNode = minInstancesPerNode, minInfoGain = minInfoGain) - - DecisionTree.train(data.rdd, strategy) + try { + DecisionTree.train(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), strategy) + } finally { + data.rdd.unpersist(blocking = false) + } } /** @@ -526,10 +515,15 @@ class PythonMLLibAPI extends Serializable { numClassesForClassification = numClasses, maxBins = maxBins, categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap) - if (algo == Algo.Classification) { - RandomForest.trainClassifier(data.rdd, strategy, numTrees, featureSubsetStrategy, seed) - } else { - RandomForest.trainRegressor(data.rdd, strategy, numTrees, featureSubsetStrategy, seed) + val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK) + try { + if (algo == Algo.Classification) { + RandomForest.trainClassifier(cached, strategy, numTrees, featureSubsetStrategy, seed) + } else { + RandomForest.trainRegressor(cached, strategy, numTrees, featureSubsetStrategy, seed) + } + } finally { + cached.unpersist(blocking = false) } } @@ -711,7 +705,7 @@ private[spark] object SerDe extends Serializable { def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { if (obj == this) { out.write(Opcodes.GLOBAL) - out.write((module + "\n" + name + "\n").getBytes()) + out.write((module + "\n" + name + "\n").getBytes) } else { pickler.save(this) // it will be memorized by Pickler saveState(obj, out, pickler) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 7443f232ec3e7..34ea0de706f08 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -113,22 +113,13 @@ class KMeans private ( this } - /** Whether a warning should be logged if the input RDD is uncached. */ - private var warnOnUncachedInput = true - - /** Disable warnings about uncached input. */ - private[spark] def disableUncachedWarning(): this.type = { - warnOnUncachedInput = false - this - } - /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. */ def run(data: RDD[Vector]): KMeansModel = { - if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) { + if (data.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } @@ -143,7 +134,7 @@ class KMeans private ( norms.unpersist() // Warn at the end of the run as well, for increased visibility. - if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) { + if (data.getStorageLevel == StorageLevel.NONE) { logWarning("The input data was not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 60ab2aaa8f27a..c6d5fe5bc678c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -76,6 +76,15 @@ sealed trait Vector extends Serializable { def copy: Vector = { throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") } + + /** + * Applies a function `f` to all the active elements of dense and sparse vector. + * + * @param f the function takes two parameters where the first parameter is the index of + * the vector with type `Int`, and the second parameter is the corresponding value + * with type `Double`. + */ + private[spark] def foreachActive(f: (Int, Double) => Unit) } /** @@ -273,6 +282,17 @@ class DenseVector(val values: Array[Double]) extends Vector { override def copy: DenseVector = { new DenseVector(values.clone()) } + + private[spark] override def foreachActive(f: (Int, Double) => Unit) = { + var i = 0 + val localValuesSize = values.size + val localValues = values + + while (i < localValuesSize) { + f(i, localValues(i)) + i += 1 + } + } } /** @@ -309,4 +329,16 @@ class SparseVector( } private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) + + private[spark] override def foreachActive(f: (Int, Double) => Unit) = { + var i = 0 + val localValuesSize = values.size + val localIndices = indices + val localValues = values + + while (i < localValuesSize) { + f(localIndices(i), localValues(i)) + i += 1 + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 00dfc86c9e0bd..0287f04e2c777 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -136,15 +136,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] this } - /** Whether a warning should be logged if the input RDD is uncached. */ - private var warnOnUncachedInput = true - - /** Disable warnings about uncached input. */ - private[spark] def disableUncachedWarning(): this.type = { - warnOnUncachedInput = false - this - } - /** * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. @@ -161,7 +152,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { - if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) { + if (input.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } @@ -241,7 +232,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] } // Warn at the end of the run as well, for increased visibility. - if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) { + if (input.getStorageLevel == StorageLevel.NONE) { logWarning("The input data was not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 654479ac2dd4f..fcc2a148791bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -17,10 +17,8 @@ package org.apache.spark.mllib.stat -import breeze.linalg.{DenseVector => BDV} - import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vectors, Vector} /** * :: DeveloperApi :: @@ -40,37 +38,14 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable { private var n = 0 - private var currMean: BDV[Double] = _ - private var currM2n: BDV[Double] = _ - private var currM2: BDV[Double] = _ - private var currL1: BDV[Double] = _ + private var currMean: Array[Double] = _ + private var currM2n: Array[Double] = _ + private var currM2: Array[Double] = _ + private var currL1: Array[Double] = _ private var totalCnt: Long = 0 - private var nnz: BDV[Double] = _ - private var currMax: BDV[Double] = _ - private var currMin: BDV[Double] = _ - - /** - * Adds input value to position i. - */ - private[this] def add(i: Int, value: Double) = { - if (value != 0.0) { - if (currMax(i) < value) { - currMax(i) = value - } - if (currMin(i) > value) { - currMin(i) = value - } - - val prevMean = currMean(i) - val diff = value - prevMean - currMean(i) = prevMean + diff / (nnz(i) + 1.0) - currM2n(i) += (value - currMean(i)) * diff - currM2(i) += value * value - currL1(i) += math.abs(value) - - nnz(i) += 1.0 - } - } + private var nnz: Array[Double] = _ + private var currMax: Array[Double] = _ + private var currMin: Array[Double] = _ /** * Add a new sample to this summarizer, and update the statistical summary. @@ -83,33 +58,36 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(sample.size > 0, s"Vector should have dimension larger than zero.") n = sample.size - currMean = BDV.zeros[Double](n) - currM2n = BDV.zeros[Double](n) - currM2 = BDV.zeros[Double](n) - currL1 = BDV.zeros[Double](n) - nnz = BDV.zeros[Double](n) - currMax = BDV.fill(n)(Double.MinValue) - currMin = BDV.fill(n)(Double.MaxValue) + currMean = Array.ofDim[Double](n) + currM2n = Array.ofDim[Double](n) + currM2 = Array.ofDim[Double](n) + currL1 = Array.ofDim[Double](n) + nnz = Array.ofDim[Double](n) + currMax = Array.fill[Double](n)(Double.MinValue) + currMin = Array.fill[Double](n)(Double.MaxValue) } require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") - sample match { - case dv: DenseVector => { - var j = 0 - while (j < dv.size) { - add(j, dv.values(j)) - j += 1 + sample.foreachActive { (index, value) => + if (value != 0.0) { + if (currMax(index) < value) { + currMax(index) = value } - } - case sv: SparseVector => - var j = 0 - while (j < sv.indices.size) { - add(sv.indices(j), sv.values(j)) - j += 1 + if (currMin(index) > value) { + currMin(index) = value } - case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + + val prevMean = currMean(index) + val diff = value - prevMean + currMean(index) = prevMean + diff / (nnz(index) + 1.0) + currM2n(index) += (value - currMean(index)) * diff + currM2(index) += value * value + currL1(index) += math.abs(value) + + nnz(index) += 1.0 + } } totalCnt += 1 @@ -152,14 +130,14 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } } else if (totalCnt == 0 && other.totalCnt != 0) { this.n = other.n - this.currMean = other.currMean.copy - this.currM2n = other.currM2n.copy - this.currM2 = other.currM2.copy - this.currL1 = other.currL1.copy + this.currMean = other.currMean.clone + this.currM2n = other.currM2n.clone + this.currM2 = other.currM2.clone + this.currL1 = other.currL1.clone this.totalCnt = other.totalCnt - this.nnz = other.nnz.copy - this.currMax = other.currMax.copy - this.currMin = other.currMin.copy + this.nnz = other.nnz.clone + this.currMax = other.currMax.clone + this.currMin = other.currMin.clone } this } @@ -167,19 +145,19 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S override def mean: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - val realMean = BDV.zeros[Double](n) + val realMean = Array.ofDim[Double](n) var i = 0 while (i < n) { realMean(i) = currMean(i) * (nnz(i) / totalCnt) i += 1 } - Vectors.fromBreeze(realMean) + Vectors.dense(realMean) } override def variance: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - val realVariance = BDV.zeros[Double](n) + val realVariance = Array.ofDim[Double](n) val denominator = totalCnt - 1.0 @@ -194,8 +172,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S i += 1 } } - - Vectors.fromBreeze(realVariance) + Vectors.dense(realVariance) } override def count: Long = totalCnt @@ -203,7 +180,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S override def numNonzeros: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - Vectors.fromBreeze(nnz) + Vectors.dense(nnz) } override def max: Vector = { @@ -214,7 +191,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 i += 1 } - Vectors.fromBreeze(currMax) + Vectors.dense(currMax) } override def min: Vector = { @@ -225,25 +202,25 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 i += 1 } - Vectors.fromBreeze(currMin) + Vectors.dense(currMin) } override def normL2: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - val realMagnitude = BDV.zeros[Double](n) + val realMagnitude = Array.ofDim[Double](n) var i = 0 while (i < currM2.size) { realMagnitude(i) = math.sqrt(currM2(i)) i += 1 } - - Vectors.fromBreeze(realMagnitude) + Vectors.dense(realMagnitude) } override def normL1: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - Vectors.fromBreeze(currL1) + + Vectors.dense(currL1) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 59cd85eab27d0..9492f604af4d5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -173,4 +173,28 @@ class VectorsSuite extends FunSuite { val v = Vectors.fromBreeze(x(::, 0)) assert(v.size === x.rows) } + + test("foreachActive") { + val dv = Vectors.dense(0.0, 1.2, 3.1, 0.0) + val sv = Vectors.sparse(4, Seq((1, 1.2), (2, 3.1), (3, 0.0))) + + val dvMap = scala.collection.mutable.Map[Int, Double]() + dv.foreachActive { (index, value) => + dvMap.put(index, value) + } + assert(dvMap.size === 4) + assert(dvMap.get(0) === Some(0.0)) + assert(dvMap.get(1) === Some(1.2)) + assert(dvMap.get(2) === Some(3.1)) + assert(dvMap.get(3) === Some(0.0)) + + val svMap = scala.collection.mutable.Map[Int, Double]() + sv.foreachActive { (index, value) => + svMap.put(index, value) + } + assert(svMap.size === 3) + assert(svMap.get(1) === Some(1.2)) + assert(svMap.get(2) === Some(3.1)) + assert(svMap.get(3) === Some(0.0)) + } } diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index fe4c4cc5094d8..e2492eef5bd6a 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -16,7 +16,7 @@ # from pyspark import SparkContext -from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _to_java_object_rdd +from pyspark.mllib.common import callMLlibFunc, callJavaFunc from pyspark.mllib.linalg import SparseVector, _convert_to_vector __all__ = ['KMeansModel', 'KMeans'] @@ -80,10 +80,8 @@ class KMeans(object): @classmethod def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"): """Train a k-means clustering model.""" - # cache serialized data to avoid objects over head in JVM - jcached = _to_java_object_rdd(rdd.map(_convert_to_vector), cache=True) - model = callMLlibFunc("trainKMeansModel", jcached, k, maxIterations, runs, - initializationMode) + model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations, + runs, initializationMode) centers = callJavaFunc(rdd.context, model.clusterCenters) return KMeansModel([c.toArray() for c in centers]) diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index c6149fe391ec8..33c49e2399908 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -54,15 +54,13 @@ def _new_smart_decode(obj): # this will call the MLlib version of pythonToJava() -def _to_java_object_rdd(rdd, cache=False): +def _to_java_object_rdd(rdd): """ Return an JavaRDD of Object by unpickling It will convert each Python object into Java object by Pyrolite, whenever the RDD is serialized in batch or not. """ rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) - if cache: - rdd.cache() return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 2bcbf2aaf8e3e..97ec74eda0b71 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -19,7 +19,7 @@ from pyspark import SparkContext from pyspark.rdd import RDD -from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, _to_java_object_rdd +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc __all__ = ['MatrixFactorizationModel', 'ALS', 'Rating'] @@ -110,7 +110,7 @@ def _prepare(cls, ratings): ratings = ratings.map(lambda x: Rating(*x)) else: raise ValueError("rating should be RDD of Rating or tuple/list") - return _to_java_object_rdd(ratings, True) + return ratings @classmethod def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False, diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index f4f5e615fadc3..210060140fd91 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -18,7 +18,7 @@ import numpy as np from numpy import array -from pyspark.mllib.common import callMLlibFunc, _to_java_object_rdd +from pyspark.mllib.common import callMLlibFunc from pyspark.mllib.linalg import SparseVector, _convert_to_vector __all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel', @@ -129,8 +129,7 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights): if not isinstance(first, LabeledPoint): raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) initial_weights = initial_weights or [0.0] * len(data.first().features) - weights, intercept = train_func(_to_java_object_rdd(data, cache=True), - _convert_to_vector(initial_weights)) + weights, intercept = train_func(data, _convert_to_vector(initial_weights)) return modelClass(weights, intercept) diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index 7667a9c11979e..da4286c5e4874 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -121,11 +121,14 @@ trait SparkILoopInit { def initializeSpark() { intp.beQuietDuring { command(""" - @transient val sc = org.apache.spark.repl.Main.interp.createSparkContext(); + @transient val sc = { + val _sc = org.apache.spark.repl.Main.interp.createSparkContext() + println("Spark context available as sc.") + _sc + } """) command("import org.apache.spark.SparkContext._") } - echo("Spark context available as sc.") } // code to be executed only after the interpreter is initialized diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index a591e9fc4622b..250727305970d 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -61,11 +61,14 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter) def initializeSpark() { intp.beQuietDuring { command( """ - @transient val sc = org.apache.spark.repl.Main.createSparkContext(); + @transient val sc = { + val _sc = org.apache.spark.repl.Main.createSparkContext() + println("Spark context available as sc.") + _sc + } """) command("import org.apache.spark.SparkContext._") } - echo("Spark context available as sc.") } /** Print a welcome message */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d3b4cf8e34242..facbd8b975f10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -179,7 +179,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool val missingInProject = requiredAttributes -- p.output if (missingInProject.nonEmpty) { // Add missing attributes and then project them away after the sort. - Project(projectList, + Project(projectList.map(_.toAttribute), Sort(ordering, Project(projectList ++ missingInProject, child))) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index ff1dc03069ef1..892b7e1a97c8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -84,6 +84,12 @@ object DataType { ("nullable", JBool(nullable)), ("type", dataType: JValue)) => StructField(name, parseDataType(dataType), nullable, Metadata.fromJObject(metadata)) + // Support reading schema when 'metadata' is missing. + case JSortedObject( + ("name", JString(name)), + ("nullable", JBool(nullable)), + ("type", dataType: JValue)) => + StructField(name, parseDataType(dataType), nullable) } @deprecated("Use DataType.fromJson instead", "1.2.0") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index cff7a012691dc..d7c811ca89022 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -41,11 +41,21 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una /** We must copy rows when sort based shuffle is on */ protected def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] + private val bypassMergeThreshold = + child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + override def execute() = attachTree(this , "execute") { newPartitioning match { case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. - val rdd = if (sortBasedShuffleOn) { + // This is a workaround for SPARK-4479. When: + // 1. sort based shuffle is on, and + // 2. the partition number is under the merge threshold, and + // 3. no ordering is required + // we can avoid the defensive copies to improve performance. In the long run, we probably + // want to include information in shuffle dependencies to indicate whether elements in the + // source RDD should be copied. + val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) { child.execute().mapPartitions { iter => val hashExpressions = newMutableProjection(expressions, child.output)() iter.map(r => (hashExpressions(r).copy(), r.copy())) @@ -82,6 +92,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una shuffled.map(_._1) case SinglePartition => + // SPARK-4479: Can't turn off defensive copy as what we do for `HashPartitioning`, since + // operators like `TakeOrdered` may require an ordering within the partition, and currently + // `SinglePartition` doesn't include ordering information. + // TODO Add `SingleOrderedPartition` for operators like `TakeOrdered` val rdd = if (sortBasedShuffleOn) { child.execute().mapPartitions { iter => iter.map(r => (null, r.copy())) } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 5d0643a64a044..0e36852ddd9b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -361,7 +361,7 @@ private[parquet] class FilteringParquetRowInputFormat private var footers: JList[Footer] = _ - private var fileStatuses= Map.empty[Path, FileStatus] + private var fileStatuses = Map.empty[Path, FileStatus] override def createRecordReader( inputSplit: InputSplit, @@ -405,7 +405,9 @@ private[parquet] class FilteringParquetRowInputFormat } val newFooters = new mutable.HashMap[FileStatus, Footer] if (toFetch.size > 0) { + val startFetch = System.currentTimeMillis val fetched = getFooters(conf, toFetch) + logInfo(s"Fetched $toFetch footers in ${System.currentTimeMillis - startFetch} ms") for ((status, i) <- toFetch.zipWithIndex) { newFooters(status) = fetched.get(i) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala new file mode 100644 index 0000000000000..bea12e6dd674e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.parquet + +import java.util.{List => JList} + +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce.{JobContext, InputSplit, Job} + +import parquet.hadoop.ParquetInputFormat +import parquet.hadoop.util.ContextUtil + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.{Partition => SparkPartition, Logging} +import org.apache.spark.rdd.{NewHadoopPartition, RDD} + +import org.apache.spark.sql.{SQLConf, Row, SQLContext} +import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, And, Expression, Attribute} +import org.apache.spark.sql.catalyst.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.sources._ + +import scala.collection.JavaConversions._ + +/** + * Allows creation of parquet based tables using the syntax + * `CREATE TABLE ... USING org.apache.spark.sql.parquet`. Currently the only option required + * is `path`, which should be the location of a collection of, optionally partitioned, + * parquet files. + */ +class DefaultSource extends RelationProvider { + /** Returns a new base relation with the given parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val path = + parameters.getOrElse("path", sys.error("'path' must be specifed for parquet tables.")) + + ParquetRelation2(path)(sqlContext) + } +} + +private[parquet] case class Partition(partitionValues: Map[String, Any], files: Seq[FileStatus]) + +/** + * An alternative to [[ParquetRelation]] that plugs in using the data sources API. This class is + * currently not intended as a full replacement of the parquet support in Spark SQL though it is + * likely that it will eventually subsume the existing physical plan implementation. + * + * Compared with the current implementation, this class has the following notable differences: + * + * Partitioning: Partitions are auto discovered and must be in the form of directories `key=value/` + * located at `path`. Currently only a single partitioning column is supported and it must + * be an integer. This class supports both fully self-describing data, which contains the partition + * key, and data where the partition key is only present in the folder structure. The presence + * of the partitioning key in the data is also auto-detected. The `null` partition is not yet + * supported. + * + * Metadata: The metadata is automatically discovered by reading the first parquet file present. + * There is currently no support for working with files that have different schema. Additionally, + * when parquet metadata caching is turned on, the FileStatus objects for all data will be cached + * to improve the speed of interactive querying. When data is added to a table it must be dropped + * and recreated to pick up any changes. + * + * Statistics: Statistics for the size of the table are automatically populated during metadata + * discovery. + */ +@DeveloperApi +case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) + extends CatalystScan with Logging { + + def sparkContext = sqlContext.sparkContext + + // Minor Hack: scala doesnt seem to respect @transient for vals declared via extraction + @transient + private var partitionKeys: Seq[String] = _ + @transient + private var partitions: Seq[Partition] = _ + discoverPartitions() + + // TODO: Only finds the first partition, assumes the key is of type Integer... + private def discoverPartitions() = { + val fs = FileSystem.get(new java.net.URI(path), sparkContext.hadoopConfiguration) + val partValue = "([^=]+)=([^=]+)".r + + val childrenOfPath = fs.listStatus(new Path(path)).filterNot(_.getPath.getName.startsWith("_")) + val childDirs = childrenOfPath.filter(s => s.isDir) + + if (childDirs.size > 0) { + val partitionPairs = childDirs.map(_.getPath.getName).map { + case partValue(key, value) => (key, value) + } + + val foundKeys = partitionPairs.map(_._1).distinct + if (foundKeys.size > 1) { + sys.error(s"Too many distinct partition keys: $foundKeys") + } + + // Do a parallel lookup of partition metadata. + val partitionFiles = + childDirs.par.map { d => + fs.listStatus(d.getPath) + // TODO: Is there a standard hadoop function for this? + .filterNot(_.getPath.getName.startsWith("_")) + .filterNot(_.getPath.getName.startsWith(".")) + }.seq + + partitionKeys = foundKeys.toSeq + partitions = partitionFiles.zip(partitionPairs).map { case (files, (key, value)) => + Partition(Map(key -> value.toInt), files) + }.toSeq + } else { + partitionKeys = Nil + partitions = Partition(Map.empty, childrenOfPath) :: Nil + } + } + + override val sizeInBytes = partitions.flatMap(_.files).map(_.getLen).sum + + val dataSchema = StructType.fromAttributes( // TODO: Parquet code should not deal with attributes. + ParquetTypesConverter.readSchemaFromFile( + partitions.head.files.head.getPath, + Some(sparkContext.hadoopConfiguration), + sqlContext.isParquetBinaryAsString)) + + val dataIncludesKey = + partitionKeys.headOption.map(dataSchema.fieldNames.contains(_)).getOrElse(true) + + override val schema = + if (dataIncludesKey) { + dataSchema + } else { + StructType(dataSchema.fields :+ StructField(partitionKeys.head, IntegerType)) + } + + override def buildScan(output: Seq[Attribute], predicates: Seq[Expression]): RDD[Row] = { + // This is mostly a hack so that we can use the existing parquet filter code. + val requiredColumns = output.map(_.name) + // TODO: Parquet filters should be based on data sources API, not catalyst expressions. + val filters = DataSourceStrategy.selectFilters(predicates) + + val job = new Job(sparkContext.hadoopConfiguration) + ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) + val jobConf: Configuration = ContextUtil.getConfiguration(job) + + val requestedSchema = StructType(requiredColumns.map(schema(_))) + + // TODO: Make folder based partitioning a first class citizen of the Data Sources API. + val partitionFilters = filters.collect { + case e @ EqualTo(attr, value) if partitionKeys.contains(attr) => + logInfo(s"Parquet scan partition filter: $attr=$value") + (p: Partition) => p.partitionValues(attr) == value + + case e @ In(attr, values) if partitionKeys.contains(attr) => + logInfo(s"Parquet scan partition filter: $attr IN ${values.mkString("{", ",", "}")}") + val set = values.toSet + (p: Partition) => set.contains(p.partitionValues(attr)) + + case e @ GreaterThan(attr, value) if partitionKeys.contains(attr) => + logInfo(s"Parquet scan partition filter: $attr > $value") + (p: Partition) => p.partitionValues(attr).asInstanceOf[Int] > value.asInstanceOf[Int] + + case e @ GreaterThanOrEqual(attr, value) if partitionKeys.contains(attr) => + logInfo(s"Parquet scan partition filter: $attr >= $value") + (p: Partition) => p.partitionValues(attr).asInstanceOf[Int] >= value.asInstanceOf[Int] + + case e @ LessThan(attr, value) if partitionKeys.contains(attr) => + logInfo(s"Parquet scan partition filter: $attr < $value") + (p: Partition) => p.partitionValues(attr).asInstanceOf[Int] < value.asInstanceOf[Int] + + case e @ LessThanOrEqual(attr, value) if partitionKeys.contains(attr) => + logInfo(s"Parquet scan partition filter: $attr <= $value") + (p: Partition) => p.partitionValues(attr).asInstanceOf[Int] <= value.asInstanceOf[Int] + } + + val selectedPartitions = partitions.filter(p => partitionFilters.forall(_(p))) + val fs = FileSystem.get(new java.net.URI(path), sparkContext.hadoopConfiguration) + val selectedFiles = selectedPartitions.flatMap(_.files).map(f => fs.makeQualified(f.getPath)) + org.apache.hadoop.mapreduce.lib.input.FileInputFormat.setInputPaths(job, selectedFiles:_*) + + // Push down filters when possible + predicates + .reduceOption(And) + .flatMap(ParquetFilters.createFilter) + .filter(_ => sqlContext.parquetFilterPushDown) + .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) + + def percentRead = selectedPartitions.size.toDouble / partitions.size.toDouble * 100 + logInfo(s"Reading $percentRead% of $path partitions") + + // Store both requested and original schema in `Configuration` + jobConf.set( + RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, + ParquetTypesConverter.convertToString(requestedSchema.toAttributes)) + jobConf.set( + RowWriteSupport.SPARK_ROW_SCHEMA, + ParquetTypesConverter.convertToString(schema.toAttributes)) + + // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata + val useCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true").toBoolean + jobConf.set(SQLConf.PARQUET_CACHE_METADATA, useCache.toString) + + val baseRDD = + new org.apache.spark.rdd.NewHadoopRDD( + sparkContext, + classOf[FilteringParquetRowInputFormat], + classOf[Void], + classOf[Row], + jobConf) { + val cacheMetadata = useCache + + @transient + val cachedStatus = selectedPartitions.flatMap(_.files) + + // Overridden so we can inject our own cached files statuses. + override def getPartitions: Array[SparkPartition] = { + val inputFormat = + if (cacheMetadata) { + new FilteringParquetRowInputFormat { + override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatus + } + } else { + new FilteringParquetRowInputFormat + } + + inputFormat match { + case configurable: Configurable => + configurable.setConf(getConf) + case _ => + } + val jobContext = newJobContext(getConf, jobId) + val rawSplits = inputFormat.getSplits(jobContext).toArray + val result = new Array[SparkPartition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = + new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } + } + + // The ordinal for the partition key in the result row, if requested. + val partitionKeyLocation = + partitionKeys + .headOption + .map(requiredColumns.indexOf(_)) + .getOrElse(-1) + + // When the data does not include the key and the key is requested then we must fill it in + // based on information from the input split. + if (!dataIncludesKey && partitionKeyLocation != -1) { + baseRDD.mapPartitionsWithInputSplit { case (split, iter) => + val partValue = "([^=]+)=([^=]+)".r + val partValues = + split.asInstanceOf[parquet.hadoop.ParquetInputSplit] + .getPath + .toString + .split("/") + .flatMap { + case partValue(key, value) => Some(key -> value) + case _ => None + }.toMap + + val currentValue = partValues.values.head.toInt + iter.map { pair => + val res = pair._2.asInstanceOf[SpecificMutableRow] + res.setInt(partitionKeyLocation, currentValue) + res + } + } + } else { + baseRDD.map(_._2) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 954e86822de17..37853d4d03019 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -31,6 +31,13 @@ import org.apache.spark.sql.execution.SparkPlan */ private[sql] object DataSourceStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: CatalystScan)) => + pruneFilterProjectRaw( + l, + projectList, + filters, + (a, f) => t.buildScan(a, f)) :: Nil + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedFilteredScan)) => pruneFilterProject( l, @@ -51,19 +58,35 @@ private[sql] object DataSourceStrategy extends Strategy { case _ => Nil } + // Based on Public API. protected def pruneFilterProject( - relation: LogicalRelation, - projectList: Seq[NamedExpression], - filterPredicates: Seq[Expression], - scanBuilder: (Array[String], Array[Filter]) => RDD[Row]) = { + relation: LogicalRelation, + projectList: Seq[NamedExpression], + filterPredicates: Seq[Expression], + scanBuilder: (Array[String], Array[Filter]) => RDD[Row]) = { + pruneFilterProjectRaw( + relation, + projectList, + filterPredicates, + (requestedColumns, pushedFilters) => { + scanBuilder(requestedColumns.map(_.name).toArray, selectFilters(pushedFilters).toArray) + }) + } + + // Based on Catalyst expressions. + protected def pruneFilterProjectRaw( + relation: LogicalRelation, + projectList: Seq[NamedExpression], + filterPredicates: Seq[Expression], + scanBuilder: (Seq[Attribute], Seq[Expression]) => RDD[Row]) = { val projectSet = AttributeSet(projectList.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) val filterCondition = filterPredicates.reduceLeftOption(And) - val pushedFilters = selectFilters(filterPredicates.map { _ transform { + val pushedFilters = filterPredicates.map { _ transform { case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes. - }}).toArray + }} if (projectList.map(_.toAttribute) == projectList && projectSet.size == projectList.size && @@ -74,8 +97,6 @@ private[sql] object DataSourceStrategy extends Strategy { val requestedColumns = projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above. .map(relation.attributeMap) // Match original case of attributes. - .map(_.name) - .toArray val scan = execution.PhysicalRDD( @@ -84,14 +105,14 @@ private[sql] object DataSourceStrategy extends Strategy { filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq - val columnNames = requestedColumns.map(_.name).toArray - val scan = execution.PhysicalRDD(requestedColumns, scanBuilder(columnNames, pushedFilters)) + val scan = + execution.PhysicalRDD(requestedColumns, scanBuilder(requestedColumns, pushedFilters)) execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } } - protected def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { + protected[sql] def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v) case expressions.EqualTo(Literal(v, _), a: Attribute) => EqualTo(a.name, v) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 861638b1e99b6..2b8fc05fc0102 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -16,12 +16,13 @@ */ package org.apache.spark.sql.sources -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLConf, Row, SQLContext, StructType} import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} /** + * ::DeveloperApi:: * Implemented by objects that produce relations for a specific kind of data source. When * Spark SQL is given a DDL operation with a USING clause specified, this interface is used to * pass in the parameters specified by a user. @@ -40,6 +41,7 @@ trait RelationProvider { } /** + * ::DeveloperApi:: * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must * be able to produce the schema of their data in the form of a [[StructType]] Concrete * implementation should inherit from one of the descendant `Scan` classes, which define various @@ -65,6 +67,7 @@ abstract class BaseRelation { } /** + * ::DeveloperApi:: * A BaseRelation that can produce all of its tuples as an RDD of Row objects. */ @DeveloperApi @@ -73,6 +76,7 @@ abstract class TableScan extends BaseRelation { } /** + * ::DeveloperApi:: * A BaseRelation that can eliminate unneeded columns before producing an RDD * containing all of its tuples as Row objects. */ @@ -82,6 +86,7 @@ abstract class PrunedScan extends BaseRelation { } /** + * ::DeveloperApi:: * A BaseRelation that can eliminate unneeded columns and filter using selected * predicates before producing an RDD containing all matching tuples as Row objects. * @@ -93,3 +98,18 @@ abstract class PrunedScan extends BaseRelation { abstract class PrunedFilteredScan extends BaseRelation { def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] } + +/** + * ::Experimental:: + * An interface for experimenting with a more direct connection to the query planner. Compared to + * [[PrunedFilteredScan]], this operator receives the raw expressions from the + * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. Unlike the other APIs this + * interface is not designed to be binary compatible across releases and thus should only be used + * for experimentation. + */ +@Experimental +abstract class CatalystScan extends BaseRelation { + def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] +} + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0a96831c76f57..84ee3051eb682 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -974,6 +974,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { dropTempTable("data") } + test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { + checkAnswer( + sql("SELECT a + b FROM testData2 ORDER BY a"), + Seq(2, 3, 3 ,4 ,4 ,5).map(Seq(_)) + ) + } + test("Supporting relational operator '<=>' in Spark SQL") { val nullCheckData1 = TestData(1,"1") :: TestData(2,null) :: Nil val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index 3c7f62af450d9..99c1987158581 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} import org.apache.spark.sql.{SchemaRDD, Row => SparkRow} /** - * A compatibility layer for interacting with Hive version 0.12.0. + * A compatibility layer for interacting with Hive version 0.13.1. */ private[thriftserver] object HiveThriftServerShim { val version = "0.13.1" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala similarity index 63% rename from sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala index cc65242c0da9b..7159ebd0353ad 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala @@ -34,71 +34,52 @@ case class ParquetDataWithKey(p: Int, intField: Int, stringField: String) /** - * Tests for our SerDe -> Native parquet scan conversion. + * A suite to test the automatic conversion of metastore tables with parquet data to use the + * built in parquet support. */ -class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll { +class ParquetMetastoreSuite extends ParquetTest { override def beforeAll(): Unit = { - val partitionedTableDir = File.createTempFile("parquettests", "sparksql") - partitionedTableDir.delete() - partitionedTableDir.mkdir() - - (1 to 10).foreach { p => - val partDir = new File(partitionedTableDir, s"p=$p") - sparkContext.makeRDD(1 to 10) - .map(i => ParquetData(i, s"part-$p")) - .saveAsParquetFile(partDir.getCanonicalPath) - } - - val partitionedTableDirWithKey = File.createTempFile("parquettests", "sparksql") - partitionedTableDirWithKey.delete() - partitionedTableDirWithKey.mkdir() - - (1 to 10).foreach { p => - val partDir = new File(partitionedTableDirWithKey, s"p=$p") - sparkContext.makeRDD(1 to 10) - .map(i => ParquetDataWithKey(p, i, s"part-$p")) - .saveAsParquetFile(partDir.getCanonicalPath) - } + super.beforeAll() sql(s""" - create external table partitioned_parquet - ( - intField INT, - stringField STRING - ) - PARTITIONED BY (p int) - ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - STORED AS - INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - location '${partitionedTableDir.getCanonicalPath}' + create external table partitioned_parquet + ( + intField INT, + stringField STRING + ) + PARTITIONED BY (p int) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${partitionedTableDir.getCanonicalPath}' """) sql(s""" - create external table partitioned_parquet_with_key - ( - intField INT, - stringField STRING - ) - PARTITIONED BY (p int) - ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - STORED AS - INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - location '${partitionedTableDirWithKey.getCanonicalPath}' + create external table partitioned_parquet_with_key + ( + intField INT, + stringField STRING + ) + PARTITIONED BY (p int) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${partitionedTableDirWithKey.getCanonicalPath}' """) sql(s""" - create external table normal_parquet - ( - intField INT, - stringField STRING - ) - ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - STORED AS - INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - location '${new File(partitionedTableDir, "p=1").getCanonicalPath}' + create external table normal_parquet + ( + intField INT, + stringField STRING + ) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${new File(partitionedTableDir, "p=1").getCanonicalPath}' """) (1 to 10).foreach { p => @@ -116,6 +97,82 @@ class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll { setConf("spark.sql.hive.convertMetastoreParquet", "false") } + test("conversion is working") { + assert( + sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + case _: HiveTableScan => true + }.isEmpty) + assert( + sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + case _: ParquetTableScan => true + }.nonEmpty) + } +} + +/** + * A suite of tests for the Parquet support through the data sources API. + */ +class ParquetSourceSuite extends ParquetTest { + override def beforeAll(): Unit = { + super.beforeAll() + + sql( s""" + create temporary table partitioned_parquet + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${partitionedTableDir.getCanonicalPath}' + ) + """) + + sql( s""" + create temporary table partitioned_parquet_with_key + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${partitionedTableDirWithKey.getCanonicalPath}' + ) + """) + + sql( s""" + create temporary table normal_parquet + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${new File(partitionedTableDir, "p=1").getCanonicalPath}' + ) + """) + } +} + +/** + * A collection of tests for parquet data with various forms of partitioning. + */ +abstract class ParquetTest extends QueryTest with BeforeAndAfterAll { + var partitionedTableDir: File = null + var partitionedTableDirWithKey: File = null + + override def beforeAll(): Unit = { + partitionedTableDir = File.createTempFile("parquettests", "sparksql") + partitionedTableDir.delete() + partitionedTableDir.mkdir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDir, s"p=$p") + sparkContext.makeRDD(1 to 10) + .map(i => ParquetData(i, s"part-$p")) + .saveAsParquetFile(partDir.getCanonicalPath) + } + + partitionedTableDirWithKey = File.createTempFile("parquettests", "sparksql") + partitionedTableDirWithKey.delete() + partitionedTableDirWithKey.mkdir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDirWithKey, s"p=$p") + sparkContext.makeRDD(1 to 10) + .map(i => ParquetDataWithKey(p, i, s"part-$p")) + .saveAsParquetFile(partDir.getCanonicalPath) + } + } + Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table => test(s"project the partitioning column $table") { checkAnswer( @@ -193,15 +250,4 @@ class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll { sql("SELECT COUNT(*) FROM normal_parquet"), 10) } - - test("conversion is working") { - assert( - sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { - case _: HiveTableScan => true - }.isEmpty) - assert( - sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { - case _: ParquetTableScan => true - }.nonEmpty) - } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index eabd61d713e0c..dbf1ebbaf653a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -254,7 +254,7 @@ abstract class DStream[T: ClassTag] ( } private[streaming] def remember(duration: Duration) { - if (duration != null && duration > rememberDuration) { + if (duration != null && (rememberDuration == null || duration > rememberDuration)) { rememberDuration = duration logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 55d6cf6a783ea..5f13fdc5579ed 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -17,18 +17,55 @@ package org.apache.spark.streaming.dstream -import java.io.{ObjectInputStream, IOException} -import scala.collection.mutable.{HashSet, HashMap} +import java.io.{IOException, ObjectInputStream} + +import scala.collection.mutable import scala.reflect.ClassTag + import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.UnionRDD -import org.apache.spark.streaming.{StreamingContext, Time} -import org.apache.spark.util.{TimeStampedHashMap, Utils} +import org.apache.spark.rdd.{RDD, UnionRDD} +import org.apache.spark.streaming._ +import org.apache.spark.util.{TimeStampedHashMap, Utils} +/** + * This class represents an input stream that monitors a Hadoop-compatible filesystem for new + * files and creates a stream out of them. The way it works as follows. + * + * At each batch interval, the file system is queried for files in the given directory and + * detected new files are selected for that batch. In this case "new" means files that + * became visible to readers during that time period. Some extra care is needed to deal + * with the fact that files may become visible after they are created. For this purpose, this + * class remembers the information about the files selected in past batches for + * a certain duration (say, "remember window") as shown in the figure below. + * + * |<----- remember window ----->| + * ignore threshold --->| |<--- current batch time + * |____.____.____.____.____.____| + * | | | | | | | + * ---------------------|----|----|----|----|----|----|-----------------------> Time + * |____|____|____|____|____|____| + * remembered batches + * + * The trailing end of the window is the "ignore threshold" and all files whose mod times + * are less than this threshold are assumed to have already been selected and are therefore + * ignored. Files whose mod times are within the "remember window" are checked against files + * that have already been selected. At a high level, this is how new files are identified in + * each batch - files whose mod times are greater than the ignore threshold and + * have not been considered within the remember window. See the documentation on the method + * `isNewFile` for more details. + * + * This makes some assumptions from the underlying file system that the system is monitoring. + * - The clock of the file system is assumed to synchronized with the clock of the machine running + * the streaming app. + * - If a file is to be visible in the directory listings, it must be visible within a certain + * duration of the mod time of the file. This duration is the "remember window", which is set to + * 1 minute (see `FileInputDStream.MIN_REMEMBER_DURATION`). Otherwise, the file will never be + * selected as the mod time will be less than the ignore threshold when it becomes visible. + * - Once a file is visible, the mod time cannot change. If it does due to appends, then the + * processing semantics are undefined. + */ private[streaming] class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : ClassTag]( @transient ssc_ : StreamingContext, @@ -37,22 +74,37 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas newFilesOnly: Boolean = true) extends InputDStream[(K, V)](ssc_) { + // Data to be saved as part of the streaming checkpoints protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData - // files found in the last interval - private val lastFoundFiles = new HashSet[String] + // Initial ignore threshold based on which old, existing files in the directory (at the time of + // starting the streaming application) will be ignored or considered + private val initialModTimeIgnoreThreshold = if (newFilesOnly) System.currentTimeMillis() else 0L + + /* + * Make sure that the information of files selected in the last few batches are remembered. + * This would allow us to filter away not-too-old files which have already been recently + * selected and processed. + */ + private val numBatchesToRemember = FileInputDStream.calculateNumBatchesToRemember(slideDuration) + private val durationToRemember = slideDuration * numBatchesToRemember + remember(durationToRemember) - // Files with mod time earlier than this is ignored. This is updated every interval - // such that in the current interval, files older than any file found in the - // previous interval will be ignored. Obviously this time keeps moving forward. - private var ignoreTime = if (newFilesOnly) System.currentTimeMillis() else 0L + // Map of batch-time to selected file info for the remembered batches + @transient private[streaming] var batchTimeToSelectedFiles = + new mutable.HashMap[Time, Array[String]] + + // Set of files that were selected in the remembered batches + @transient private var recentlySelectedFiles = new mutable.HashSet[String]() + + // Read-through cache of file mod times, used to speed up mod time lookups + @transient private var fileToModTime = new TimeStampedHashMap[String, Long](true) + + // Timestamp of the last round of finding files + @transient private var lastNewFileFindingTime = 0L - // Latest file mod time seen till any point of time @transient private var path_ : Path = null @transient private var fs_ : FileSystem = null - @transient private[streaming] var files = new HashMap[Time, Array[String]] - @transient private var fileModTimes = new TimeStampedHashMap[String, Long](true) - @transient private var lastNewFileFindingTime = 0L override def start() { } @@ -68,54 +120,113 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas * the previous call. */ override def compute(validTime: Time): Option[RDD[(K, V)]] = { - assert(validTime.milliseconds >= ignoreTime, - "Trying to get new files for a really old time [" + validTime + " < " + ignoreTime + "]") - // Find new files - val (newFiles, minNewFileModTime) = findNewFiles(validTime.milliseconds) + val newFiles = findNewFiles(validTime.milliseconds) logInfo("New files at time " + validTime + ":\n" + newFiles.mkString("\n")) - if (!newFiles.isEmpty) { - lastFoundFiles.clear() - lastFoundFiles ++= newFiles - ignoreTime = minNewFileModTime - } - files += ((validTime, newFiles.toArray)) + batchTimeToSelectedFiles += ((validTime, newFiles)) + recentlySelectedFiles ++= newFiles Some(filesToRDD(newFiles)) } /** Clear the old time-to-files mappings along with old RDDs */ protected[streaming] override def clearMetadata(time: Time) { super.clearMetadata(time) - val oldFiles = files.filter(_._1 < (time - rememberDuration)) - files --= oldFiles.keys + val oldFiles = batchTimeToSelectedFiles.filter(_._1 < (time - rememberDuration)) + batchTimeToSelectedFiles --= oldFiles.keys + recentlySelectedFiles --= oldFiles.values.flatten logInfo("Cleared " + oldFiles.size + " old files that were older than " + (time - rememberDuration) + ": " + oldFiles.keys.mkString(", ")) logDebug("Cleared files are:\n" + oldFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n")) // Delete file mod times that weren't accessed in the last round of getting new files - fileModTimes.clearOldValues(lastNewFileFindingTime - 1) + fileToModTime.clearOldValues(lastNewFileFindingTime - 1) } /** - * Find files which have modification timestamp <= current time and return a 3-tuple of - * (new files found, latest modification time among them, files with latest modification time) + * Find new files for the batch of `currentTime`. This is done by first calculating the + * ignore threshold for file mod times, and then getting a list of files filtered based on + * the current batch time and the ignore threshold. The ignore threshold is the max of + * initial ignore threshold and the trailing end of the remember window (that is, which ever + * is later in time). */ - private def findNewFiles(currentTime: Long): (Seq[String], Long) = { - logDebug("Trying to get new files for time " + currentTime) - lastNewFileFindingTime = System.currentTimeMillis - val filter = new CustomPathFilter(currentTime) - val newFiles = fs.listStatus(directoryPath, filter).map(_.getPath.toString) - val timeTaken = System.currentTimeMillis - lastNewFileFindingTime - logInfo("Finding new files took " + timeTaken + " ms") - logDebug("# cached file times = " + fileModTimes.size) - if (timeTaken > slideDuration.milliseconds) { - logWarning( - "Time taken to find new files exceeds the batch size. " + - "Consider increasing the batch size or reduceing the number of " + - "files in the monitored directory." + private def findNewFiles(currentTime: Long): Array[String] = { + try { + lastNewFileFindingTime = System.currentTimeMillis + + // Calculate ignore threshold + val modTimeIgnoreThreshold = math.max( + initialModTimeIgnoreThreshold, // initial threshold based on newFilesOnly setting + currentTime - durationToRemember.milliseconds // trailing end of the remember window ) + logDebug(s"Getting new files for time $currentTime, " + + s"ignoring files older than $modTimeIgnoreThreshold") + val filter = new PathFilter { + def accept(path: Path): Boolean = isNewFile(path, currentTime, modTimeIgnoreThreshold) + } + val newFiles = fs.listStatus(directoryPath, filter).map(_.getPath.toString) + val timeTaken = System.currentTimeMillis - lastNewFileFindingTime + logInfo("Finding new files took " + timeTaken + " ms") + logDebug("# cached file times = " + fileToModTime.size) + if (timeTaken > slideDuration.milliseconds) { + logWarning( + "Time taken to find new files exceeds the batch size. " + + "Consider increasing the batch size or reducing the number of " + + "files in the monitored directory." + ) + } + newFiles + } catch { + case e: Exception => + logWarning("Error finding new files", e) + reset() + Array.empty + } + } + + /** + * Identify whether the given `path` is a new file for the batch of `currentTime`. For it to be + * accepted, it has to pass the following criteria. + * - It must pass the user-provided file filter. + * - It must be newer than the ignore threshold. It is assumed that files older than the ignore + * threshold have already been considered or are existing files before start + * (when newFileOnly = true). + * - It must not be present in the recently selected files that this class remembers. + * - It must not be newer than the time of the batch (i.e. `currentTime` for which this + * file is being tested. This can occur if the driver was recovered, and the missing batches + * (during downtime) are being generated. In that case, a batch of time T may be generated + * at time T+x. Say x = 5. If that batch T contains file of mod time T+5, then bad things can + * happen. Let's say the selected files are remembered for 60 seconds. At time t+61, + * the batch of time t is forgotten, and the ignore threshold is still T+1. + * The files with mod time T+5 are not remembered and cannot be ignored (since, t+5 > t+1). + * Hence they can get selected as new files again. To prevent this, files whose mod time is more + * than current batch time are not considered. + */ + private def isNewFile(path: Path, currentTime: Long, modTimeIgnoreThreshold: Long): Boolean = { + val pathStr = path.toString + // Reject file if it does not satisfy filter + if (!filter(path)) { + logDebug(s"$pathStr rejected by filter") + return false + } + // Reject file if it was created before the ignore time + val modTime = getFileModTime(path) + if (modTime <= modTimeIgnoreThreshold) { + // Use <= instead of < to avoid SPARK-4518 + logDebug(s"$pathStr ignored as mod time $modTime <= ignore time $modTimeIgnoreThreshold") + return false } - (newFiles, filter.minNewFileModTime) + // Reject file if mod time > current batch time + if (modTime > currentTime) { + logDebug(s"$pathStr not selected as mod time $modTime > current time $currentTime") + return false + } + // Reject file if it was considered earlier + if (recentlySelectedFiles.contains(pathStr)) { + logDebug(s"$pathStr already considered") + return false + } + logDebug(s"$pathStr accepted with mod time $modTime") + return true } /** Generate one RDD from an array of files */ @@ -132,21 +243,21 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas new UnionRDD(context.sparkContext, fileRDDs) } + /** Get file mod time from cache or fetch it from the file system */ + private def getFileModTime(path: Path) = { + fileToModTime.getOrElseUpdate(path.toString, fs.getFileStatus(path).getModificationTime()) + } + private def directoryPath: Path = { if (path_ == null) path_ = new Path(directory) path_ } private def fs: FileSystem = { - if (fs_ == null) fs_ = directoryPath.getFileSystem(new Configuration()) + if (fs_ == null) fs_ = directoryPath.getFileSystem(ssc.sparkContext.hadoopConfiguration) fs_ } - private def getFileModTime(path: Path) = { - // Get file mod time from cache or fetch it from the file system - fileModTimes.getOrElseUpdate(path.toString, fs.getFileStatus(path).getModificationTime()) - } - private def reset() { fs_ = null } @@ -155,9 +266,10 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() - generatedRDDs = new HashMap[Time, RDD[(K,V)]] () - files = new HashMap[Time, Array[String]] - fileModTimes = new TimeStampedHashMap[String, Long](true) + generatedRDDs = new mutable.HashMap[Time, RDD[(K,V)]] () + batchTimeToSelectedFiles = new mutable.HashMap[Time, Array[String]]() + recentlySelectedFiles = new mutable.HashSet[String]() + fileToModTime = new TimeStampedHashMap[String, Long](true) } /** @@ -167,11 +279,11 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas private[streaming] class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) { - def hadoopFiles = data.asInstanceOf[HashMap[Time, Array[String]]] + def hadoopFiles = data.asInstanceOf[mutable.HashMap[Time, Array[String]]] override def update(time: Time) { hadoopFiles.clear() - hadoopFiles ++= files + hadoopFiles ++= batchTimeToSelectedFiles } override def cleanup(time: Time) { } @@ -182,7 +294,8 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas // Restore the metadata in both files and generatedRDDs logInfo("Restoring files for time " + t + " - " + f.mkString("[", ", ", "]") ) - files += ((t, f)) + batchTimeToSelectedFiles += ((t, f)) + recentlySelectedFiles ++= f generatedRDDs += ((t, filesToRDD(f))) } } @@ -193,57 +306,25 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas hadoopFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n") + "\n]" } } +} + +private[streaming] +object FileInputDStream { /** - * Custom PathFilter class to find new files that - * ... have modification time more than ignore time - * ... have not been seen in the last interval - * ... have modification time less than maxModTime + * Minimum duration of remembering the information of selected files. Files with mod times + * older than this "window" of remembering will be ignored. So if new files are visible + * within this window, then the file will get selected in the next batch. */ - private[streaming] - class CustomPathFilter(maxModTime: Long) extends PathFilter { + private val MIN_REMEMBER_DURATION = Minutes(1) - // Minimum of the mod times of new files found in the current interval - var minNewFileModTime = -1L + def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".") - def accept(path: Path): Boolean = { - try { - if (!filter(path)) { // Reject file if it does not satisfy filter - logDebug("Rejected by filter " + path) - return false - } - // Reject file if it was found in the last interval - if (lastFoundFiles.contains(path.toString)) { - logDebug("Mod time equal to last mod time, but file considered already") - return false - } - val modTime = getFileModTime(path) - logDebug("Mod time for " + path + " is " + modTime) - if (modTime < ignoreTime) { - // Reject file if it was created before the ignore time (or, before last interval) - logDebug("Mod time " + modTime + " less than ignore time " + ignoreTime) - return false - } else if (modTime > maxModTime) { - // Reject file if it is too new that considering it may give errors - logDebug("Mod time more than ") - return false - } - if (minNewFileModTime < 0 || modTime < minNewFileModTime) { - minNewFileModTime = modTime - } - logDebug("Accepted " + path) - } catch { - case fnfe: java.io.FileNotFoundException => - logWarning("Error finding new files", fnfe) - reset() - return false - } - true - } + /** + * Calculate the number of last batches to remember, such that all the files selected in + * at least last MIN_REMEMBER_DURATION duration can be remembered. + */ + def calculateNumBatchesToRemember(batchDuration: Duration): Int = { + math.ceil(MIN_REMEMBER_DURATION.milliseconds.toDouble / batchDuration.milliseconds).toInt } } - -private[streaming] -object FileInputDStream { - def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".") -} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index e5592e52b0d2d..77ff1ca780a58 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -265,7 +265,7 @@ class CheckpointSuite extends TestSuiteBase { // Verify whether files created have been recorded correctly or not var fileInputDStream = ssc.graph.getInputStreams().head.asInstanceOf[FileInputDStream[_, _, _]] - def recordedFiles = fileInputDStream.files.values.flatMap(x => x) + def recordedFiles = fileInputDStream.batchTimeToSelectedFiles.values.flatten assert(!recordedFiles.filter(_.endsWith("1")).isEmpty) assert(!recordedFiles.filter(_.endsWith("2")).isEmpty) assert(!recordedFiles.filter(_.endsWith("3")).isEmpty) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index fa04fa326e370..307052a4a9cbb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -28,9 +28,12 @@ import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer, SynchronizedQueue} +import scala.concurrent.duration._ +import scala.language.postfixOps import com.google.common.io.Files import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually._ import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel @@ -38,6 +41,9 @@ import org.apache.spark.streaming.util.ManualClock import org.apache.spark.util.Utils import org.apache.spark.streaming.receiver.{ActorHelper, Receiver} import org.apache.spark.rdd.RDD +import org.apache.hadoop.io.{Text, LongWritable} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.apache.hadoop.fs.Path class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { @@ -91,54 +97,12 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } - test("file input stream") { - // Disable manual clock as FileInputDStream does not work with manual clock - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") - - // Set up the streaming context and input streams - val testDir = Utils.createTempDir() - val ssc = new StreamingContext(conf, batchDuration) - val fileStream = ssc.textFileStream(testDir.toString) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - def output = outputBuffer.flatMap(x => x) - val outputStream = new TestOutputStream(fileStream, outputBuffer) - outputStream.register() - ssc.start() - - // Create files in the temporary directory so that Spark Streaming can read data from it - val input = Seq(1, 2, 3, 4, 5) - val expectedOutput = input.map(_.toString) - Thread.sleep(1000) - for (i <- 0 until input.size) { - val file = new File(testDir, i.toString) - Files.write(input(i) + "\n", file, Charset.forName("UTF-8")) - logInfo("Created file " + file) - Thread.sleep(batchDuration.milliseconds) - Thread.sleep(1000) - } - val startTime = System.currentTimeMillis() - Thread.sleep(1000) - val timeTaken = System.currentTimeMillis() - startTime - assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") - logInfo("Stopping context") - ssc.stop() - - // Verify whether data received by Spark Streaming was as expected - logInfo("--------------------------------") - logInfo("output, size = " + outputBuffer.size) - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output, size = " + expectedOutput.size) - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") - - // Verify whether all the elements received are as expected - // (whether the elements were received one in each interval is not verified) - assert(output.toList === expectedOutput.toList) - - Utils.deleteRecursively(testDir) + test("file input stream - newFilesOnly = true") { + testFileStream(newFilesOnly = true) + } - // Enable manual clock back again for other tests - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + test("file input stream - newFilesOnly = false") { + testFileStream(newFilesOnly = false) } test("multi-thread receiver") { @@ -180,7 +144,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { assert(output.sum === numTotalRecords) } - test("queue input stream - oneAtATime=true") { + test("queue input stream - oneAtATime = true") { // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) val queue = new SynchronizedQueue[RDD[String]]() @@ -223,7 +187,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } } - test("queue input stream - oneAtATime=false") { + test("queue input stream - oneAtATime = false") { // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) val queue = new SynchronizedQueue[RDD[String]]() @@ -268,6 +232,50 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { assert(output(i) === expectedOutput(i)) } } + + def testFileStream(newFilesOnly: Boolean) { + var ssc: StreamingContext = null + val testDir: File = null + try { + val testDir = Utils.createTempDir() + val existingFile = new File(testDir, "0") + Files.write("0\n", existingFile, Charset.forName("UTF-8")) + + Thread.sleep(1000) + // Set up the streaming context and input streams + val newConf = conf.clone.set( + "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") + ssc = new StreamingContext(newConf, batchDuration) + val fileStream = ssc.fileStream[LongWritable, Text, TextInputFormat]( + testDir.toString, (x: Path) => true, newFilesOnly = newFilesOnly).map(_._2.toString) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(fileStream, outputBuffer) + outputStream.register() + ssc.start() + + // Create files in the directory + val input = Seq(1, 2, 3, 4, 5) + input.foreach { i => + Thread.sleep(batchDuration.milliseconds) + val file = new File(testDir, i.toString) + Files.write(i + "\n", file, Charset.forName("UTF-8")) + logInfo("Created file " + file) + } + + // Verify that all the files have been read + val expectedOutput = if (newFilesOnly) { + input.map(_.toString).toSet + } else { + (Seq(0) ++ input).map(_.toString).toSet + } + eventually(timeout(maxWaitTimeMillis milliseconds), interval(100 milliseconds)) { + assert(outputBuffer.flatten.toSet === expectedOutput) + } + } finally { + if (ssc != null) ssc.stop() + if (testDir != null) Utils.deleteRecursively(testDir) + } + } }