From f6e6899a8b8af99cd06e84cae7c69e0fc35bc60a Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 28 Jun 2018 16:25:40 -0700 Subject: [PATCH] [SPARK-24386][SS] coalesce(1) aggregates in continuous processing ## What changes were proposed in this pull request? Provide a continuous processing implementation of coalesce(1), as well as allowing aggregates on top of it. The changes in ContinuousQueuedDataReader and such are to use split.index (the ID of the partition within the RDD currently being compute()d) rather than context.partitionId() (the partition ID of the scheduled task within the Spark job - that is, the post coalesce writer). In the absence of a narrow dependency, these values were previously always the same, so there was no need to distinguish. ## How was this patch tested? new unit test Author: Jose Torres Closes #21560 from jose-torres/coalesce. --- .../UnsupportedOperationChecker.scala | 11 ++ .../datasources/v2/DataSourceV2Strategy.scala | 16 ++- .../continuous/ContinuousCoalesceExec.scala | 51 +++++++ .../continuous/ContinuousCoalesceRDD.scala | 136 ++++++++++++++++++ .../continuous/ContinuousDataSourceRDD.scala | 7 +- .../continuous/ContinuousExecution.scala | 4 + .../ContinuousQueuedDataReader.scala | 6 +- .../shuffle/ContinuousShuffleReadRDD.scala | 10 +- .../shuffle/RPCContinuousShuffleReader.scala | 4 +- .../sources/ContinuousMemoryStream.scala | 11 +- .../ContinuousAggregationSuite.scala | 63 +++++++- .../ContinuousQueuedDataReaderSuite.scala | 2 +- .../shuffle/ContinuousShuffleSuite.scala | 7 +- 13 files changed, 310 insertions(+), 18 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 2bed41672fe33..5ced1ca200daa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -349,6 +349,17 @@ object UnsupportedOperationChecker { _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias | _: TypedFilter) => case node if node.nodeName == "StreamingRelationV2" => + case Repartition(1, false, _) => + case node: Aggregate => + val aboveSinglePartitionCoalesce = node.find { + case Repartition(1, false, _) => true + case _ => false + }.isDefined + + if (!aboveSinglePartitionCoalesce) { + throwError(s"In continuous processing mode, coalesce(1) must be called before " + + s"aggregate operation ${node.nodeName}.") + } case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 182aa2906cf1e..2a7f1de2c7c19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -22,11 +22,12 @@ import scala.collection.mutable import org.apache.spark.sql.{sources, Strategy} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader object DataSourceV2Strategy extends Strategy { @@ -141,6 +142,17 @@ object DataSourceV2Strategy extends Strategy { case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil + case Repartition(1, false, child) => + val isContinuous = child.collectFirst { + case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r + }.isDefined + + if (isContinuous) { + ContinuousCoalesceExec(1, planLater(child)) :: Nil + } else { + Nil + } + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala new file mode 100644 index 0000000000000..5f60343bacfaa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.UUID + +import org.apache.spark.{HashPartitioner, SparkEnv} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD} + +/** + * Physical plan for coalescing a continuous processing plan. + * + * Currently, only coalesces to a single partition are supported. `numPartitions` must be 1. + */ +case class ContinuousCoalesceExec(numPartitions: Int, child: SparkPlan) extends SparkPlan { + override def output: Seq[Attribute] = child.output + + override def children: Seq[SparkPlan] = child :: Nil + + override def outputPartitioning: Partitioning = SinglePartition + + override def doExecute(): RDD[InternalRow] = { + assert(numPartitions == 1) + new ContinuousCoalesceRDD( + sparkContext, + numPartitions, + conf.continuousStreamingExecutorQueueSize, + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_INTERVAL_KEY).toLong, + child.execute()) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala new file mode 100644 index 0000000000000..ba85b355f974f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.UUID + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.continuous.shuffle._ +import org.apache.spark.util.ThreadUtils + +case class ContinuousCoalesceRDDPartition( + index: Int, + endpointName: String, + queueSize: Int, + numShuffleWriters: Int, + epochIntervalMs: Long) + extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (reader: ContinuousShuffleReader, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new RPCContinuousShuffleReader( + queueSize, numShuffleWriters, epochIntervalMs, env) + val endpoint = env.setupEndpoint(endpointName, receiver) + + TaskContext.get().addTaskCompletionListener { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } + // This flag will be flipped on the executors to indicate that the threads processing + // partitions of the write-side RDD have been started. These will run indefinitely + // asynchronously as epochs of the coalesce RDD complete on the read side. + private[continuous] var writersInitialized: Boolean = false +} + +/** + * RDD for continuous coalescing. Asynchronously writes all partitions of `prev` into a local + * continuous shuffle, and then reads them in the task thread using `reader`. + */ +class ContinuousCoalesceRDD( + context: SparkContext, + numPartitions: Int, + readerQueueSize: Int, + epochIntervalMs: Long, + prev: RDD[InternalRow]) + extends RDD[InternalRow](context, Nil) { + + // When we support more than 1 target partition, we'll need to figure out how to pass in the + // required partitioner. + private val outputPartitioner = new HashPartitioner(1) + + private val readerEndpointNames = (0 until numPartitions).map { i => + s"ContinuousCoalesceRDD-part$i-${UUID.randomUUID()}" + } + + override def getPartitions: Array[Partition] = { + (0 until numPartitions).map { partIndex => + ContinuousCoalesceRDDPartition( + partIndex, + readerEndpointNames(partIndex), + readerQueueSize, + prev.getNumPartitions, + epochIntervalMs) + }.toArray + } + + private lazy val threadPool = ThreadUtils.newDaemonFixedThreadPool( + prev.getNumPartitions, + this.name) + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val part = split.asInstanceOf[ContinuousCoalesceRDDPartition] + + if (!part.writersInitialized) { + val rpcEnv = SparkEnv.get.rpcEnv + + // trigger lazy initialization + part.endpoint + val endpointRefs = readerEndpointNames.map { endpointName => + rpcEnv.setupEndpointRef(rpcEnv.address, endpointName) + } + + val runnables = prev.partitions.map { prevSplit => + new Runnable() { + override def run(): Unit = { + TaskContext.setTaskContext(context) + + val writer: ContinuousShuffleWriter = new RPCContinuousShuffleWriter( + prevSplit.index, outputPartitioner, endpointRefs.toArray) + + EpochTracker.initializeCurrentEpoch( + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) + while (!context.isInterrupted() && !context.isCompleted()) { + writer.write(prev.compute(prevSplit, context).asInstanceOf[Iterator[UnsafeRow]]) + // Note that current epoch is a non-inheritable thread local, so each writer thread + // can properly increment its own epoch without affecting the main task thread. + EpochTracker.incrementCurrentEpoch() + } + } + } + } + + context.addTaskCompletionListener { ctx => + threadPool.shutdownNow() + } + + part.writersInitialized = true + + runnables.foreach(threadPool.execute) + } + + part.reader.read() + } + + override def clearDependencies(): Unit = { + throw new IllegalStateException("Continuous RDDs cannot be checkpointed") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index a7ccce10b0cee..73868d5967e90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -51,11 +51,11 @@ class ContinuousDataSourceRDD( sc: SparkContext, dataQueueSize: Int, epochPollIntervalMs: Long, - @transient private val readerFactories: Seq[InputPartition[UnsafeRow]]) + private val readerInputPartitions: Seq[InputPartition[UnsafeRow]]) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerFactories.zipWithIndex.map { + readerInputPartitions.zipWithIndex.map { case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition) }.toArray } @@ -74,8 +74,7 @@ class ContinuousDataSourceRDD( val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition] if (partition.queueReader == null) { partition.queueReader = - new ContinuousQueuedDataReader( - partition.inputPartition, context, dataQueueSize, epochPollIntervalMs) + new ContinuousQueuedDataReader(partition, context, dataQueueSize, epochPollIntervalMs) } partition.queueReader diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index e3d0cea608b2a..a0bb8292d7766 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -216,6 +216,9 @@ class ContinuousExecution( currentEpochCoordinatorId = epochCoordinatorId sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId) + sparkSessionForQuery.sparkContext.setLocalProperty( + ContinuousExecution.EPOCH_INTERVAL_KEY, + trigger.asInstanceOf[ContinuousTrigger].intervalMs.toString) // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = @@ -382,4 +385,5 @@ class ContinuousExecution( object ContinuousExecution { val START_EPOCH_KEY = "__continuous_start_epoch" val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id" + val EPOCH_INTERVAL_KEY = "__continuous_epoch_interval" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index f38577b6a9f16..8c74b8244d096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -37,11 +37,11 @@ import org.apache.spark.util.ThreadUtils * offsets across epochs. Each compute() should call the next() method here until null is returned. */ class ContinuousQueuedDataReader( - partition: InputPartition[UnsafeRow], + partition: ContinuousDataSourceRDDPartition, context: TaskContext, dataQueueSize: Int, epochPollIntervalMs: Long) extends Closeable { - private val reader = partition.createPartitionReader() + private val reader = partition.inputPartition.createPartitionReader() // Important sequencing - we must get our starting point before the provider threads start running private var currentOffset: PartitionOffset = @@ -113,7 +113,7 @@ class ContinuousQueuedDataReader( currentEntry match { case EpochMarker => epochCoordEndpoint.send(ReportPartitionOffset( - context.partitionId(), EpochTracker.getCurrentEpoch.get, currentOffset)) + partition.index, EpochTracker.getCurrentEpoch.get, currentOffset)) null case ContinuousRow(row, offset) => currentOffset = offset diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index cf6572d3de1f7..518223f3cd008 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -21,12 +21,14 @@ import java.util.UUID import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcAddress import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator case class ContinuousShuffleReadPartition( index: Int, + endpointName: String, queueSize: Int, numShuffleWriters: Int, epochIntervalMs: Long) @@ -36,7 +38,7 @@ case class ContinuousShuffleReadPartition( val env = SparkEnv.get.rpcEnv val receiver = new RPCContinuousShuffleReader( queueSize, numShuffleWriters, epochIntervalMs, env) - val endpoint = env.setupEndpoint(s"RPCContinuousShuffleReader-${UUID.randomUUID()}", receiver) + val endpoint = env.setupEndpoint(endpointName, receiver) TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) @@ -61,12 +63,14 @@ class ContinuousShuffleReadRDD( numPartitions: Int, queueSize: Int = 1024, numShuffleWriters: Int = 1, - epochIntervalMs: Long = 1000) + epochIntervalMs: Long = 1000, + val endpointNames: Seq[String] = Seq(s"RPCContinuousShuffleReader-${UUID.randomUUID()}")) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { (0 until numPartitions).map { partIndex => - ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters, epochIntervalMs) + ContinuousShuffleReadPartition( + partIndex, endpointNames(partIndex), queueSize, numShuffleWriters, epochIntervalMs) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala index 834e84675c7d5..502ae0d4822e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala @@ -46,7 +46,7 @@ private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends RPCContin * TODO: Support multiple source tasks. We need to output a single epoch marker once all * source tasks have sent one. */ -private[shuffle] class RPCContinuousShuffleReader( +private[continuous] class RPCContinuousShuffleReader( queueSize: Int, numShuffleWriters: Int, epochIntervalMs: Long, @@ -107,7 +107,7 @@ private[shuffle] class RPCContinuousShuffleReader( } logWarning( s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " + - s"for writers $writerIdsUncommitted to send epoch markers.") + s"for writers ${writerIdsUncommitted.mkString(",")} to send epoch markers.") // The completion service guarantees this future will be available immediately. case future => future.get() match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index d1c3498450096..0bf90b8063326 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -23,12 +23,13 @@ import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ +import scala.collection.SortedMap import scala.collection.mutable.ListBuffer import org.json4s.NoTypeHints import org.json4s.jackson.Serialization -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.{Encoder, Row, SQLContext} import org.apache.spark.sql.execution.streaming._ @@ -184,6 +185,14 @@ class ContinuousMemoryStreamInputPartitionReader( private var currentOffset = startOffset private var current: Option[Row] = None + // Defense-in-depth against failing to propagate the task context. Since it's not inheritable, + // we have to do a bit of error prone work to get it into every thread used by continuous + // processing. We hope that some unit test will end up instantiating a continuous memory stream + // in such cases. + if (TaskContext.get() == null) { + throw new IllegalStateException("Task context was not set!") + } + override def next(): Boolean = { current = getRecord while (current.isEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala index b7ef637f5270e..0223812600961 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala @@ -31,7 +31,8 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase { testStream(input.toDF().agg(max('value)), OutputMode.Complete)() } - assert(ex.getMessage.contains("Continuous processing does not support Aggregate operations")) + assert(ex.getMessage.contains( + "In continuous processing mode, coalesce(1) must be called before aggregate operation")) } test("basic") { @@ -50,6 +51,66 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase { } } + test("multiple partitions with coalesce") { + val input = ContinuousMemoryStream[Int] + + val df = input.toDF().coalesce(1).agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + + test("multiple partitions with coalesce - multiple transformations") { + val input = ContinuousMemoryStream[Int] + + // We use a barrier to make sure predicates both before and after coalesce work + val df = input.toDF() + .select('value as 'copy, 'value) + .where('copy =!= 1) + .planWithBarrier + .coalesce(1) + .where('copy =!= 2) + .agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(0), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + + test("multiple partitions with multiple coalesce") { + val input = ContinuousMemoryStream[Int] + + val df = input.toDF() + .coalesce(1) + .planWithBarrier + .coalesce(1) + .select('value as 'copy, 'value) + .agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + test("repeated restart") { withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { val input = ContinuousMemoryStream.singlePartition[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index e663fa8312da4..0e7e6febb53df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -92,7 +92,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { } } val reader = new ContinuousQueuedDataReader( - factory, + new ContinuousDataSourceRDDPartition(0, factory), mockContext, dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala index a8e3611b585cf..f84f3d49707bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle +import java.util.UUID + import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} @@ -124,7 +126,10 @@ class ContinuousShuffleSuite extends StreamTest { } test("reader - multiple partitions") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) + val rdd = new ContinuousShuffleReadRDD( + sparkContext, + numPartitions = 5, + endpointNames = Seq.fill(5)(s"endpt-${UUID.randomUUID()}")) // Send all data before processing to ensure there's no crossover. for (p <- rdd.partitions) { val part = p.asInstanceOf[ContinuousShuffleReadPartition]