diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index 8aad39dacca3f..a6c4cd220e42f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -50,12 +50,7 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) val id = ssc.getNewInputStreamId() // Keep track of the freshest rate for this stream using the rateEstimator - protected[streaming] val rateController: Option[RateController] = - RateEstimator.makeEstimator(ssc.conf).map { estimator => - new RateController(id, estimator) { - override def publish(rate: Long): Unit = () - } - } + protected[streaming] val rateController: Option[RateController] = None /** A human-readable name of this InputDStream */ private[streaming] def name: String = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 5b4df67d9ce11..e79ba5018d9fd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -45,7 +45,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker. */ override protected[streaming] val rateController: Option[RateController] = - RateEstimator.makeEstimator(ssc.conf).map { new ReceiverRateController(id, _) } + RateEstimator.create(ssc.conf).map { new ReceiverRateController(id, _) } /** * Gets the receiver object that will be sent to the worker nodes diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 4aed1aa1d92d2..58bdda7794bf2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -66,7 +66,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } eventLoop.start() - // Estimators receive updates from batch completion + // attach rate controllers of input streams to receive batch completion updates for { inputDStream <- ssc.graph.getInputStreams rateController <- inputDStream.rateController diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala index 0fea6838da032..f1e75da1644f3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala @@ -21,6 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.concurrent.{ExecutionContext, Future} +import org.apache.spark.SparkConf +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.util.ThreadUtils @@ -29,8 +31,8 @@ import org.apache.spark.util.ThreadUtils * an estimate of the speed at which this stream should ingest messages, * given an estimate computation from a `RateEstimator` */ -private [streaming] abstract class RateController(val streamUID: Int, rateEstimator: RateEstimator) - extends StreamingListener with Serializable { +private[streaming] abstract class RateController(val streamUID: Int, rateEstimator: RateEstimator) + extends StreamingListener with Serializable { protected def publish(rate: Long): Unit @@ -46,8 +48,8 @@ private [streaming] abstract class RateController(val streamUID: Int, rateEstima */ private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit = Future[Unit] { - val newSpeed = rateEstimator.compute(time, elems, workDelay, waitDelay) - newSpeed foreach { s => + val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay) + newRate.foreach { s => rateLimit.set(s.toLong) publish(getLatestRate()) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala index 592d173e99bdc..a08685119e5d5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -52,8 +52,8 @@ object RateEstimator { * @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any * known estimators. */ - def makeEstimator(conf: SparkConf): Option[RateEstimator] = - conf.getOption("spark.streaming.RateEstimator") map { estimator => + def create(conf: SparkConf): Option[RateEstimator] = + conf.getOption("spark.streaming.backpressure.rateEstimator").map { estimator => throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index 8457aa78de5f9..3136cba8b4f63 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -90,6 +90,7 @@ class ReceiverTrackerSuite extends TestSuiteBase { ssc.addStreamingListener(ReceiverStartedWaiter) ssc.scheduler.listenerBus.start(ssc.sc) + SingletonDummyReceiver.reset() val newRateLimit = 100L val inputDStream = new RateLimitInputDStream(ssc) @@ -109,7 +110,14 @@ class ReceiverTrackerSuite extends TestSuiteBase { } } -/** An input DStream with a hard-coded receiver that gives access to internals for testing. */ +/** + * An input DStream with a hard-coded receiver that gives access to internals for testing. + * + * @note Make sure to call {{{SingletonDummyReceiver.reset()}}} before using this in a test, + * or otherwise you may get {{{NotSerializableException}}} when trying to serialize + * the receiver. + * @see [[[SingletonDummyReceiver]]]. + */ private class RateLimitInputDStream(@transient ssc_ : StreamingContext) extends ReceiverInputDStream[Int](ssc_) {