Skip to content

Commit

Permalink
[SPARK-14176][SQL] Add DataFrameWriter.trigger to set the stream batc…
Browse files Browse the repository at this point in the history
…h period

## What changes were proposed in this pull request?

Add a processing time trigger to control the batch processing speed

## How was this patch tested?

Unit tests

Author: Shixiong Zhu <[email protected]>

Closes #11976 from zsxwing/trigger.
  • Loading branch information
zsxwing authored and marmbrus committed Apr 4, 2016
1 parent 89f3bef commit 855ed44
Show file tree
Hide file tree
Showing 9 changed files with 413 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,20 @@ class ContinuousQueryManager(sqlContext: SQLContext) {
name: String,
checkpointLocation: String,
df: DataFrame,
sink: Sink): ContinuousQuery = {
sink: Sink,
trigger: Trigger = ProcessingTime(0)): ContinuousQuery = {
activeQueriesLock.synchronized {
if (activeQueries.contains(name)) {
throw new IllegalArgumentException(
s"Cannot start query with name $name as a query with that name is already active")
}
val query = new StreamExecution(sqlContext, name, checkpointLocation, df.logicalPlan, sink)
val query = new StreamExecution(
sqlContext,
name,
checkpointLocation,
df.logicalPlan,
sink,
trigger)
query.start()
activeQueries.put(name, query)
query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,35 @@ final class DataFrameWriter private[sql](df: DataFrame) {
this
}

/**
* :: Experimental ::
* Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will run
* the query as fast as possible.
*
* Scala Example:
* {{{
* def.writer.trigger(ProcessingTime("10 seconds"))
*
* import scala.concurrent.duration._
* def.writer.trigger(ProcessingTime(10.seconds))
* }}}
*
* Java Example:
* {{{
* def.writer.trigger(ProcessingTime.create("10 seconds"))
*
* import java.util.concurrent.TimeUnit
* def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS))
* }}}
*
* @since 2.0.0
*/
@Experimental
def trigger(trigger: Trigger): DataFrameWriter = {
this.trigger = trigger
this
}

/**
* Specifies the underlying output data source. Built-in options include "parquet", "json", etc.
*
Expand Down Expand Up @@ -261,7 +290,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
queryName,
checkpointLocation,
df,
dataSource.createSink())
dataSource.createSink(),
trigger)
}

/**
Expand Down Expand Up @@ -552,6 +582,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {

private var mode: SaveMode = SaveMode.ErrorIfExists

private var trigger: Trigger = ProcessingTime(0L)

private var extraOptions = new scala.collection.mutable.HashMap[String, String]

private var partitioningColumns: Option[Seq[String]] = None
Expand Down
133 changes: 133 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.util.concurrent.TimeUnit

import scala.concurrent.duration.Duration

import org.apache.commons.lang3.StringUtils

import org.apache.spark.annotation.Experimental
import org.apache.spark.unsafe.types.CalendarInterval

/**
* :: Experimental ::
* Used to indicate how often results should be produced by a [[ContinuousQuery]].
*/
@Experimental
sealed trait Trigger {}

/**
* :: Experimental ::
* A trigger that runs a query periodically based on the processing time. If `intervalMs` is 0,
* the query will run as fast as possible.
*
* Scala Example:
* {{{
* def.writer.trigger(ProcessingTime("10 seconds"))
*
* import scala.concurrent.duration._
* def.writer.trigger(ProcessingTime(10.seconds))
* }}}
*
* Java Example:
* {{{
* def.writer.trigger(ProcessingTime.create("10 seconds"))
*
* import java.util.concurrent.TimeUnit
* def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS))
* }}}
*/
@Experimental
case class ProcessingTime(intervalMs: Long) extends Trigger {
require(intervalMs >= 0, "the interval of trigger should not be negative")
}

/**
* :: Experimental ::
* Used to create [[ProcessingTime]] triggers for [[ContinuousQuery]]s.
*/
@Experimental
object ProcessingTime {

/**
* Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
*
* Example:
* {{{
* def.writer.trigger(ProcessingTime("10 seconds"))
* }}}
*/
def apply(interval: String): ProcessingTime = {
if (StringUtils.isBlank(interval)) {
throw new IllegalArgumentException(
"interval cannot be null or blank.")
}
val cal = if (interval.startsWith("interval")) {
CalendarInterval.fromString(interval)
} else {
CalendarInterval.fromString("interval " + interval)
}
if (cal == null) {
throw new IllegalArgumentException(s"Invalid interval: $interval")
}
if (cal.months > 0) {
throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval")
}
new ProcessingTime(cal.microseconds / 1000)
}

/**
* Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
*
* Example:
* {{{
* import scala.concurrent.duration._
* def.writer.trigger(ProcessingTime(10.seconds))
* }}}
*/
def apply(interval: Duration): ProcessingTime = {
new ProcessingTime(interval.toMillis)
}

/**
* Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
*
* Example:
* {{{
* def.writer.trigger(ProcessingTime.create("10 seconds"))
* }}}
*/
def create(interval: String): ProcessingTime = {
apply(interval)
}

/**
* Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
*
* Example:
* {{{
* import java.util.concurrent.TimeUnit
* def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS))
* }}}
*/
def create(interval: Long, unit: TimeUnit): ProcessingTime = {
new ProcessingTime(unit.toMillis(interval))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,14 @@ class StreamExecution(
override val name: String,
val checkpointRoot: String,
private[sql] val logicalPlan: LogicalPlan,
val sink: Sink) extends ContinuousQuery with Logging {
val sink: Sink,
val trigger: Trigger) extends ContinuousQuery with Logging {

/** An monitor used to wait/notify when batches complete. */
private val awaitBatchLock = new Object
private val startLatch = new CountDownLatch(1)
private val terminationLatch = new CountDownLatch(1)

/** Minimum amount of time in between the start of each batch. */
private val minBatchTime = 10

/**
* Tracks how much data we have processed and committed to the sink or state store from each
* input source.
Expand All @@ -79,6 +77,10 @@ class StreamExecution(
/** A list of unique sources in the query plan. */
private val uniqueSources = sources.distinct

private val triggerExecutor = trigger match {
case t: ProcessingTime => ProcessingTimeExecutor(t)
}

/** Defines the internal state of execution */
@volatile
private var state: State = INITIALIZED
Expand Down Expand Up @@ -154,11 +156,15 @@ class StreamExecution(
SQLContext.setActive(sqlContext)
populateStartOffsets()
logDebug(s"Stream running from $committedOffsets to $availableOffsets")
while (isActive) {
if (dataAvailable) runBatch()
commitAndConstructNextBatch()
Thread.sleep(minBatchTime) // TODO: Could be tighter
}
triggerExecutor.execute(() => {
if (isActive) {
if (dataAvailable) runBatch()
commitAndConstructNextBatch()
true
} else {
false
}
})
} catch {
case _: InterruptedException if state == TERMINATED => // interrupted by stop()
case NonFatal(e) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming

import org.apache.spark.internal.Logging
import org.apache.spark.sql.ProcessingTime
import org.apache.spark.util.{Clock, SystemClock}

trait TriggerExecutor {

/**
* Execute batches using `batchRunner`. If `batchRunner` runs `false`, terminate the execution.
*/
def execute(batchRunner: () => Boolean): Unit
}

/**
* A trigger executor that runs a batch every `intervalMs` milliseconds.
*/
case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = new SystemClock())
extends TriggerExecutor with Logging {

private val intervalMs = processingTime.intervalMs

override def execute(batchRunner: () => Boolean): Unit = {
while (true) {
val batchStartTimeMs = clock.getTimeMillis()
val terminated = !batchRunner()
if (intervalMs > 0) {
val batchEndTimeMs = clock.getTimeMillis()
val batchElapsedTimeMs = batchEndTimeMs - batchStartTimeMs
if (batchElapsedTimeMs > intervalMs) {
notifyBatchFallingBehind(batchElapsedTimeMs)
}
if (terminated) {
return
}
clock.waitTillTime(nextBatchTime(batchEndTimeMs))
} else {
if (terminated) {
return
}
}
}
}

/** Called when a batch falls behind. Expose for test only */
def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = {
logWarning("Current batch is falling behind. The trigger interval is " +
s"${intervalMs} milliseconds, but spent ${realElapsedTimeMs} milliseconds")
}

/** Return the next multiple of intervalMs */
def nextBatchTime(now: Long): Long = {
(now - 1) / intervalMs * intervalMs + intervalMs
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.util.concurrent.TimeUnit

import scala.concurrent.duration._

import org.apache.spark.SparkFunSuite

class ProcessingTimeSuite extends SparkFunSuite {

test("create") {
assert(ProcessingTime(10.seconds).intervalMs === 10 * 1000)
assert(ProcessingTime.create(10, TimeUnit.SECONDS).intervalMs === 10 * 1000)
assert(ProcessingTime("1 minute").intervalMs === 60 * 1000)
assert(ProcessingTime("interval 1 minute").intervalMs === 60 * 1000)

intercept[IllegalArgumentException] { ProcessingTime(null: String) }
intercept[IllegalArgumentException] { ProcessingTime("") }
intercept[IllegalArgumentException] { ProcessingTime("invalid") }
intercept[IllegalArgumentException] { ProcessingTime("1 month") }
intercept[IllegalArgumentException] { ProcessingTime("1 year") }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,11 @@ trait StreamTest extends QueryTest with Timeouts {
currentStream =
sqlContext
.streams
.startQuery(StreamExecution.nextName, metadataRoot, stream, sink)
.startQuery(
StreamExecution.nextName,
metadataRoot,
stream,
sink)
.asInstanceOf[StreamExecution]
currentStream.microBatchThread.setUncaughtExceptionHandler(
new UncaughtExceptionHandler {
Expand Down
Loading

0 comments on commit 855ed44

Please sign in to comment.