Skip to content

Commit

Permalink
[SPARK-24386][SS] coalesce(1) aggregates in continuous processing
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Provide a continuous processing implementation of coalesce(1), as well as allowing aggregates on top of it.

The changes in ContinuousQueuedDataReader and such are to use split.index (the ID of the partition within the RDD currently being compute()d) rather than context.partitionId() (the partition ID of the scheduled task within the Spark job - that is, the post coalesce writer). In the absence of a narrow dependency, these values were previously always the same, so there was no need to distinguish.

## How was this patch tested?

new unit test

Author: Jose Torres <[email protected]>

Closes #21560 from jose-torres/coalesce.
  • Loading branch information
jose-torres authored and tdas committed Jun 28, 2018
1 parent 2224861 commit f6e6899
Show file tree
Hide file tree
Showing 13 changed files with 310 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,17 @@ object UnsupportedOperationChecker {
_: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias |
_: TypedFilter) =>
case node if node.nodeName == "StreamingRelationV2" =>
case Repartition(1, false, _) =>
case node: Aggregate =>
val aboveSinglePartitionCoalesce = node.find {
case Repartition(1, false, _) => true
case _ => false
}.isDefined

if (!aboveSinglePartitionCoalesce) {
throwError(s"In continuous processing mode, coalesce(1) must be called before " +
s"aggregate operation ${node.nodeName}.")
}
case node =>
throwError(s"Continuous processing does not support ${node.nodeName} operations.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ import scala.collection.mutable
import org.apache.spark.sql.{sources, Strategy}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader

object DataSourceV2Strategy extends Strategy {

Expand Down Expand Up @@ -141,6 +142,17 @@ object DataSourceV2Strategy extends Strategy {
case WriteToContinuousDataSource(writer, query) =>
WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil

case Repartition(1, false, child) =>
val isContinuous = child.collectFirst {
case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r
}.isDefined

if (isContinuous) {
ContinuousCoalesceExec(1, planLater(child)) :: Nil
} else {
Nil
}

case _ => Nil
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.continuous

import java.util.UUID

import org.apache.spark.{HashPartitioner, SparkEnv}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD}

/**
* Physical plan for coalescing a continuous processing plan.
*
* Currently, only coalesces to a single partition are supported. `numPartitions` must be 1.
*/
case class ContinuousCoalesceExec(numPartitions: Int, child: SparkPlan) extends SparkPlan {
override def output: Seq[Attribute] = child.output

override def children: Seq[SparkPlan] = child :: Nil

override def outputPartitioning: Partitioning = SinglePartition

override def doExecute(): RDD[InternalRow] = {
assert(numPartitions == 1)
new ContinuousCoalesceRDD(
sparkContext,
numPartitions,
conf.continuousStreamingExecutorQueueSize,
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_INTERVAL_KEY).toLong,
child.execute())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.continuous

import java.util.UUID

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming.continuous.shuffle._
import org.apache.spark.util.ThreadUtils

case class ContinuousCoalesceRDDPartition(
index: Int,
endpointName: String,
queueSize: Int,
numShuffleWriters: Int,
epochIntervalMs: Long)
extends Partition {
// Initialized only on the executor, and only once even as we call compute() multiple times.
lazy val (reader: ContinuousShuffleReader, endpoint) = {
val env = SparkEnv.get.rpcEnv
val receiver = new RPCContinuousShuffleReader(
queueSize, numShuffleWriters, epochIntervalMs, env)
val endpoint = env.setupEndpoint(endpointName, receiver)

TaskContext.get().addTaskCompletionListener { ctx =>
env.stop(endpoint)
}
(receiver, endpoint)
}
// This flag will be flipped on the executors to indicate that the threads processing
// partitions of the write-side RDD have been started. These will run indefinitely
// asynchronously as epochs of the coalesce RDD complete on the read side.
private[continuous] var writersInitialized: Boolean = false
}

/**
* RDD for continuous coalescing. Asynchronously writes all partitions of `prev` into a local
* continuous shuffle, and then reads them in the task thread using `reader`.
*/
class ContinuousCoalesceRDD(
context: SparkContext,
numPartitions: Int,
readerQueueSize: Int,
epochIntervalMs: Long,
prev: RDD[InternalRow])
extends RDD[InternalRow](context, Nil) {

// When we support more than 1 target partition, we'll need to figure out how to pass in the
// required partitioner.
private val outputPartitioner = new HashPartitioner(1)

private val readerEndpointNames = (0 until numPartitions).map { i =>
s"ContinuousCoalesceRDD-part$i-${UUID.randomUUID()}"
}

override def getPartitions: Array[Partition] = {
(0 until numPartitions).map { partIndex =>
ContinuousCoalesceRDDPartition(
partIndex,
readerEndpointNames(partIndex),
readerQueueSize,
prev.getNumPartitions,
epochIntervalMs)
}.toArray
}

private lazy val threadPool = ThreadUtils.newDaemonFixedThreadPool(
prev.getNumPartitions,
this.name)

override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
val part = split.asInstanceOf[ContinuousCoalesceRDDPartition]

if (!part.writersInitialized) {
val rpcEnv = SparkEnv.get.rpcEnv

// trigger lazy initialization
part.endpoint
val endpointRefs = readerEndpointNames.map { endpointName =>
rpcEnv.setupEndpointRef(rpcEnv.address, endpointName)
}

val runnables = prev.partitions.map { prevSplit =>
new Runnable() {
override def run(): Unit = {
TaskContext.setTaskContext(context)

val writer: ContinuousShuffleWriter = new RPCContinuousShuffleWriter(
prevSplit.index, outputPartitioner, endpointRefs.toArray)

EpochTracker.initializeCurrentEpoch(
context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong)
while (!context.isInterrupted() && !context.isCompleted()) {
writer.write(prev.compute(prevSplit, context).asInstanceOf[Iterator[UnsafeRow]])
// Note that current epoch is a non-inheritable thread local, so each writer thread
// can properly increment its own epoch without affecting the main task thread.
EpochTracker.incrementCurrentEpoch()
}
}
}
}

context.addTaskCompletionListener { ctx =>
threadPool.shutdownNow()
}

part.writersInitialized = true

runnables.foreach(threadPool.execute)
}

part.reader.read()
}

override def clearDependencies(): Unit = {
throw new IllegalStateException("Continuous RDDs cannot be checkpointed")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ class ContinuousDataSourceRDD(
sc: SparkContext,
dataQueueSize: Int,
epochPollIntervalMs: Long,
@transient private val readerFactories: Seq[InputPartition[UnsafeRow]])
private val readerInputPartitions: Seq[InputPartition[UnsafeRow]])
extends RDD[UnsafeRow](sc, Nil) {

override protected def getPartitions: Array[Partition] = {
readerFactories.zipWithIndex.map {
readerInputPartitions.zipWithIndex.map {
case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition)
}.toArray
}
Expand All @@ -74,8 +74,7 @@ class ContinuousDataSourceRDD(
val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition]
if (partition.queueReader == null) {
partition.queueReader =
new ContinuousQueuedDataReader(
partition.inputPartition, context, dataQueueSize, epochPollIntervalMs)
new ContinuousQueuedDataReader(partition, context, dataQueueSize, epochPollIntervalMs)
}

partition.queueReader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ class ContinuousExecution(
currentEpochCoordinatorId = epochCoordinatorId
sparkSessionForQuery.sparkContext.setLocalProperty(
ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId)
sparkSessionForQuery.sparkContext.setLocalProperty(
ContinuousExecution.EPOCH_INTERVAL_KEY,
trigger.asInstanceOf[ContinuousTrigger].intervalMs.toString)

// Use the parent Spark session for the endpoint since it's where this query ID is registered.
val epochEndpoint =
Expand Down Expand Up @@ -382,4 +385,5 @@ class ContinuousExecution(
object ContinuousExecution {
val START_EPOCH_KEY = "__continuous_start_epoch"
val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id"
val EPOCH_INTERVAL_KEY = "__continuous_epoch_interval"
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ import org.apache.spark.util.ThreadUtils
* offsets across epochs. Each compute() should call the next() method here until null is returned.
*/
class ContinuousQueuedDataReader(
partition: InputPartition[UnsafeRow],
partition: ContinuousDataSourceRDDPartition,
context: TaskContext,
dataQueueSize: Int,
epochPollIntervalMs: Long) extends Closeable {
private val reader = partition.createPartitionReader()
private val reader = partition.inputPartition.createPartitionReader()

// Important sequencing - we must get our starting point before the provider threads start running
private var currentOffset: PartitionOffset =
Expand Down Expand Up @@ -113,7 +113,7 @@ class ContinuousQueuedDataReader(
currentEntry match {
case EpochMarker =>
epochCoordEndpoint.send(ReportPartitionOffset(
context.partitionId(), EpochTracker.getCurrentEpoch.get, currentOffset))
partition.index, EpochTracker.getCurrentEpoch.get, currentOffset))
null
case ContinuousRow(row, offset) =>
currentOffset = offset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ import java.util.UUID

import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.RpcAddress
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.NextIterator

case class ContinuousShuffleReadPartition(
index: Int,
endpointName: String,
queueSize: Int,
numShuffleWriters: Int,
epochIntervalMs: Long)
Expand All @@ -36,7 +38,7 @@ case class ContinuousShuffleReadPartition(
val env = SparkEnv.get.rpcEnv
val receiver = new RPCContinuousShuffleReader(
queueSize, numShuffleWriters, epochIntervalMs, env)
val endpoint = env.setupEndpoint(s"RPCContinuousShuffleReader-${UUID.randomUUID()}", receiver)
val endpoint = env.setupEndpoint(endpointName, receiver)

TaskContext.get().addTaskCompletionListener { ctx =>
env.stop(endpoint)
Expand All @@ -61,12 +63,14 @@ class ContinuousShuffleReadRDD(
numPartitions: Int,
queueSize: Int = 1024,
numShuffleWriters: Int = 1,
epochIntervalMs: Long = 1000)
epochIntervalMs: Long = 1000,
val endpointNames: Seq[String] = Seq(s"RPCContinuousShuffleReader-${UUID.randomUUID()}"))
extends RDD[UnsafeRow](sc, Nil) {

override protected def getPartitions: Array[Partition] = {
(0 until numPartitions).map { partIndex =>
ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters, epochIntervalMs)
ContinuousShuffleReadPartition(
partIndex, endpointNames(partIndex), queueSize, numShuffleWriters, epochIntervalMs)
}.toArray
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends RPCContin
* TODO: Support multiple source tasks. We need to output a single epoch marker once all
* source tasks have sent one.
*/
private[shuffle] class RPCContinuousShuffleReader(
private[continuous] class RPCContinuousShuffleReader(
queueSize: Int,
numShuffleWriters: Int,
epochIntervalMs: Long,
Expand Down Expand Up @@ -107,7 +107,7 @@ private[shuffle] class RPCContinuousShuffleReader(
}
logWarning(
s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " +
s"for writers $writerIdsUncommitted to send epoch markers.")
s"for writers ${writerIdsUncommitted.mkString(",")} to send epoch markers.")

// The completion service guarantees this future will be available immediately.
case future => future.get() match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ import java.util.concurrent.atomic.AtomicInteger
import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
import scala.collection.SortedMap
import scala.collection.mutable.ListBuffer

import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization

import org.apache.spark.SparkEnv
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.sql.{Encoder, Row, SQLContext}
import org.apache.spark.sql.execution.streaming._
Expand Down Expand Up @@ -184,6 +185,14 @@ class ContinuousMemoryStreamInputPartitionReader(
private var currentOffset = startOffset
private var current: Option[Row] = None

// Defense-in-depth against failing to propagate the task context. Since it's not inheritable,
// we have to do a bit of error prone work to get it into every thread used by continuous
// processing. We hope that some unit test will end up instantiating a continuous memory stream
// in such cases.
if (TaskContext.get() == null) {
throw new IllegalStateException("Task context was not set!")
}

override def next(): Boolean = {
current = getRecord
while (current.isEmpty) {
Expand Down
Loading

0 comments on commit f6e6899

Please sign in to comment.