Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-14176][SQL]Add DataFrameWriter.trigger to set the stream batch period #11976

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,20 @@ class ContinuousQueryManager(sqlContext: SQLContext) {
name: String,
checkpointLocation: String,
df: DataFrame,
sink: Sink): ContinuousQuery = {
sink: Sink,
trigger: Trigger = ProcessingTime(0)): ContinuousQuery = {
activeQueriesLock.synchronized {
if (activeQueries.contains(name)) {
throw new IllegalArgumentException(
s"Cannot start query with name $name as a query with that name is already active")
}
val query = new StreamExecution(sqlContext, name, checkpointLocation, df.logicalPlan, sink)
val query = new StreamExecution(
sqlContext,
name,
checkpointLocation,
df.logicalPlan,
sink,
trigger)
query.start()
activeQueries.put(name, query)
query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,35 @@ final class DataFrameWriter private[sql](df: DataFrame) {
this
}

/**
* :: Experimental ::
* Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will run
* the query as fast as possible.
*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This scala doc should have an example right here.

write.trigger(ProcessingTime("10 seconds"))
write.trigger("10 seconds")     // less verbose

* Scala Example:
* {{{
* def.writer.trigger(ProcessingTime("10 seconds"))
*
* import scala.concurrent.duration._
* def.writer.trigger(ProcessingTime(10.seconds))
* }}}
*
* Java Example:
* {{{
* def.writer.trigger(ProcessingTime.create("10 seconds"))
*
* import java.util.concurrent.TimeUnit
* def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS))
* }}}
*
* @since 2.0.0
*/
@Experimental
def trigger(trigger: Trigger): DataFrameWriter = {
this.trigger = trigger
this
}

/**
* Specifies the underlying output data source. Built-in options include "parquet", "json", etc.
*
Expand Down Expand Up @@ -261,7 +290,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
queryName,
checkpointLocation,
df,
dataSource.createSink())
dataSource.createSink(),
trigger)
}

/**
Expand Down Expand Up @@ -552,6 +582,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {

private var mode: SaveMode = SaveMode.ErrorIfExists

private var trigger: Trigger = ProcessingTime(0L)

private var extraOptions = new scala.collection.mutable.HashMap[String, String]

private var partitioningColumns: Option[Seq[String]] = None
Expand Down
133 changes: 133 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.util.concurrent.TimeUnit

import scala.concurrent.duration.Duration

import org.apache.commons.lang3.StringUtils

import org.apache.spark.annotation.Experimental
import org.apache.spark.unsafe.types.CalendarInterval

/**
* :: Experimental ::
* Used to indicate how often results should be produced by a [[ContinuousQuery]].
*/
@Experimental
sealed trait Trigger {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

everything here should be @experimental


/**
* :: Experimental ::
* A trigger that runs a query periodically based on the processing time. If `intervalMs` is 0,
* the query will run as fast as possible.
*
* Scala Example:
* {{{
* def.writer.trigger(ProcessingTime("10 seconds"))
*
* import scala.concurrent.duration._
* def.writer.trigger(ProcessingTime(10.seconds))
* }}}
*
* Java Example:
* {{{
* def.writer.trigger(ProcessingTime.create("10 seconds"))
*
* import java.util.concurrent.TimeUnit
* def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS))
* }}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice documentation! Maybe put the typesafe one second and include the imports that are required.

*/
@Experimental
case class ProcessingTime(intervalMs: Long) extends Trigger {
require(intervalMs >= 0, "the interval of trigger should not be negative")
}

/**
* :: Experimental ::
* Used to create [[ProcessingTime]] triggers for [[ContinuousQuery]]s.
*/
@Experimental
object ProcessingTime {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used to create [[ProcessingTime]] triggers for [[ContinuousQueries]]. Or something.


/**
* Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
*
* Example:
* {{{
* def.writer.trigger(ProcessingTime("10 seconds"))
* }}}
*/
def apply(interval: String): ProcessingTime = {
if (StringUtils.isBlank(interval)) {
throw new IllegalArgumentException(
"interval cannot be null or blank.")
}
val cal = if (interval.startsWith("interval")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this logic duplicated elsewhere? Should CalendarInterval.fromString just do this internally?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this logic duplicated elsewhere? Should CalendarInterval.fromString just do this internally?

SQL also uses CalendarInterval.fromString to parse the interval and the syntax requires INTERVAL value unit. Hence, I cannot move the logic into CalendarInterval.fromString

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The presence of INTERVAL isn't also enforced by the grammar?

CalendarInterval.fromString(interval)
} else {
CalendarInterval.fromString("interval " + interval)
}
if (cal == null) {
throw new IllegalArgumentException(s"Invalid interval: $interval")
}
if (cal.months > 0) {
throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval")
}
new ProcessingTime(cal.microseconds / 1000)
}

/**
* Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
*
* Example:
* {{{
* import scala.concurrent.duration._
* def.writer.trigger(ProcessingTime(10.seconds))
* }}}
*/
def apply(interval: Duration): ProcessingTime = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs docs with examples

new ProcessingTime(interval.toMillis)
}

/**
* Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
*
* Example:
* {{{
* def.writer.trigger(ProcessingTime.create("10 seconds"))
* }}}
*/
def create(interval: String): ProcessingTime = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs scala docs with examples

apply(interval)
}

/**
* Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
*
* Example:
* {{{
* import java.util.concurrent.TimeUnit
* def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS))
* }}}
*/
def create(interval: Long, unit: TimeUnit): ProcessingTime = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs scala docs with examples

new ProcessingTime(unit.toMillis(interval))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,14 @@ class StreamExecution(
override val name: String,
val checkpointRoot: String,
private[sql] val logicalPlan: LogicalPlan,
val sink: Sink) extends ContinuousQuery with Logging {
val sink: Sink,
val trigger: Trigger) extends ContinuousQuery with Logging {

/** An monitor used to wait/notify when batches complete. */
private val awaitBatchLock = new Object
private val startLatch = new CountDownLatch(1)
private val terminationLatch = new CountDownLatch(1)

/** Minimum amount of time in between the start of each batch. */
private val minBatchTime = 10

/**
* Tracks how much data we have processed and committed to the sink or state store from each
* input source.
Expand All @@ -78,6 +76,10 @@ class StreamExecution(
/** A list of unique sources in the query plan. */
private val uniqueSources = sources.distinct

private val triggerExecutor = trigger match {
case t: ProcessingTime => ProcessingTimeExecutor(t)
}

/** Defines the internal state of execution */
@volatile
private var state: State = INITIALIZED
Expand Down Expand Up @@ -211,11 +213,15 @@ class StreamExecution(
SQLContext.setActive(sqlContext)
populateStartOffsets()
logDebug(s"Stream running from $committedOffsets to $availableOffsets")
while (isActive) {
if (dataAvailable) runBatch()
commitAndConstructNextBatch()
Thread.sleep(minBatchTime) // TODO: Could be tighter
}
triggerExecutor.execute(() => {
if (isActive) {
if (dataAvailable) runBatch()
commitAndConstructNextBatch()
true
} else {
false
}
})
} catch {
case _: InterruptedException if state == TERMINATED => // interrupted by stop()
case NonFatal(e) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming

import org.apache.spark.internal.Logging
import org.apache.spark.sql.ProcessingTime
import org.apache.spark.util.{Clock, SystemClock}

trait TriggerExecutor {

/**
* Execute batches using `batchRunner`. If `batchRunner` runs `false`, terminate the execution.
*/
def execute(batchRunner: () => Boolean): Unit
}

/**
* A trigger executor that runs a batch every `intervalMs` milliseconds.
*/
case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = new SystemClock())
extends TriggerExecutor with Logging {

private val intervalMs = processingTime.intervalMs

override def execute(batchRunner: () => Boolean): Unit = {
while (true) {
val batchStartTimeMs = clock.getTimeMillis()
val terminated = !batchRunner()
if (intervalMs > 0) {
val batchEndTimeMs = clock.getTimeMillis()
val batchElapsedTimeMs = batchEndTimeMs - batchStartTimeMs
if (batchElapsedTimeMs > intervalMs) {
notifyBatchFallingBehind(batchElapsedTimeMs)
}
if (terminated) {
return
}
clock.waitTillTime(nextBatchTime(batchEndTimeMs))
} else {
if (terminated) {
return
}
}
}
}

/** Called when a batch falls behind. Expose for test only */
def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = {
logWarning("Current batch is falling behind. The trigger interval is " +
s"${intervalMs} milliseconds, but spent ${realElapsedTimeMs} milliseconds")
}

/** Return the next multiple of intervalMs */
def nextBatchTime(now: Long): Long = {
(now - 1) / intervalMs * intervalMs + intervalMs
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.util.concurrent.TimeUnit

import scala.concurrent.duration._

import org.apache.spark.SparkFunSuite

class ProcessingTimeSuite extends SparkFunSuite {

test("create") {
assert(ProcessingTime(10.seconds).intervalMs === 10 * 1000)
assert(ProcessingTime.create(10, TimeUnit.SECONDS).intervalMs === 10 * 1000)
assert(ProcessingTime("1 minute").intervalMs === 60 * 1000)
assert(ProcessingTime("interval 1 minute").intervalMs === 60 * 1000)

intercept[IllegalArgumentException] { ProcessingTime(null: String) }
intercept[IllegalArgumentException] { ProcessingTime("") }
intercept[IllegalArgumentException] { ProcessingTime("invalid") }
intercept[IllegalArgumentException] { ProcessingTime("1 month") }
intercept[IllegalArgumentException] { ProcessingTime("1 year") }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,11 @@ trait StreamTest extends QueryTest with Timeouts {
currentStream =
sqlContext
.streams
.startQuery(StreamExecution.nextName, metadataRoot, stream, sink)
.startQuery(
StreamExecution.nextName,
metadataRoot,
stream,
sink)
.asInstanceOf[StreamExecution]
currentStream.microBatchThread.setUncaughtExceptionHandler(
new UncaughtExceptionHandler {
Expand Down
Loading