Skip to content

Commit

Permalink
Allow multiple continuous queries to be started from the same DataFrame
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxwing committed Mar 29, 2016
1 parent a7a93a1 commit 50c39b8
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap)
Dataset.ofRows(sqlContext, StreamingRelation(dataSource.createSource()))
Dataset.ofRows(sqlContext, StreamingRelation(dataSource))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ import org.apache.spark.util.UninterruptibleThread
* and the results are committed transactionally to the given [[Sink]].
*/
class StreamExecution(
val sqlContext: SQLContext,
override val sqlContext: SQLContext,
override val name: String,
val checkpointRoot: String,
private[sql] val logicalPlan: LogicalPlan,
val sink: Sink) extends ContinuousQuery with Logging {
checkpointRoot: String,
_logicalPlan: LogicalPlan,
sink: Sink) extends ContinuousQuery with Logging {

/** An monitor used to wait/notify when batches complete. */
private val awaitBatchLock = new Object
Expand All @@ -71,9 +71,18 @@ class StreamExecution(
/** The current batchId or -1 if execution has not yet been initialized. */
private var currentBatchId: Long = -1

private[sql] val logicalPlan = _logicalPlan.transform {
case StreamingRelation(sourceCreator, output) =>
// Materialize source to avoid creating it in every batch
val source = sourceCreator()
// We still need to use the previous `output` instead of `source.schema` as attributes in
// "_logicalPlan" has already used attributes of the previous `output`.
StreamingRelation(() => source, output)
}

/** All stream sources present the query plan. */
private val sources =
logicalPlan.collect { case s: StreamingRelation => s.source }
logicalPlan.collect { case s: StreamingRelation => s.sourceCreator() }

/** A list of unique sources in the query plan. */
private val uniqueSources = sources.distinct
Expand Down Expand Up @@ -286,8 +295,8 @@ class StreamExecution(
var replacements = new ArrayBuffer[(Attribute, Attribute)]
// Replace sources in the logical plan with data that has arrived since the last batch.
val withNewSources = logicalPlan transform {
case StreamingRelation(source, output) =>
newData.get(source).map { data =>
case StreamingRelation(sourceCreator, output) =>
newData.get(sourceCreator()).map { data =>
val newPlan = data.logicalPlan
assert(output.size == newPlan.output.size,
s"Invalid batch: ${output.mkString(",")} != ${newPlan.output.mkString(",")}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,25 @@ package org.apache.spark.sql.execution.streaming

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.execution.datasources.DataSource

object StreamingRelation {
def apply(source: Source): StreamingRelation =
StreamingRelation(source, source.schema.toAttributes)
def apply(dataSource: DataSource): StreamingRelation = {
val source = dataSource.createSource()
StreamingRelation(dataSource.createSource, source.schema.toAttributes)
}

def apply(source: Source): StreamingRelation = {
StreamingRelation(() => source, source.schema.toAttributes)
}
}

/**
* Used to link a streaming [[Source]] of data into a
* [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]].
*/
case class StreamingRelation(source: Source, output: Seq[Attribute]) extends LeafNode {
override def toString: String = source.toString
case class StreamingRelation(
sourceCreator: () => Source,
output: Seq[Attribute]) extends LeafNode {
override def toString: String = sourceCreator().toString
}
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,8 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
if (withError) {
logDebug(s"Terminating query ${queryToStop.name} with error")
queryToStop.asInstanceOf[StreamExecution].logicalPlan.collect {
case StreamingRelation(memoryStream, _) =>
memoryStream.asInstanceOf[MemoryStream[Int]].addData(0)
case StreamingRelation(sourceCreator, _) =>
sourceCreator().asInstanceOf[MemoryStream[Int]].addData(0)
}
} else {
logDebug(s"Stopping query ${queryToStop.name}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext {
}
reader.stream(path)
.queryExecution.analyzed
.collect { case StreamingRelation(s: FileStreamSource, _) => s }
.head
.collect { case StreamingRelation(sourceCreator, _) =>
sourceCreator().asInstanceOf[FileStreamSource]
}.head
}

val valueSchema = new StructType().add("value", StringType)
Expand All @@ -96,8 +97,9 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext {
reader.stream()
}
df.queryExecution.analyzed
.collect { case StreamingRelation(s: FileStreamSource, _) => s }
.head
.collect { case StreamingRelation(sourceCreator, _) =>
sourceCreator().asInstanceOf[FileStreamSource]
}.head
.schema
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@

package org.apache.spark.sql.streaming

import org.apache.spark.sql.{Row, StreamTest}
import org.scalatest.concurrent.Eventually._

import org.apache.spark.sql.{DataFrame, Row, SQLContext, StreamTest}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class StreamSuite extends StreamTest with SharedSQLContext {

Expand Down Expand Up @@ -81,4 +85,62 @@ class StreamSuite extends StreamTest with SharedSQLContext {
AddData(inputData, 1, 2, 3, 4),
CheckAnswer(2, 4))
}

test("DataFrame reuse") {
def assertDF(df: DataFrame) {
withTempDir { outputDir =>
withTempDir { checkpointDir =>
val query = df.write.format("parquet")
.option("checkpointLocation", checkpointDir.getAbsolutePath)
.startStream(outputDir.getAbsolutePath)
try {
eventually(timeout(streamingTimeout)) {
val outputDf = sqlContext.read.parquet(outputDir.getAbsolutePath).as[Long]
checkDataset[Long](outputDf, (0L to 10L).toArray: _*)
}
} finally {
query.stop()
}
}
}
}

val df = sqlContext.read.format(classOf[FakeDefaultSource].getName).stream()
assertDF(df)
assertDF(df)
assertDF(df)
}
}

/**
* A fake StreamSourceProvider thats creates a fake Source that cannot be reused.
*/
class FakeDefaultSource extends StreamSourceProvider {

override def createSource(
sqlContext: SQLContext,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source = {
// Create a fake Source that emits 0 to 10.
new Source {
private var offset = -1L

override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil)

override def getOffset: Option[Offset] = {
if (offset >= 10) {
None
} else {
offset += 1
Some(LongOffset(offset))
}
}

override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1
sqlContext.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a")
}
}
}
}

0 comments on commit 50c39b8

Please sign in to comment.