Skip to content

Commit

Permalink
[SPARK-14257][SQL] Allow multiple continuous queries to be started fr…
Browse files Browse the repository at this point in the history
…om the same DataFrame

## What changes were proposed in this pull request?

Make StreamingRelation store the closure to create the source in StreamExecution so that we can start multiple continuous queries from the same DataFrame.

## How was this patch tested?

`test("DataFrame reuse")`

Author: Shixiong Zhu <[email protected]>

Closes #12049 from zsxwing/df-reuse.
  • Loading branch information
zsxwing authored and marmbrus committed Apr 5, 2016
1 parent f77f11c commit 463bac0
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.collection.mutable

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef
import org.apache.spark.sql.util.ContinuousQueryListener

Expand Down Expand Up @@ -178,11 +178,19 @@ class ContinuousQueryManager(sqlContext: SQLContext) {
throw new IllegalArgumentException(
s"Cannot start query with name $name as a query with that name is already active")
}
val logicalPlan = df.logicalPlan.transform {
case StreamingRelation(dataSource, _, output) =>
// Materialize source to avoid creating it in every batch
val source = dataSource.createSource()
// We still need to use the previous `output` instead of `source.schema` as attributes in
// "df.logicalPlan" has already used attributes of the previous `output`.
StreamingExecutionRelation(source, output)
}
val query = new StreamExecution(
sqlContext,
name,
checkpointLocation,
df.logicalPlan,
logicalPlan,
sink,
trigger)
query.start()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap)
Dataset.ofRows(sqlContext, StreamingRelation(dataSource.createSource()))
Dataset.ofRows(sqlContext, StreamingRelation(dataSource))
}

/**
Expand Down
6 changes: 4 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
import org.apache.spark.sql.execution.python.EvaluatePython
import org.apache.spark.sql.execution.streaming.StreamingRelation
import org.apache.spark.sql.execution.streaming.{StreamingExecutionRelation, StreamingRelation}
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -462,7 +462,9 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
@Experimental
def isStreaming: Boolean = logicalPlan.find(_.isInstanceOf[StreamingRelation]).isDefined
def isStreaming: Boolean = logicalPlan.find { n =>
n.isInstanceOf[StreamingRelation] || n.isInstanceOf[StreamingExecutionRelation]
}.isDefined

/**
* Displays the [[Dataset]] in a tabular form. Strings more than 20 characters will be truncated,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ import org.apache.spark.util.UninterruptibleThread
* and the results are committed transactionally to the given [[Sink]].
*/
class StreamExecution(
val sqlContext: SQLContext,
override val sqlContext: SQLContext,
override val name: String,
val checkpointRoot: String,
checkpointRoot: String,
private[sql] val logicalPlan: LogicalPlan,
val sink: Sink,
val trigger: Trigger) extends ContinuousQuery with Logging {
Expand All @@ -72,7 +72,7 @@ class StreamExecution(

/** All stream sources present the query plan. */
private val sources =
logicalPlan.collect { case s: StreamingRelation => s.source }
logicalPlan.collect { case s: StreamingExecutionRelation => s.source }

/** A list of unique sources in the query plan. */
private val uniqueSources = sources.distinct
Expand Down Expand Up @@ -295,7 +295,7 @@ class StreamExecution(
var replacements = new ArrayBuffer[(Attribute, Attribute)]
// Replace sources in the logical plan with data that has arrived since the last batch.
val withNewSources = logicalPlan transform {
case StreamingRelation(source, output) =>
case StreamingExecutionRelation(source, output) =>
newData.get(source).map { data =>
val newPlan = data.logicalPlan
assert(output.size == newPlan.output.size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,37 @@ package org.apache.spark.sql.execution.streaming

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.execution.datasources.DataSource

object StreamingRelation {
def apply(source: Source): StreamingRelation =
StreamingRelation(source, source.schema.toAttributes)
def apply(dataSource: DataSource): StreamingRelation = {
val source = dataSource.createSource()
StreamingRelation(dataSource, source.toString, source.schema.toAttributes)
}
}

/**
* Used to link a streaming [[DataSource]] into a
* [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. This is only used for creating
* a streaming [[org.apache.spark.sql.DataFrame]] from [[org.apache.spark.sql.DataFrameReader]].
* It should be used to create [[Source]] and converted to [[StreamingExecutionRelation]] when
* passing to [StreamExecution]] to run a query.
*/
case class StreamingRelation(dataSource: DataSource, sourceName: String, output: Seq[Attribute])
extends LeafNode {
override def toString: String = sourceName
}

/**
* Used to link a streaming [[Source]] of data into a
* [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]].
*/
case class StreamingRelation(source: Source, output: Seq[Attribute]) extends LeafNode {
case class StreamingExecutionRelation(source: Source, output: Seq[Attribute]) extends LeafNode {
override def toString: String = source.toString
}

object StreamingExecutionRelation {
def apply(source: Source): StreamingExecutionRelation = {
StreamingExecutionRelation(source, source.schema.toAttributes)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext}
import org.apache.spark.sql.catalyst.encoders.{encoderFor, RowEncoder}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.types.StructType

object MemoryStream {
Expand All @@ -45,7 +43,7 @@ object MemoryStream {
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
extends Source with Logging {
protected val encoder = encoderFor[A]
protected val logicalPlan = StreamingRelation(this)
protected val logicalPlan = StreamingExecutionRelation(this)
protected val output = logicalPlan.output
protected val batches = new ArrayBuffer[Dataset[A]]

Expand Down
5 changes: 3 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -66,9 +67,9 @@ import org.apache.spark.util.Utils
trait StreamTest extends QueryTest with Timeouts {

implicit class RichSource(s: Source) {
def toDF(): DataFrame = Dataset.ofRows(sqlContext, StreamingRelation(s))
def toDF(): DataFrame = Dataset.ofRows(sqlContext, StreamingExecutionRelation(s))

def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, StreamingRelation(s))
def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, StreamingExecutionRelation(s))
}

/** How long to wait for an active stream to catch up when checking a result. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.scalatest.time.SpanSugar._

import org.apache.spark.SparkException
import org.apache.spark.sql.{ContinuousQuery, Dataset, StreamTest}
import org.apache.spark.sql.execution.streaming.{MemorySink, MemoryStream, StreamExecution, StreamingRelation}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -294,8 +294,8 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
if (withError) {
logDebug(s"Terminating query ${queryToStop.name} with error")
queryToStop.asInstanceOf[StreamExecution].logicalPlan.collect {
case StreamingRelation(memoryStream, _) =>
memoryStream.asInstanceOf[MemoryStream[Int]].addData(0)
case StreamingExecutionRelation(source, _) =>
source.asInstanceOf[MemoryStream[Int]].addData(0)
}
} else {
logDebug(s"Stopping query ${queryToStop.name}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext {
}
reader.stream(path)
.queryExecution.analyzed
.collect { case StreamingRelation(s: FileStreamSource, _) => s }
.head
.collect { case StreamingRelation(dataSource, _, _) =>
dataSource.createSource().asInstanceOf[FileStreamSource]
}.head
}

val valueSchema = new StructType().add("value", StringType)
Expand All @@ -96,8 +97,9 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext {
reader.stream()
}
df.queryExecution.analyzed
.collect { case StreamingRelation(s: FileStreamSource, _) => s }
.head
.collect { case StreamingRelation(dataSource, _, _) =>
dataSource.createSource().asInstanceOf[FileStreamSource]
}.head
.schema
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@

package org.apache.spark.sql.streaming

import org.apache.spark.sql.{Row, StreamTest}
import org.scalatest.concurrent.Eventually._

import org.apache.spark.sql.{DataFrame, Row, SQLContext, StreamTest}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class StreamSuite extends StreamTest with SharedSQLContext {

Expand Down Expand Up @@ -81,4 +85,60 @@ class StreamSuite extends StreamTest with SharedSQLContext {
AddData(inputData, 1, 2, 3, 4),
CheckAnswer(2, 4))
}

test("DataFrame reuse") {
def assertDF(df: DataFrame) {
withTempDir { outputDir =>
withTempDir { checkpointDir =>
val query = df.write.format("parquet")
.option("checkpointLocation", checkpointDir.getAbsolutePath)
.startStream(outputDir.getAbsolutePath)
try {
query.processAllAvailable()
val outputDf = sqlContext.read.parquet(outputDir.getAbsolutePath).as[Long]
checkDataset[Long](outputDf, (0L to 10L).toArray: _*)
} finally {
query.stop()
}
}
}
}

val df = sqlContext.read.format(classOf[FakeDefaultSource].getName).stream()
assertDF(df)
assertDF(df)
}
}

/**
* A fake StreamSourceProvider thats creates a fake Source that cannot be reused.
*/
class FakeDefaultSource extends StreamSourceProvider {

override def createSource(
sqlContext: SQLContext,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source = {
// Create a fake Source that emits 0 to 10.
new Source {
private var offset = -1L

override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil)

override def getOffset: Option[Offset] = {
if (offset >= 10) {
None
} else {
offset += 1
Some(LongOffset(offset))
}
}

override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1
sqlContext.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a")
}
}
}
}

0 comments on commit 463bac0

Please sign in to comment.