diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index e784e6695dbd0..f56136b2d517a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -2024,10 +2024,17 @@ case class CollectMetrics( dataframeId: Long) extends UnaryNode { + import CollectMetrics._ + override lazy val resolved: Boolean = { name.nonEmpty && metrics.nonEmpty && metrics.forall(_.resolved) && childrenResolved } + if (isForStreamSource(name)) { + assert(references.isEmpty, + "The node should not refer any column if it's used for stream source output counter!") + } + override def maxRows: Option[Long] = child.maxRows override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition override def output: Seq[Attribute] = child.output @@ -2040,6 +2047,14 @@ case class CollectMetrics( } } +object CollectMetrics { + val STREAM_SOURCE_PREFIX = "__stream_source_" + + def nameForStreamSource(name: String): String = s"$STREAM_SOURCE_PREFIX$name" + + def isForStreamSource(name: String): Boolean = name.startsWith(STREAM_SOURCE_PREFIX) +} + /** * A placeholder for domain join that can be added when decorrelating subqueries. * It should be rewritten during the optimization phase. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 27019ab047ff2..bad6ff402ecbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -24,11 +24,11 @@ import scala.collection.mutable import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{NUM_PRUNED, POST_SCAN_FILTERS, PUSHED_FILTERS, TOTAL} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable} import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreePattern.{PLAN_EXPRESSION, SCALAR_SUBQUERY} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} @@ -61,6 +61,25 @@ import org.apache.spark.util.collection.BitSet */ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { + private type HadoopFsRelationHolderRetType = + (LogicalRelation, HadoopFsRelation, Option[CatalogTable], Option[CollectMetrics]) + + private object HadoopFsRelationHolder { + def unapply(plan: LogicalPlan): Option[HadoopFsRelationHolderRetType] = { + plan match { + case c @ CollectMetrics(name, _, + l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _), _) + if CollectMetrics.isForStreamSource(name) => + Some(l, fsRelation, table, Some(c)) + + case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _) => + Some(l, fsRelation, table, None) + + case _ => None + } + } + } + // should prune buckets iff num buckets is greater than 1 and there is only one bucket column private def shouldPruneBuckets(bucketSpec: Option[BucketSpec]): Boolean = { bucketSpec match { @@ -151,7 +170,7 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ScanOperation(projects, stayUpFilters, filters, - l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) => + HadoopFsRelationHolder(l, fsRelation, table, collectMetricsOpt)) => // Filters on this relation fall into four categories based on where we can use them to avoid // reading unneeded data: // - partition keys only - used to prune directories to read @@ -342,9 +361,25 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { val metadataAlias = Alias(KnownNotNull(CreateStruct(structColumns.toImmutableArraySeq)), FileFormat.METADATA_NAME)(exprId = metadataStruct.exprId) + + val nodeExec = if (collectMetricsOpt.isDefined) { + val collectMetricsLogical = collectMetricsOpt.get + execution.CollectMetricsExec( + collectMetricsLogical.name, collectMetricsLogical.metrics, scan) + } else { + scan + } execution.ProjectExec( - readDataColumns ++ partitionColumns :+ metadataAlias, scan) - }.getOrElse(scan) + readDataColumns ++ partitionColumns :+ metadataAlias, nodeExec) + }.getOrElse { + if (collectMetricsOpt.isDefined) { + val collectMetricsLogical = collectMetricsOpt.get + execution.CollectMetricsExec( + collectMetricsLogical.name, collectMetricsLogical.metrics, scan) + } else { + scan + } + } // bottom-most filters are put in the left of the list. val finalFilters = afterScanFilters.toSeq.reduceOption(expressions.And).toSeq ++ stayUpFilters diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index f59cdca8aefec..2e1dde0cd6972 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID + import scala.collection.mutable.{Map => MutableMap} import scala.collection.mutable import org.apache.spark.internal.{LogKeys, MDC} -import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.{Column, Dataset, SparkSession} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp, FileSourceMetadataAttribute, LocalTimestamp} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, LeafNode, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.streaming.{StreamingRelationV2, WriteToStream} import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE import org.apache.spark.sql.catalyst.util.truncatedString @@ -35,6 +38,8 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming.sources.{WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1} +import org.apache.spark.sql.functions.count +import org.apache.spark.sql.internal import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.Trigger import org.apache.spark.util.{Clock, Utils} @@ -731,10 +736,15 @@ class MicroBatchExecution( } // Replace sources in the logical plan with data that has arrived since the last batch. + import sparkSessionToRunBatch.RichColumn + + val uuidToStream = new mutable.HashMap[String, SparkDataStream]() + val streamToCollectMetrics = new mutable.HashMap[SparkDataStream, CollectMetrics]() + val newBatchesPlan = logicalPlan transform { // For v1 sources. case StreamingExecutionRelation(source, output, catalogTable) => - mutableNewData.get(source).map { dataPlan => + val node = mutableNewData.get(source).map { dataPlan => val hasFileMetadata = output.exists { case FileSourceMetadataAttribute(_) => true case _ => false @@ -782,16 +792,54 @@ class MicroBatchExecution( LocalRelation(output, isStreaming = true) } + val collectMetricsName = CollectMetrics.nameForStreamSource( + UUID.randomUUID().toString) + uuidToStream.put(collectMetricsName, source) + val cachedCollectMetrics = streamToCollectMetrics.getOrElseUpdate(source, + CollectMetrics( + collectMetricsName, + Seq( + count( + new Column(internal.Literal(1))).as("row_count") + ).map(_.named), + UnresolvedRelation(Seq("dummy")), + -1 + ) + ) + + val colMetrics = cachedCollectMetrics.copy(child = node) + sparkSessionToRunBatch.sessionState.analyzer.execute(colMetrics) + // For v2 sources. - case r: StreamingDataSourceV2ScanRelation => - mutableNewData.get(r.stream).map { + case r: StreamingDataSourceV2ScanRelation + if r.startOffset.isEmpty && r.endOffset.isEmpty => + val node = mutableNewData.get(r.stream).map { case OffsetHolder(start, end) => r.copy(startOffset = Some(start), endOffset = Some(end)) }.getOrElse { LocalRelation(r.output, isStreaming = true) } + + val collectMetricsName = CollectMetrics.nameForStreamSource( + UUID.randomUUID().toString) + uuidToStream.put(collectMetricsName, r.stream) + val cachedCollectMetrics = streamToCollectMetrics.getOrElseUpdate(r.stream, + CollectMetrics( + collectMetricsName, + Seq( + count( + new Column(internal.Literal(1))).as("row_count") + ).map(_.named), + UnresolvedRelation(Seq("dummy")), + -1 + ) + ) + + val colMetrics = cachedCollectMetrics.copy(child = node) + sparkSessionToRunBatch.sessionState.analyzer.execute(colMetrics) } execCtx.newData = mutableNewData.toMap + execCtx.uuidToStream = uuidToStream.toMap // Rewire the plan to use the new attributes that were returned by the source. val newAttributePlan = newBatchesPlan.transformAllExpressionsWithPruning( _.containsPattern(CURRENT_LIKE)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index c440ec451b724..5fa02fe33ac1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -27,14 +27,13 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.catalyst.optimizer.InlineCTE -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan, WithCTE} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, ReportsSinkMetrics, ReportsSourceMetrics, SparkDataStream} +import org.apache.spark.sql.connector.read.streaming.{ReportsSinkMetrics, ReportsSourceMetrics, SparkDataStream} import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.execution.datasources.v2.{MicroBatchScanExec, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress} +import org.apache.spark.sql.execution.datasources.v2.{StreamWriterCommitProgress} import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryProgressEvent} import org.apache.spark.util.{Clock, Utils} @@ -144,6 +143,8 @@ abstract class ProgressContext( // the most recent input data for each source. protected def newData: Map[SparkDataStream, LogicalPlan] + protected def uuidToStream: Map[String, SparkDataStream] + /** Flag that signals whether any error with input metrics have already been logged */ protected var metricWarningLogged: Boolean = false @@ -409,103 +410,21 @@ abstract class ProgressContext( tuples.groupBy(_._1).transform((_, v) => v.map(_._2).sum) // sum up rows for each source } - def unrollCTE(plan: LogicalPlan): LogicalPlan = { - val containsCTE = plan.exists { - case _: WithCTE => true - case _ => false - } - - if (containsCTE) { - InlineCTE(alwaysInline = true).apply(plan) - } else { - plan - } - } - - val onlyDataSourceV2Sources = { - // Check whether the streaming query's logical plan has only V2 micro-batch data sources - val allStreamingLeaves = progressReporter.logicalPlan().collect { - case s: StreamingDataSourceV2ScanRelation => s.stream.isInstanceOf[MicroBatchStream] - case _: StreamingExecutionRelation => false - } - allStreamingLeaves.forall(_ == true) - } + import org.apache.spark.sql.execution.CollectMetricsExec - if (onlyDataSourceV2Sources) { - // It's possible that multiple DataSourceV2ScanExec instances may refer to the same source - // (can happen with self-unions or self-joins). This means the source is scanned multiple - // times in the query, we should count the numRows for each scan. + if (uuidToStream != null) { val sourceToInputRowsTuples = lastExecution.executedPlan.collect { - case s: MicroBatchScanExec => - val numRows = s.metrics.get("numOutputRows").map(_.value).getOrElse(0L) - val source = s.stream - source -> numRows + case c: CollectMetricsExec if uuidToStream.contains(c.name) => + val stream = uuidToStream(c.name) + val numRows = c.collectedMetrics.getAs[Long]("row_count") + stream -> numRows } + logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) - sumRows(sourceToInputRowsTuples) + sumRows(sourceToInputRowsTuples.toSeq) } else { - - // Since V1 source do not generate execution plan leaves that directly link with source that - // generated it, we can only do a best-effort association between execution plan leaves to the - // sources. This is known to fail in a few cases, see SPARK-24050. - // - // We want to associate execution plan leaves to sources that generate them, so that we match - // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. - // Consider the translation from the streaming logical plan to the final executed plan. - // - // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan - // - // 1. We keep track of streaming sources associated with each leaf in trigger's logical plan - // - Each logical plan leaf will be associated with a single streaming source. - // - There can be multiple logical plan leaves associated with a streaming source. - // - There can be leaves not associated with any streaming source, because they were - // generated from a batch source (e.g. stream-batch joins) - // - // 2. Assuming that the executed plan has same number of leaves in the same order as that of - // the trigger logical plan, we associate executed plan leaves with corresponding - // streaming sources. - // - // 3. For each source, we sum the metrics of the associated execution plan leaves. - // - val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => - logicalPlan.collectLeaves().map { leaf => leaf -> source } - } - - // SPARK-41198: CTE is inlined in optimization phase, which ends up with having different - // number of leaf nodes between (analyzed) logical plan and executed plan. Here we apply - // inlining CTE against logical plan manually if there is a CTE node. - val finalLogicalPlan = unrollCTE(lastExecution.logical) - - val allLogicalPlanLeaves = finalLogicalPlan.collectLeaves() // includes non-streaming - val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() - if (allLogicalPlanLeaves.size == allExecPlanLeaves.size) { - val execLeafToSource = allLogicalPlanLeaves.zip(allExecPlanLeaves).flatMap { - case (_, ep: MicroBatchScanExec) => - // SPARK-41199: `logicalPlanLeafToSource` contains OffsetHolder instance for DSv2 - // streaming source, hence we cannot lookup the actual source from the map. - // The physical node for DSv2 streaming source contains the information of the source - // by itself, so leverage it. - Some(ep -> ep.stream) - case (lp, ep) => - logicalPlanLeafToSource.get(lp).map { source => ep -> source } - } - val sourceToInputRowsTuples = execLeafToSource.map { case (execLeaf, source) => - val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) - source -> numRows - } - sumRows(sourceToInputRowsTuples) - } else { - if (!metricWarningLogged) { - def toString[T](seq: Seq[T]): String = s"(size = ${seq.size}), ${seq.mkString(", ")}" - - logWarning(log"Could not report metrics as number leaves in trigger logical plan did " + - log"not match that of the execution plan:\nlogical plan leaves: " + - log"${MDC(LogKeys.LOGICAL_PLAN_LEAVES, toString(allLogicalPlanLeaves))}\nexecution " + - log"plan leaves: ${MDC(LogKeys.EXECUTION_PLAN_LEAVES, toString(allExecPlanLeaves))}\n") - metricWarningLogged = true - } - Map.empty - } + logWarning("Association for streaming source output has been lost.") + Map.empty } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecutionContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecutionContext.scala index c5e14df3e20e1..e60eaa0ac4e8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecutionContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecutionContext.scala @@ -49,6 +49,8 @@ abstract class StreamExecutionContext( /** Holds the most recent input data for each source. */ var newData: Map[SparkDataStream, LogicalPlan] = _ + var uuidToStream: Map[String, SparkDataStream] = _ + /** * Stores the start offset for this batch. * Only the scheduler thread should modify this field, and only in atomic steps. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 2767f2dd46b2e..d529af59f4849 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -1448,6 +1448,65 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + // FIXME: ...testing... + test("SPARK-XXXXX: partition filter is defined against parquet streaming source") { + withView("view1") { + withTable("table1") { + spark.conf.set(SQLConf.PLAN_CHANGE_LOG_LEVEL.key, "INFO") + spark.sql( + """ + |CREATE TABLE table1 + |(row_date date, value int) + |USING parquet + |PARTITIONED BY (row_date) + |""".stripMargin) + + (1 to 31).foreach { day => + val dateStr = f"2024-01-$day%02d" + spark.range(0 + 100 * day, 100 + 100 * day) + // .selectExpr(s"CAST('$dateStr' AS date) AS row_date", "id AS value") + // FIXME: why the order of column has reversed??? + .selectExpr("id AS value", s"CAST('$dateStr' AS date) AS row_date") + // intended to write a single file + .repartition(1) + .write + .insertInto("table1") + } + + val df1 = spark.readStream + .format("parquet") + // This is just to simplify the case. In production, it won't be 1 but 1000 by default, + // but the amount of backlog could be also huge as well, which would have the same issue. + .option("maxFilesPerTrigger", "1") + .table("table1") + .where("row_date > CAST('2024-01-15' AS date)") + + val query1 = df1 + .selectExpr("CAST(row_date AS string)", "value") + .writeStream + .format("memory") + .queryName("table1output") + .trigger(Trigger.AvailableNow()) + .start() + + query1.processAllAvailable() + + val (batchesZeroInputRows, batchesNonZeroInputRows) = query1.recentProgress + // This filters out update events from idle trigger + .filter(_.durationMs.containsKey("addBatch")) + .partition(_.numInputRows == 0) + + // FIXME: This requires filter pushdown to take effect in DSv1 streaming source. + assert(batchesZeroInputRows.map(_.batchId) === (0 until 15)) + assert(batchesNonZeroInputRows.map(_.batchId) === (15 until 31)) + + logWarning(s"DEBUG: progresses: ${query1.recentProgress.mkString("Array(", ", ", ")")}") + + query1.explain(extended = true) + } + } + } + private def checkAppendOutputModeException(df: DataFrame): Unit = { withTempDir { outputDir => withTempDir { checkpointDir =>