diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala index 2306df09b8b76..d7f71bd4b0895 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.collection.mutable import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.util.ContinuousQueryListener @@ -178,11 +178,19 @@ class ContinuousQueryManager(sqlContext: SQLContext) { throw new IllegalArgumentException( s"Cannot start query with name $name as a query with that name is already active") } + val logicalPlan = df.logicalPlan.transform { + case StreamingRelation(dataSource, _, output) => + // Materialize source to avoid creating it in every batch + val source = dataSource.createSource() + // We still need to use the previous `output` instead of `source.schema` as attributes in + // "df.logicalPlan" has already used attributes of the previous `output`. + StreamingExecutionRelation(source, output) + } val query = new StreamExecution( sqlContext, name, checkpointLocation, - df.logicalPlan, + logicalPlan, sink, trigger) query.start() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index a5a6e01e99874..15f2344df6ab2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -176,7 +176,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { userSpecifiedSchema = userSpecifiedSchema, className = source, options = extraOptions.toMap) - Dataset.ofRows(sqlContext, StreamingRelation(dataSource.createSource())) + Dataset.ofRows(sqlContext, StreamingRelation(dataSource)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index db2134b020167..f472a5068e4b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.execution.streaming.StreamingRelation +import org.apache.spark.sql.execution.streaming.{StreamingExecutionRelation, StreamingRelation} import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -462,7 +462,9 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @Experimental - def isStreaming: Boolean = logicalPlan.find(_.isInstanceOf[StreamingRelation]).isDefined + def isStreaming: Boolean = logicalPlan.find { n => + n.isInstanceOf[StreamingRelation] || n.isInstanceOf[StreamingExecutionRelation] + }.isDefined /** * Displays the [[Dataset]] in a tabular form. Strings more than 20 characters will be truncated, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 64f80699ced34..3e4acb752a573 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -43,9 +43,9 @@ import org.apache.spark.util.UninterruptibleThread * and the results are committed transactionally to the given [[Sink]]. */ class StreamExecution( - val sqlContext: SQLContext, + override val sqlContext: SQLContext, override val name: String, - val checkpointRoot: String, + checkpointRoot: String, private[sql] val logicalPlan: LogicalPlan, val sink: Sink, val trigger: Trigger) extends ContinuousQuery with Logging { @@ -72,7 +72,7 @@ class StreamExecution( /** All stream sources present the query plan. */ private val sources = - logicalPlan.collect { case s: StreamingRelation => s.source } + logicalPlan.collect { case s: StreamingExecutionRelation => s.source } /** A list of unique sources in the query plan. */ private val uniqueSources = sources.distinct @@ -295,7 +295,7 @@ class StreamExecution( var replacements = new ArrayBuffer[(Attribute, Attribute)] // Replace sources in the logical plan with data that has arrived since the last batch. val withNewSources = logicalPlan transform { - case StreamingRelation(source, output) => + case StreamingExecutionRelation(source, output) => newData.get(source).map { data => val newPlan = data.logicalPlan assert(output.size == newPlan.output.size, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index e35c444348f48..f951dea735d9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -19,16 +19,37 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LeafNode +import org.apache.spark.sql.execution.datasources.DataSource object StreamingRelation { - def apply(source: Source): StreamingRelation = - StreamingRelation(source, source.schema.toAttributes) + def apply(dataSource: DataSource): StreamingRelation = { + val source = dataSource.createSource() + StreamingRelation(dataSource, source.toString, source.schema.toAttributes) + } +} + +/** + * Used to link a streaming [[DataSource]] into a + * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. This is only used for creating + * a streaming [[org.apache.spark.sql.DataFrame]] from [[org.apache.spark.sql.DataFrameReader]]. + * It should be used to create [[Source]] and converted to [[StreamingExecutionRelation]] when + * passing to [StreamExecution]] to run a query. + */ +case class StreamingRelation(dataSource: DataSource, sourceName: String, output: Seq[Attribute]) + extends LeafNode { + override def toString: String = sourceName } /** * Used to link a streaming [[Source]] of data into a * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. */ -case class StreamingRelation(source: Source, output: Seq[Attribute]) extends LeafNode { +case class StreamingExecutionRelation(source: Source, output: Seq[Attribute]) extends LeafNode { override def toString: String = source.toString } + +object StreamingExecutionRelation { + def apply(source: Source): StreamingExecutionRelation = { + StreamingExecutionRelation(source, source.schema.toAttributes) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 7d97f81b0f10a..b652530d7c78c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -22,11 +22,9 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal -import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext} -import org.apache.spark.sql.catalyst.encoders.{encoderFor, RowEncoder} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.types.StructType object MemoryStream { @@ -45,7 +43,7 @@ object MemoryStream { case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) extends Source with Logging { protected val encoder = encoderFor[A] - protected val logicalPlan = StreamingRelation(this) + protected val logicalPlan = StreamingExecutionRelation(this) protected val output = logicalPlan.output protected val batches = new ArrayBuffer[Dataset[A]] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 3444e56e9ec90..6ccc99fe179d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -36,6 +36,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.util.Utils @@ -66,9 +67,9 @@ import org.apache.spark.util.Utils trait StreamTest extends QueryTest with Timeouts { implicit class RichSource(s: Source) { - def toDF(): DataFrame = Dataset.ofRows(sqlContext, StreamingRelation(s)) + def toDF(): DataFrame = Dataset.ofRows(sqlContext, StreamingExecutionRelation(s)) - def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, StreamingRelation(s)) + def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, StreamingExecutionRelation(s)) } /** How long to wait for an active stream to catch up when checking a result. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala index 29bd3e018ed04..33787de9da388 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException import org.apache.spark.sql.{ContinuousQuery, Dataset, StreamTest} -import org.apache.spark.sql.execution.streaming.{MemorySink, MemoryStream, StreamExecution, StreamingRelation} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -294,8 +294,8 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with if (withError) { logDebug(s"Terminating query ${queryToStop.name} with error") queryToStop.asInstanceOf[StreamExecution].logicalPlan.collect { - case StreamingRelation(memoryStream, _) => - memoryStream.asInstanceOf[MemoryStream[Int]].addData(0) + case StreamingExecutionRelation(source, _) => + source.asInstanceOf[MemoryStream[Int]].addData(0) } } else { logDebug(s"Stopping query ${queryToStop.name}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 054f5c9fa2d8c..09daa7f81a979 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -71,8 +71,9 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { } reader.stream(path) .queryExecution.analyzed - .collect { case StreamingRelation(s: FileStreamSource, _) => s } - .head + .collect { case StreamingRelation(dataSource, _, _) => + dataSource.createSource().asInstanceOf[FileStreamSource] + }.head } val valueSchema = new StructType().add("value", StringType) @@ -96,8 +97,9 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { reader.stream() } df.queryExecution.analyzed - .collect { case StreamingRelation(s: FileStreamSource, _) => s } - .head + .collect { case StreamingRelation(dataSource, _, _) => + dataSource.createSource().asInstanceOf[FileStreamSource] + }.head .schema } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index fbb1792596b18..e4ea55552691d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -17,9 +17,13 @@ package org.apache.spark.sql.streaming -import org.apache.spark.sql.{Row, StreamTest} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.sql.{DataFrame, Row, SQLContext, StreamTest} import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class StreamSuite extends StreamTest with SharedSQLContext { @@ -81,4 +85,60 @@ class StreamSuite extends StreamTest with SharedSQLContext { AddData(inputData, 1, 2, 3, 4), CheckAnswer(2, 4)) } + + test("DataFrame reuse") { + def assertDF(df: DataFrame) { + withTempDir { outputDir => + withTempDir { checkpointDir => + val query = df.write.format("parquet") + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .startStream(outputDir.getAbsolutePath) + try { + query.processAllAvailable() + val outputDf = sqlContext.read.parquet(outputDir.getAbsolutePath).as[Long] + checkDataset[Long](outputDf, (0L to 10L).toArray: _*) + } finally { + query.stop() + } + } + } + } + + val df = sqlContext.read.format(classOf[FakeDefaultSource].getName).stream() + assertDF(df) + assertDF(df) + } +} + +/** + * A fake StreamSourceProvider thats creates a fake Source that cannot be reused. + */ +class FakeDefaultSource extends StreamSourceProvider { + + override def createSource( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + // Create a fake Source that emits 0 to 10. + new Source { + private var offset = -1L + + override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil) + + override def getOffset: Option[Offset] = { + if (offset >= 10) { + None + } else { + offset += 1 + Some(LongOffset(offset)) + } + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1 + sqlContext.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a") + } + } + } }