Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use CollectMetrics for numOutputRows in streaming sources #11

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2024,10 +2024,17 @@ case class CollectMetrics(
dataframeId: Long)
extends UnaryNode {

import CollectMetrics._

override lazy val resolved: Boolean = {
name.nonEmpty && metrics.nonEmpty && metrics.forall(_.resolved) && childrenResolved
}

if (isForStreamSource(name)) {
assert(references.isEmpty,
"The node should not refer any column if it's used for stream source output counter!")
}

override def maxRows: Option[Long] = child.maxRows
override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition
override def output: Seq[Attribute] = child.output
Expand All @@ -2040,6 +2047,14 @@ case class CollectMetrics(
}
}

object CollectMetrics {
val STREAM_SOURCE_PREFIX = "__stream_source_"

def nameForStreamSource(name: String): String = s"$STREAM_SOURCE_PREFIX$name"

def isForStreamSource(name: String): Boolean = name.startsWith(STREAM_SOURCE_PREFIX)
}

/**
* A placeholder for domain join that can be added when decorrelating subqueries.
* It should be rewritten during the optimization phase.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ import scala.collection.mutable
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{NUM_PRUNED, POST_SCAN_FILTERS, PUSHED_FILTERS, TOTAL}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable}
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.TreePattern.{PLAN_EXPRESSION, SCALAR_SUBQUERY}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
Expand Down Expand Up @@ -61,6 +61,25 @@ import org.apache.spark.util.collection.BitSet
*/
object FileSourceStrategy extends Strategy with PredicateHelper with Logging {

private type HadoopFsRelationHolderRetType =
(LogicalRelation, HadoopFsRelation, Option[CatalogTable], Option[CollectMetrics])

private object HadoopFsRelationHolder {
def unapply(plan: LogicalPlan): Option[HadoopFsRelationHolderRetType] = {
plan match {
case c @ CollectMetrics(name, _,
l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _), _)
if CollectMetrics.isForStreamSource(name) =>
Some(l, fsRelation, table, Some(c))

case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _) =>
Some(l, fsRelation, table, None)

case _ => None
}
}
}

// should prune buckets iff num buckets is greater than 1 and there is only one bucket column
private def shouldPruneBuckets(bucketSpec: Option[BucketSpec]): Boolean = {
bucketSpec match {
Expand Down Expand Up @@ -151,7 +170,7 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging {

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ScanOperation(projects, stayUpFilters, filters,
l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) =>
HadoopFsRelationHolder(l, fsRelation, table, collectMetricsOpt)) =>
// Filters on this relation fall into four categories based on where we can use them to avoid
// reading unneeded data:
// - partition keys only - used to prune directories to read
Expand Down Expand Up @@ -342,9 +361,25 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging {
val metadataAlias =
Alias(KnownNotNull(CreateStruct(structColumns.toImmutableArraySeq)),
FileFormat.METADATA_NAME)(exprId = metadataStruct.exprId)

val nodeExec = if (collectMetricsOpt.isDefined) {
val collectMetricsLogical = collectMetricsOpt.get
execution.CollectMetricsExec(
collectMetricsLogical.name, collectMetricsLogical.metrics, scan)
} else {
scan
}
execution.ProjectExec(
readDataColumns ++ partitionColumns :+ metadataAlias, scan)
}.getOrElse(scan)
readDataColumns ++ partitionColumns :+ metadataAlias, nodeExec)
}.getOrElse {
if (collectMetricsOpt.isDefined) {
val collectMetricsLogical = collectMetricsOpt.get
execution.CollectMetricsExec(
collectMetricsLogical.name, collectMetricsLogical.metrics, scan)
} else {
scan
}
}

// bottom-most filters are put in the left of the list.
val finalFilters = afterScanFilters.toSeq.reduceOption(expressions.And).toSeq ++ stayUpFilters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

package org.apache.spark.sql.execution.streaming

import java.util.UUID

import scala.collection.mutable.{Map => MutableMap}
import scala.collection.mutable

import org.apache.spark.internal.{LogKeys, MDC}
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.{Column, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp, FileSourceMetadataAttribute, LocalTimestamp}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, LeafNode, LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.streaming.{StreamingRelationV2, WriteToStream}
import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE
import org.apache.spark.sql.catalyst.util.truncatedString
Expand All @@ -35,6 +38,8 @@ import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, WriteToDataSourceV2Exec}
import org.apache.spark.sql.execution.streaming.sources.{WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1}
import org.apache.spark.sql.functions.count
import org.apache.spark.sql.internal
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.Trigger
import org.apache.spark.util.{Clock, Utils}
Expand Down Expand Up @@ -731,10 +736,15 @@ class MicroBatchExecution(
}

// Replace sources in the logical plan with data that has arrived since the last batch.
import sparkSessionToRunBatch.RichColumn

val uuidToStream = new mutable.HashMap[String, SparkDataStream]()
val streamToCollectMetrics = new mutable.HashMap[SparkDataStream, CollectMetrics]()

val newBatchesPlan = logicalPlan transform {
// For v1 sources.
case StreamingExecutionRelation(source, output, catalogTable) =>
mutableNewData.get(source).map { dataPlan =>
val node = mutableNewData.get(source).map { dataPlan =>
val hasFileMetadata = output.exists {
case FileSourceMetadataAttribute(_) => true
case _ => false
Expand Down Expand Up @@ -782,16 +792,54 @@ class MicroBatchExecution(
LocalRelation(output, isStreaming = true)
}

val collectMetricsName = CollectMetrics.nameForStreamSource(
UUID.randomUUID().toString)
uuidToStream.put(collectMetricsName, source)
val cachedCollectMetrics = streamToCollectMetrics.getOrElseUpdate(source,
CollectMetrics(
collectMetricsName,
Seq(
count(
new Column(internal.Literal(1))).as("row_count")
).map(_.named),
UnresolvedRelation(Seq("dummy")),
-1
)
)

val colMetrics = cachedCollectMetrics.copy(child = node)
sparkSessionToRunBatch.sessionState.analyzer.execute(colMetrics)

// For v2 sources.
case r: StreamingDataSourceV2ScanRelation =>
mutableNewData.get(r.stream).map {
case r: StreamingDataSourceV2ScanRelation
if r.startOffset.isEmpty && r.endOffset.isEmpty =>
val node = mutableNewData.get(r.stream).map {
case OffsetHolder(start, end) =>
r.copy(startOffset = Some(start), endOffset = Some(end))
}.getOrElse {
LocalRelation(r.output, isStreaming = true)
}

val collectMetricsName = CollectMetrics.nameForStreamSource(
UUID.randomUUID().toString)
uuidToStream.put(collectMetricsName, r.stream)
val cachedCollectMetrics = streamToCollectMetrics.getOrElseUpdate(r.stream,
CollectMetrics(
collectMetricsName,
Seq(
count(
new Column(internal.Literal(1))).as("row_count")
).map(_.named),
UnresolvedRelation(Seq("dummy")),
-1
)
)

val colMetrics = cachedCollectMetrics.copy(child = node)
sparkSessionToRunBatch.sessionState.analyzer.execute(colMetrics)
}
execCtx.newData = mutableNewData.toMap
execCtx.uuidToStream = uuidToStream.toMap
// Rewire the plan to use the new attributes that were returned by the source.
val newAttributePlan = newBatchesPlan.transformAllExpressionsWithPruning(
_.containsPattern(CURRENT_LIKE)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@ import scala.jdk.CollectionConverters._

import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.optimizer.InlineCTE
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan, WithCTE}
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, ReportsSinkMetrics, ReportsSourceMetrics, SparkDataStream}
import org.apache.spark.sql.connector.read.streaming.{ReportsSinkMetrics, ReportsSourceMetrics, SparkDataStream}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.v2.{MicroBatchScanExec, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress}
import org.apache.spark.sql.execution.datasources.v2.{StreamWriterCommitProgress}
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryProgressEvent}
import org.apache.spark.util.{Clock, Utils}
Expand Down Expand Up @@ -144,6 +143,8 @@ abstract class ProgressContext(
// the most recent input data for each source.
protected def newData: Map[SparkDataStream, LogicalPlan]

protected def uuidToStream: Map[String, SparkDataStream]

/** Flag that signals whether any error with input metrics have already been logged */
protected var metricWarningLogged: Boolean = false

Expand Down Expand Up @@ -409,103 +410,21 @@ abstract class ProgressContext(
tuples.groupBy(_._1).transform((_, v) => v.map(_._2).sum) // sum up rows for each source
}

def unrollCTE(plan: LogicalPlan): LogicalPlan = {
val containsCTE = plan.exists {
case _: WithCTE => true
case _ => false
}

if (containsCTE) {
InlineCTE(alwaysInline = true).apply(plan)
} else {
plan
}
}

val onlyDataSourceV2Sources = {
// Check whether the streaming query's logical plan has only V2 micro-batch data sources
val allStreamingLeaves = progressReporter.logicalPlan().collect {
case s: StreamingDataSourceV2ScanRelation => s.stream.isInstanceOf[MicroBatchStream]
case _: StreamingExecutionRelation => false
}
allStreamingLeaves.forall(_ == true)
}
import org.apache.spark.sql.execution.CollectMetricsExec

if (onlyDataSourceV2Sources) {
// It's possible that multiple DataSourceV2ScanExec instances may refer to the same source
// (can happen with self-unions or self-joins). This means the source is scanned multiple
// times in the query, we should count the numRows for each scan.
if (uuidToStream != null) {
val sourceToInputRowsTuples = lastExecution.executedPlan.collect {
case s: MicroBatchScanExec =>
val numRows = s.metrics.get("numOutputRows").map(_.value).getOrElse(0L)
val source = s.stream
source -> numRows
case c: CollectMetricsExec if uuidToStream.contains(c.name) =>
val stream = uuidToStream(c.name)
val numRows = c.collectedMetrics.getAs[Long]("row_count")
stream -> numRows
}

logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t"))
sumRows(sourceToInputRowsTuples)
sumRows(sourceToInputRowsTuples.toSeq)
} else {

// Since V1 source do not generate execution plan leaves that directly link with source that
// generated it, we can only do a best-effort association between execution plan leaves to the
// sources. This is known to fail in a few cases, see SPARK-24050.
//
// We want to associate execution plan leaves to sources that generate them, so that we match
// the their metrics (e.g. numOutputRows) to the sources. To do this we do the following.
// Consider the translation from the streaming logical plan to the final executed plan.
//
// streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan
//
// 1. We keep track of streaming sources associated with each leaf in trigger's logical plan
// - Each logical plan leaf will be associated with a single streaming source.
// - There can be multiple logical plan leaves associated with a streaming source.
// - There can be leaves not associated with any streaming source, because they were
// generated from a batch source (e.g. stream-batch joins)
//
// 2. Assuming that the executed plan has same number of leaves in the same order as that of
// the trigger logical plan, we associate executed plan leaves with corresponding
// streaming sources.
//
// 3. For each source, we sum the metrics of the associated execution plan leaves.
//
val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) =>
logicalPlan.collectLeaves().map { leaf => leaf -> source }
}

// SPARK-41198: CTE is inlined in optimization phase, which ends up with having different
// number of leaf nodes between (analyzed) logical plan and executed plan. Here we apply
// inlining CTE against logical plan manually if there is a CTE node.
val finalLogicalPlan = unrollCTE(lastExecution.logical)

val allLogicalPlanLeaves = finalLogicalPlan.collectLeaves() // includes non-streaming
val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves()
if (allLogicalPlanLeaves.size == allExecPlanLeaves.size) {
val execLeafToSource = allLogicalPlanLeaves.zip(allExecPlanLeaves).flatMap {
case (_, ep: MicroBatchScanExec) =>
// SPARK-41199: `logicalPlanLeafToSource` contains OffsetHolder instance for DSv2
// streaming source, hence we cannot lookup the actual source from the map.
// The physical node for DSv2 streaming source contains the information of the source
// by itself, so leverage it.
Some(ep -> ep.stream)
case (lp, ep) =>
logicalPlanLeafToSource.get(lp).map { source => ep -> source }
}
val sourceToInputRowsTuples = execLeafToSource.map { case (execLeaf, source) =>
val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L)
source -> numRows
}
sumRows(sourceToInputRowsTuples)
} else {
if (!metricWarningLogged) {
def toString[T](seq: Seq[T]): String = s"(size = ${seq.size}), ${seq.mkString(", ")}"

logWarning(log"Could not report metrics as number leaves in trigger logical plan did " +
log"not match that of the execution plan:\nlogical plan leaves: " +
log"${MDC(LogKeys.LOGICAL_PLAN_LEAVES, toString(allLogicalPlanLeaves))}\nexecution " +
log"plan leaves: ${MDC(LogKeys.EXECUTION_PLAN_LEAVES, toString(allExecPlanLeaves))}\n")
metricWarningLogged = true
}
Map.empty
}
logWarning("Association for streaming source output has been lost.")
Map.empty
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ abstract class StreamExecutionContext(
/** Holds the most recent input data for each source. */
var newData: Map[SparkDataStream, LogicalPlan] = _

var uuidToStream: Map[String, SparkDataStream] = _

/**
* Stores the start offset for this batch.
* Only the scheduler thread should modify this field, and only in atomic steps.
Expand Down
Loading
Loading