diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 2bed41672fe33..5ced1ca200daa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -349,6 +349,17 @@ object UnsupportedOperationChecker { _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias | _: TypedFilter) => case node if node.nodeName == "StreamingRelationV2" => + case Repartition(1, false, _) => + case node: Aggregate => + val aboveSinglePartitionCoalesce = node.find { + case Repartition(1, false, _) => true + case _ => false + }.isDefined + + if (!aboveSinglePartitionCoalesce) { + throwError(s"In continuous processing mode, coalesce(1) must be called before " + + s"aggregate operation ${node.nodeName}.") + } case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 182aa2906cf1e..2a7f1de2c7c19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -22,11 +22,12 @@ import scala.collection.mutable import org.apache.spark.sql.{sources, Strategy} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader object DataSourceV2Strategy extends Strategy { @@ -141,6 +142,17 @@ object DataSourceV2Strategy extends Strategy { case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil + case Repartition(1, false, child) => + val isContinuous = child.collectFirst { + case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r + }.isDefined + + if (isContinuous) { + ContinuousCoalesceExec(1, planLater(child)) :: Nil + } else { + Nil + } + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala new file mode 100644 index 0000000000000..5f60343bacfaa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.UUID + +import org.apache.spark.{HashPartitioner, SparkEnv} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD} + +/** + * Physical plan for coalescing a continuous processing plan. + * + * Currently, only coalesces to a single partition are supported. `numPartitions` must be 1. + */ +case class ContinuousCoalesceExec(numPartitions: Int, child: SparkPlan) extends SparkPlan { + override def output: Seq[Attribute] = child.output + + override def children: Seq[SparkPlan] = child :: Nil + + override def outputPartitioning: Partitioning = SinglePartition + + override def doExecute(): RDD[InternalRow] = { + assert(numPartitions == 1) + new ContinuousCoalesceRDD( + sparkContext, + numPartitions, + conf.continuousStreamingExecutorQueueSize, + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_INTERVAL_KEY).toLong, + child.execute()) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala new file mode 100644 index 0000000000000..ba85b355f974f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.UUID + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.continuous.shuffle._ +import org.apache.spark.util.ThreadUtils + +case class ContinuousCoalesceRDDPartition( + index: Int, + endpointName: String, + queueSize: Int, + numShuffleWriters: Int, + epochIntervalMs: Long) + extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (reader: ContinuousShuffleReader, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new RPCContinuousShuffleReader( + queueSize, numShuffleWriters, epochIntervalMs, env) + val endpoint = env.setupEndpoint(endpointName, receiver) + + TaskContext.get().addTaskCompletionListener { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } + // This flag will be flipped on the executors to indicate that the threads processing + // partitions of the write-side RDD have been started. These will run indefinitely + // asynchronously as epochs of the coalesce RDD complete on the read side. + private[continuous] var writersInitialized: Boolean = false +} + +/** + * RDD for continuous coalescing. Asynchronously writes all partitions of `prev` into a local + * continuous shuffle, and then reads them in the task thread using `reader`. + */ +class ContinuousCoalesceRDD( + context: SparkContext, + numPartitions: Int, + readerQueueSize: Int, + epochIntervalMs: Long, + prev: RDD[InternalRow]) + extends RDD[InternalRow](context, Nil) { + + // When we support more than 1 target partition, we'll need to figure out how to pass in the + // required partitioner. + private val outputPartitioner = new HashPartitioner(1) + + private val readerEndpointNames = (0 until numPartitions).map { i => + s"ContinuousCoalesceRDD-part$i-${UUID.randomUUID()}" + } + + override def getPartitions: Array[Partition] = { + (0 until numPartitions).map { partIndex => + ContinuousCoalesceRDDPartition( + partIndex, + readerEndpointNames(partIndex), + readerQueueSize, + prev.getNumPartitions, + epochIntervalMs) + }.toArray + } + + private lazy val threadPool = ThreadUtils.newDaemonFixedThreadPool( + prev.getNumPartitions, + this.name) + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val part = split.asInstanceOf[ContinuousCoalesceRDDPartition] + + if (!part.writersInitialized) { + val rpcEnv = SparkEnv.get.rpcEnv + + // trigger lazy initialization + part.endpoint + val endpointRefs = readerEndpointNames.map { endpointName => + rpcEnv.setupEndpointRef(rpcEnv.address, endpointName) + } + + val runnables = prev.partitions.map { prevSplit => + new Runnable() { + override def run(): Unit = { + TaskContext.setTaskContext(context) + + val writer: ContinuousShuffleWriter = new RPCContinuousShuffleWriter( + prevSplit.index, outputPartitioner, endpointRefs.toArray) + + EpochTracker.initializeCurrentEpoch( + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) + while (!context.isInterrupted() && !context.isCompleted()) { + writer.write(prev.compute(prevSplit, context).asInstanceOf[Iterator[UnsafeRow]]) + // Note that current epoch is a non-inheritable thread local, so each writer thread + // can properly increment its own epoch without affecting the main task thread. + EpochTracker.incrementCurrentEpoch() + } + } + } + } + + context.addTaskCompletionListener { ctx => + threadPool.shutdownNow() + } + + part.writersInitialized = true + + runnables.foreach(threadPool.execute) + } + + part.reader.read() + } + + override def clearDependencies(): Unit = { + throw new IllegalStateException("Continuous RDDs cannot be checkpointed") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index a7ccce10b0cee..73868d5967e90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -51,11 +51,11 @@ class ContinuousDataSourceRDD( sc: SparkContext, dataQueueSize: Int, epochPollIntervalMs: Long, - @transient private val readerFactories: Seq[InputPartition[UnsafeRow]]) + private val readerInputPartitions: Seq[InputPartition[UnsafeRow]]) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerFactories.zipWithIndex.map { + readerInputPartitions.zipWithIndex.map { case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition) }.toArray } @@ -74,8 +74,7 @@ class ContinuousDataSourceRDD( val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition] if (partition.queueReader == null) { partition.queueReader = - new ContinuousQueuedDataReader( - partition.inputPartition, context, dataQueueSize, epochPollIntervalMs) + new ContinuousQueuedDataReader(partition, context, dataQueueSize, epochPollIntervalMs) } partition.queueReader diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index e3d0cea608b2a..a0bb8292d7766 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -216,6 +216,9 @@ class ContinuousExecution( currentEpochCoordinatorId = epochCoordinatorId sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId) + sparkSessionForQuery.sparkContext.setLocalProperty( + ContinuousExecution.EPOCH_INTERVAL_KEY, + trigger.asInstanceOf[ContinuousTrigger].intervalMs.toString) // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = @@ -382,4 +385,5 @@ class ContinuousExecution( object ContinuousExecution { val START_EPOCH_KEY = "__continuous_start_epoch" val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id" + val EPOCH_INTERVAL_KEY = "__continuous_epoch_interval" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index f38577b6a9f16..8c74b8244d096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -37,11 +37,11 @@ import org.apache.spark.util.ThreadUtils * offsets across epochs. Each compute() should call the next() method here until null is returned. */ class ContinuousQueuedDataReader( - partition: InputPartition[UnsafeRow], + partition: ContinuousDataSourceRDDPartition, context: TaskContext, dataQueueSize: Int, epochPollIntervalMs: Long) extends Closeable { - private val reader = partition.createPartitionReader() + private val reader = partition.inputPartition.createPartitionReader() // Important sequencing - we must get our starting point before the provider threads start running private var currentOffset: PartitionOffset = @@ -113,7 +113,7 @@ class ContinuousQueuedDataReader( currentEntry match { case EpochMarker => epochCoordEndpoint.send(ReportPartitionOffset( - context.partitionId(), EpochTracker.getCurrentEpoch.get, currentOffset)) + partition.index, EpochTracker.getCurrentEpoch.get, currentOffset)) null case ContinuousRow(row, offset) => currentOffset = offset diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index cf6572d3de1f7..518223f3cd008 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -21,12 +21,14 @@ import java.util.UUID import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcAddress import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator case class ContinuousShuffleReadPartition( index: Int, + endpointName: String, queueSize: Int, numShuffleWriters: Int, epochIntervalMs: Long) @@ -36,7 +38,7 @@ case class ContinuousShuffleReadPartition( val env = SparkEnv.get.rpcEnv val receiver = new RPCContinuousShuffleReader( queueSize, numShuffleWriters, epochIntervalMs, env) - val endpoint = env.setupEndpoint(s"RPCContinuousShuffleReader-${UUID.randomUUID()}", receiver) + val endpoint = env.setupEndpoint(endpointName, receiver) TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) @@ -61,12 +63,14 @@ class ContinuousShuffleReadRDD( numPartitions: Int, queueSize: Int = 1024, numShuffleWriters: Int = 1, - epochIntervalMs: Long = 1000) + epochIntervalMs: Long = 1000, + val endpointNames: Seq[String] = Seq(s"RPCContinuousShuffleReader-${UUID.randomUUID()}")) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { (0 until numPartitions).map { partIndex => - ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters, epochIntervalMs) + ContinuousShuffleReadPartition( + partIndex, endpointNames(partIndex), queueSize, numShuffleWriters, epochIntervalMs) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala index 834e84675c7d5..502ae0d4822e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala @@ -46,7 +46,7 @@ private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends RPCContin * TODO: Support multiple source tasks. We need to output a single epoch marker once all * source tasks have sent one. */ -private[shuffle] class RPCContinuousShuffleReader( +private[continuous] class RPCContinuousShuffleReader( queueSize: Int, numShuffleWriters: Int, epochIntervalMs: Long, @@ -107,7 +107,7 @@ private[shuffle] class RPCContinuousShuffleReader( } logWarning( s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " + - s"for writers $writerIdsUncommitted to send epoch markers.") + s"for writers ${writerIdsUncommitted.mkString(",")} to send epoch markers.") // The completion service guarantees this future will be available immediately. case future => future.get() match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index d1c3498450096..0bf90b8063326 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -23,12 +23,13 @@ import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ +import scala.collection.SortedMap import scala.collection.mutable.ListBuffer import org.json4s.NoTypeHints import org.json4s.jackson.Serialization -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.{Encoder, Row, SQLContext} import org.apache.spark.sql.execution.streaming._ @@ -184,6 +185,14 @@ class ContinuousMemoryStreamInputPartitionReader( private var currentOffset = startOffset private var current: Option[Row] = None + // Defense-in-depth against failing to propagate the task context. Since it's not inheritable, + // we have to do a bit of error prone work to get it into every thread used by continuous + // processing. We hope that some unit test will end up instantiating a continuous memory stream + // in such cases. + if (TaskContext.get() == null) { + throw new IllegalStateException("Task context was not set!") + } + override def next(): Boolean = { current = getRecord while (current.isEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala index b7ef637f5270e..0223812600961 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala @@ -31,7 +31,8 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase { testStream(input.toDF().agg(max('value)), OutputMode.Complete)() } - assert(ex.getMessage.contains("Continuous processing does not support Aggregate operations")) + assert(ex.getMessage.contains( + "In continuous processing mode, coalesce(1) must be called before aggregate operation")) } test("basic") { @@ -50,6 +51,66 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase { } } + test("multiple partitions with coalesce") { + val input = ContinuousMemoryStream[Int] + + val df = input.toDF().coalesce(1).agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + + test("multiple partitions with coalesce - multiple transformations") { + val input = ContinuousMemoryStream[Int] + + // We use a barrier to make sure predicates both before and after coalesce work + val df = input.toDF() + .select('value as 'copy, 'value) + .where('copy =!= 1) + .planWithBarrier + .coalesce(1) + .where('copy =!= 2) + .agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(0), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + + test("multiple partitions with multiple coalesce") { + val input = ContinuousMemoryStream[Int] + + val df = input.toDF() + .coalesce(1) + .planWithBarrier + .coalesce(1) + .select('value as 'copy, 'value) + .agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + test("repeated restart") { withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { val input = ContinuousMemoryStream.singlePartition[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index e663fa8312da4..0e7e6febb53df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -92,7 +92,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { } } val reader = new ContinuousQueuedDataReader( - factory, + new ContinuousDataSourceRDDPartition(0, factory), mockContext, dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala index a8e3611b585cf..f84f3d49707bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle +import java.util.UUID + import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} @@ -124,7 +126,10 @@ class ContinuousShuffleSuite extends StreamTest { } test("reader - multiple partitions") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) + val rdd = new ContinuousShuffleReadRDD( + sparkContext, + numPartitions = 5, + endpointNames = Seq.fill(5)(s"endpt-${UUID.randomUUID()}")) // Send all data before processing to ensure there's no crossover. for (p <- rdd.partitions) { val part = p.asInstanceOf[ContinuousShuffleReadPartition]