From 2341c22cde694bda13f3f50a55c23d0815502733 Mon Sep 17 00:00:00 2001 From: Steve Vaughan Jr Date: Tue, 23 Apr 2024 11:30:15 -0400 Subject: [PATCH] [SPARK-47050][SQL] Collect and publish partition level metrics Capture the partition sub-paths, along with the number of files, bytes, and rows per partition for each task. --- .../write/PartitionMetricsWriteInfo.java | 99 ++++++++++++++ .../connector/write/PartitionMetrics.scala | 54 ++++++++ .../SparkListenerSQLPartitionMetrics.scala | 54 ++++++++ .../datasources/BasicWriteStatsTracker.scala | 69 +++++++++- .../datasources/FileFormatDataWriter.scala | 15 ++- .../datasources/FileFormatWriter.scala | 9 +- .../datasources/WriteStatsTracker.scala | 14 +- .../datasources/v2/FileBatchWrite.scala | 3 +- .../BasicWriteTaskStatsTrackerSuite.scala | 52 ++++++++ .../CustomWriteTaskStatsTrackerSuite.scala | 2 +- .../metric/SQLMetricsTestUtils.scala | 121 ++++++++++++------ 11 files changed, 441 insertions(+), 51 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/PartitionMetricsWriteInfo.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/PartitionMetrics.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/SparkListenerSQLPartitionMetrics.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/PartitionMetricsWriteInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/PartitionMetricsWriteInfo.java new file mode 100644 index 0000000000000..82455796444e1 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/PartitionMetricsWriteInfo.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.write; + +import java.io.Serializable; +import java.util.Collections; +import java.util.Map; +import java.util.TreeMap; + +/** + * An aggregator of partition metrics collected during write operations. + *

+ * This is patterned after {@code org.apache.spark.util.AccumulatorV2} + *

+ */ +public class PartitionMetricsWriteInfo implements Serializable { + + private final Map metrics = new TreeMap<>(); + + /** + * Merges another same-type accumulator into this one and update its state, i.e. this should be + * merge-in-place. + * + * @param otherAccumulator Another object containing aggregated partition metrics + */ + public void merge(PartitionMetricsWriteInfo otherAccumulator) { + otherAccumulator.metrics.forEach((p, m) -> + metrics.computeIfAbsent(p, key -> new PartitionMetrics(0L, 0L, 0)) + .merge(m)); + } + + /** + * Update the partition metrics for the specified path by adding to the existing state. This will + * add the partition if it has not been referenced previously. + * + * @param partitionPath The path for the written partition + * @param bytes The number of additional bytes + * @param records the number of addition records + * @param files the number of additional files + */ + public void update(String partitionPath, long bytes, long records, int files) { + metrics.computeIfAbsent(partitionPath, key -> new PartitionMetrics(0L, 0L, 0)) + .merge(new PartitionMetrics(bytes, records, files)); + } + + /** + * Update the partition metrics for the specified path by adding to the existing state from an + * individual file. This will add the partition if it has not been referenced previously. + * + * @param partitionPath The path for the written partition + * @param bytes The number of additional bytes + * @param records the number of addition records + */ + public void updateFile(String partitionPath, long bytes, long records) { + update(partitionPath, bytes, records, 1); + } + + /** + * Convert this instance into an immutable {@code java.util.Map}. This is used for posting to the + * listener bus + * + * @return an immutable map of partition paths to their metrics + */ + public Map toMap() { + return Collections.unmodifiableMap(metrics); + } + + /** + * Returns if this accumulator is zero value or not. For a map accumulator this indicates if the + * map is empty. + * + * @return {@code true} if there are no partition metrics + */ + boolean isZero() { + return metrics.isEmpty(); + } + + @Override + public String toString() { + return "PartitionMetricsWriteInfo{" + + "metrics=" + metrics + + '}'; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/PartitionMetrics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/PartitionMetrics.scala new file mode 100644 index 0000000000000..7372e46bb40ef --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/PartitionMetrics.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.write + +/** + * The metrics collected for an individual partition + * + * @param numBytes the number of bytes + * @param numRecords the number of records (rows) + * @param numFiles the number of files + */ +case class PartitionMetrics(var numBytes: Long = 0, var numRecords: Long = 0, var numFiles: Int = 0) + extends Serializable { + + /** + * Updates the metrics for an individual file. + * + * @param bytes the number of bytes + * @param records the number of records (rows) + */ + def updateFile(bytes: Long, records: Long): Unit = { + numBytes += bytes + numRecords += records + numFiles += 1 + } + + /** + * Merges another same-type accumulator into this one and update its state, i.e. this should be + * merge-in-place. + + * @param other Another set of metrics for the same partition + */ + def merge (other: PartitionMetrics): Unit = { + numBytes += other.numBytes + numRecords += other.numRecords + numFiles += other.numFiles + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/SparkListenerSQLPartitionMetrics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/SparkListenerSQLPartitionMetrics.scala new file mode 100644 index 0000000000000..43c3b8bca4c5e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/SparkListenerSQLPartitionMetrics.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.write + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.scheduler.SparkListenerEvent + +@DeveloperApi +case class SparkListenerSQLPartitionMetrics(executorId: Long, + metrics: java.util.Map[String, PartitionMetrics]) + extends SparkListenerEvent + +object SQLPartitionMetrics { + + /** + * Post any aggregated partition write statistics to the listener bus using a + * [[SparkListenerSQLPartitionMetrics]] event + * + * @param sc The Spark context + * @param executionId The identifier for the SQL execution that resulted in the partition writes + * @param writeInfo The aggregated partition writes for this SQL exectuion + */ + def postDriverMetricUpdates(sc: SparkContext, executionId: String, + writeInfo: PartitionMetricsWriteInfo): Unit = { + // Don't bother firing an event if there are no collected metrics + if (writeInfo.isZero) { + return + } + + // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` + // directly without setting an execution id. We should be tolerant to it. + if (executionId != null) { + sc.listenerBus.post( + SparkListenerSQLPartitionMetrics(executionId.toLong, writeInfo.toMap)) + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index a34b235d2d127..63bd49617e70a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -29,6 +29,7 @@ import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKey.{ACTUAL_NUM_FILES, EXPECTED_NUM_FILES} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.write.{PartitionMetricsWriteInfo, SQLPartitionMetrics} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -43,10 +44,18 @@ case class BasicWriteTaskStats( partitions: Seq[InternalRow], numFiles: Int, numBytes: Long, - numRows: Long) + numRows: Long, + partitionsStats: Map[InternalRow, BasicWritePartitionTaskStats] + = Map[InternalRow, BasicWritePartitionTaskStats]()) extends WriteTaskStats +case class BasicWritePartitionTaskStats( + numFiles: Int, + numBytes: Long, + numRows: Long) + extends PartitionTaskStats + /** * Simple [[WriteTaskStatsTracker]] implementation that produces [[BasicWriteTaskStats]]. */ @@ -56,12 +65,19 @@ class BasicWriteTaskStatsTracker( extends WriteTaskStatsTracker with Logging { private[this] val partitions: mutable.ArrayBuffer[InternalRow] = mutable.ArrayBuffer.empty + // Map each partition to counts of the number of files, bytes, and rows written + // partition -> (files, bytes, rows) + private[this] val partitionsStats: mutable.Map[InternalRow, (Int, Long, Long)] = + mutable.Map.empty.withDefaultValue((0, 0L, 0L)) private[this] var numFiles: Int = 0 private[this] var numSubmittedFiles: Int = 0 private[this] var numBytes: Long = 0L private[this] var numRows: Long = 0L private[this] val submittedFiles = mutable.HashSet[String]() + private[this] val submittedPartitionFiles = mutable.Map[String, InternalRow]() + + private[this] val numFileRows: mutable.Map[String, Long] = mutable.Map.empty.withDefaultValue(0) /** * Get the size of the file expected to have been written by a worker. @@ -138,25 +154,45 @@ class BasicWriteTaskStatsTracker( partitions.append(partitionValues) } - override def newFile(filePath: String): Unit = { + override def newFile(filePath: String, partitionValues: Option[InternalRow] = None): Unit = { submittedFiles += filePath numSubmittedFiles += 1 + + // Submitting a file for a partition + if (partitionValues.isDefined) { + submittedPartitionFiles += (filePath -> partitionValues.get) + } } override def closeFile(filePath: String): Unit = { updateFileStats(filePath) submittedFiles.remove(filePath) + submittedPartitionFiles.remove(filePath) + numFileRows.remove(filePath) } private def updateFileStats(filePath: String): Unit = { getFileSize(filePath).foreach { len => numBytes += len numFiles += 1 + + submittedPartitionFiles.get(filePath) + .foreach(partition => { + val stats = partitionsStats(partition) + partitionsStats(partition) = stats.copy( + stats._1 + 1, + stats._2 + len, + stats._3 + numFileRows(filePath)) + }) } } override def newRow(filePath: String, row: InternalRow): Unit = { numRows += 1 + + // Track the number of rows added to each file, which may be accumulated with an associated + // partition + numFileRows(filePath) += 1 } override def getFinalStats(taskCommitTime: Long): WriteTaskStats = { @@ -172,7 +208,15 @@ class BasicWriteTaskStatsTracker( log"writing empty files, or files being not immediately visible in the filesystem.") } taskCommitTimeMetric.foreach(_ += taskCommitTime) - BasicWriteTaskStats(partitions.toSeq, numFiles, numBytes, numRows) + + val publish: ((InternalRow, (Int, Long, Long))) => + (InternalRow, BasicWritePartitionTaskStats) = { + case (key, value) => + val newValue = BasicWritePartitionTaskStats(value._1, value._2, value._3) + key -> newValue + } + BasicWriteTaskStats(partitions.toSeq, numFiles, numBytes, numRows, + partitionsStats.map(publish).toMap) } } @@ -189,6 +233,8 @@ class BasicWriteJobStatsTracker( taskCommitTimeMetric: SQLMetric) extends WriteJobStatsTracker { + val partitionMetrics: PartitionMetricsWriteInfo = new PartitionMetricsWriteInfo() + def this( serializableHadoopConf: SerializableConfiguration, metrics: Map[String, SQLMetric]) = { @@ -199,7 +245,8 @@ class BasicWriteJobStatsTracker( new BasicWriteTaskStatsTracker(serializableHadoopConf.value, Some(taskCommitTimeMetric)) } - override def processStats(stats: Seq[WriteTaskStats], jobCommitTime: Long): Unit = { + override def processStats(stats: Seq[WriteTaskStats], jobCommitTime: Long, + partitionsMap: Map[InternalRow, String]): Unit = { val sparkContext = SparkContext.getActive.get val partitionsSet: mutable.Set[InternalRow] = mutable.HashSet.empty var numFiles: Long = 0L @@ -213,6 +260,14 @@ class BasicWriteJobStatsTracker( numFiles += summary.numFiles totalNumBytes += summary.numBytes totalNumOutput += summary.numRows + + summary.partitionsStats.foreach(s => { + // Check if we know the mapping of the internal row to a partition path + if (partitionsMap.contains(s._1)) { + val path = partitionsMap(s._1) + partitionMetrics.update(path, s._2.numBytes, s._2.numRows, s._2.numFiles) + } + }) } driverSideMetrics(JOB_COMMIT_TIME).add(jobCommitTime) @@ -223,6 +278,9 @@ class BasicWriteJobStatsTracker( val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, driverSideMetrics.values.toList) + + SQLPartitionMetrics.postDriverMetricUpdates(sparkContext, executionId, + partitionMetrics) } } @@ -247,4 +305,7 @@ object BasicWriteJobStatsTracker { JOB_COMMIT_TIME -> SQLMetrics.createTimingMetric(sparkContext, "job commit time") ) } + + def partitionMetrics: mutable.Map[String, PartitionTaskStats] = + mutable.Map.empty.withDefaultValue(BasicWritePartitionTaskStats(0, 0L, 0L)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index 1dbb6ce26f693..b4470df048226 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -55,6 +55,8 @@ abstract class FileFormatDataWriter( */ protected val MAX_FILE_COUNTER: Int = 1000 * 1000 protected val updatedPartitions: mutable.Set[String] = mutable.Set[String]() + protected val updatedPartitionsMap: mutable.Map[InternalRow, String] + = mutable.Map[InternalRow, String]() protected var currentWriter: OutputWriter = _ /** Trackers for computing various statistics on the data as it's being written out. */ @@ -126,7 +128,8 @@ abstract class FileFormatDataWriter( } val summary = ExecutedWriteSummary( updatedPartitions = updatedPartitions.toSet, - stats = statsTrackers.map(_.getFinalStats(taskCommitTime))) + stats = statsTrackers.map(_.getFinalStats(taskCommitTime)), + updatedPartitionsMap.toMap) WriteTaskResult(taskCommitMessage, summary) } @@ -178,7 +181,7 @@ class SingleDirectoryDataWriter( dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) - statsTrackers.foreach(_.newFile(currentPath)) + statsTrackers.foreach(_.newFile(currentPath, None)) } override def write(record: InternalRow): Unit = { @@ -287,6 +290,9 @@ abstract class BaseDynamicPartitionDataWriter( val partDir = partitionValues.map(getPartitionPath(_)) partDir.foreach(updatedPartitions.add) + if (partDir.isDefined) { + partitionValues.foreach(updatedPartitionsMap(_) = partDir.get) + } val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") @@ -315,7 +321,7 @@ abstract class BaseDynamicPartitionDataWriter( dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) - statsTrackers.foreach(_.newFile(currentPath)) + statsTrackers.foreach(_.newFile(currentPath, partitionValues)) } /** @@ -628,4 +634,5 @@ case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteS */ case class ExecutedWriteSummary( updatedPartitions: Set[String], - stats: Seq[WriteTaskStats]) + stats: Seq[WriteTaskStats], + updatedPartitionsMap: Map[InternalRow, String]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 3bfa3413f6796..01bafda63515f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -275,7 +275,8 @@ object FileFormatWriter extends Logging { logInfo(s"Write Job ${description.uuid} committed. Elapsed time: $duration ms.") processStats( - description.statsTrackers, ret.map(_.summary.stats).toImmutableArraySeq, duration) + description.statsTrackers, ret.map(_.summary.stats).toImmutableArraySeq, duration, + ret.map(_.summary.updatedPartitionsMap).reduce(_ ++ _)) logInfo(s"Finished processing stats for write job ${description.uuid}.") // return a set of all the partition paths that were updated during this job @@ -417,7 +418,8 @@ object FileFormatWriter extends Logging { private[datasources] def processStats( statsTrackers: Seq[WriteJobStatsTracker], statsPerTask: Seq[Seq[WriteTaskStats]], - jobCommitDuration: Long) + jobCommitDuration: Long, + partitionsMap: Map[InternalRow, String]) : Unit = { val numStatsTrackers = statsTrackers.length @@ -434,7 +436,8 @@ object FileFormatWriter extends Logging { } statsTrackers.zip(statsPerTracker).foreach { - case (statsTracker, stats) => statsTracker.processStats(stats, jobCommitDuration) + case (statsTracker, stats) => + statsTracker.processStats(stats, jobCommitDuration, partitionsMap) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala index 157ed0120bf3a..fdd4762e69f7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala @@ -28,6 +28,12 @@ import org.apache.spark.sql.catalyst.InternalRow trait WriteTaskStats extends Serializable +trait PartitionTaskStats extends Serializable { + def numFiles: Int + def numBytes: Long + def numRows: Long +} + /** * A trait for classes that are capable of collecting statistics on data that's being processed by * a single write task in [[FileFormatWriter]] - i.e. there should be one instance per executor. @@ -46,8 +52,10 @@ trait WriteTaskStatsTracker { /** * Process the fact that a new file is about to be written. * @param filePath Path of the file into which future rows will be written. + * @param partitionValues Optional reference to the partition associated with this new file. This + * avoids trying to extract the partition values from the filePath. */ - def newFile(filePath: String): Unit + def newFile(filePath: String, partitionValues: Option[InternalRow] = None): Unit /** * Process the fact that a file is finished to be written and closed. @@ -95,6 +103,7 @@ trait WriteJobStatsTracker extends Serializable { * E.g. aggregate them, write them to memory / disk, issue warnings, whatever. * @param stats One [[WriteTaskStats]] object from each successful write task. * @param jobCommitTime Time of committing the job. + * @param partitionsMap A map of [[InternalRow]] to a partition subpath * @note The type of @param `stats` is too generic. These classes should probably be parametrized: * WriteTaskStatsTracker[S <: WriteTaskStats] * WriteJobStatsTracker[S <: WriteTaskStats, T <: WriteTaskStatsTracker[S]] @@ -105,5 +114,6 @@ trait WriteJobStatsTracker extends Serializable { * to the expected derived type when implementing this method in a derived class. * The framework will make sure to call this with the right arguments. */ - def processStats(stats: Seq[WriteTaskStats], jobCommitTime: Long): Unit + def processStats(stats: Seq[WriteTaskStats], jobCommitTime: Long, + partitionsMap: Map[InternalRow, String] = Map.empty): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala index 2f443a0bb1fad..7b52e647f2548 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala @@ -39,7 +39,8 @@ class FileBatchWrite( logInfo(s"Write Job ${description.uuid} committed. Elapsed time: $duration ms.") processStats( - description.statsTrackers, results.map(_.summary.stats).toImmutableArraySeq, duration) + description.statsTrackers, results.map(_.summary.stats).toImmutableArraySeq, duration, + results.map(_.summary.updatedPartitionsMap).reduce(_ ++ _)) logInfo(s"Finished processing stats for write job ${description.uuid}.") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala index e8b9dcf172a78..661ad5e95da85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala @@ -23,7 +23,11 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FilterFileSystem, Path} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker.FILE_LENGTH_XATTR +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.Utils /** @@ -72,6 +76,16 @@ class BasicWriteTaskStatsTrackerSuite extends SparkFunSuite { assert(bytes === stats.numBytes, "Wrong byte count of file size") } + private def assertPartitionStats( + tracker: BasicWriteTaskStatsTracker, + partition: InternalRow, + files: Int, + bytes: Int): Unit = { + val stats = finalStatus(tracker).partitionsStats.apply(partition) + assert(files === stats.numFiles, "Wrong number of files") + assert(bytes === stats.numBytes, "Wrong byte count of file size") + } + private def finalStatus(tracker: BasicWriteTaskStatsTracker): BasicWriteTaskStats = { tracker.getFinalStats(0L).asInstanceOf[BasicWriteTaskStats] } @@ -123,6 +137,23 @@ class BasicWriteTaskStatsTrackerSuite extends SparkFunSuite { assertStats(tracker, 1, len1) } + test("Partitioned file with data") { + val schema: StructType = new StructType() + .add("day", StringType) + val row = Row("day", "Monday") + val lit = Literal.create(row, schema) + val internalRow = lit.value.asInstanceOf[InternalRow] + + val file = new Path(tempDirPath, "day=Monday/file-with-data") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newPartition(internalRow) + tracker.newFile(file.toString, Some(internalRow)) + write1(file) + tracker.closeFile(file.toString) + assertStats(tracker, 1, len1) + assertPartitionStats(tracker, internalRow, 1, len1) + } + test("Open file") { val file = new Path(tempDirPath, "file-open") val tracker = new BasicWriteTaskStatsTracker(conf) @@ -152,6 +183,27 @@ class BasicWriteTaskStatsTrackerSuite extends SparkFunSuite { assertStats(tracker, 2, len1 + len2) } + test("Partitioned two files") { + val schema: StructType = new StructType() + .add("day", StringType) + val row = Row("day", "Monday") + val lit = Literal.create(row, schema) + val internalRow = lit.value.asInstanceOf[InternalRow] + + val file1 = new Path(tempDirPath, "day=Monday/f-2-1") + val file2 = new Path(tempDirPath, "day=Monday/f-2-2") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newPartition(internalRow) + tracker.newFile(file1.toString, Some(internalRow)) + write1(file1) + tracker.closeFile(file1.toString) + tracker.newFile(file2.toString, Some(internalRow)) + write2(file2) + tracker.closeFile(file2.toString) + assertStats(tracker, 2, len1 + len2) + assertPartitionStats(tracker, internalRow, 2, len1 + len2) + } + test("Three files, last one empty") { val file1 = new Path(tempDirPath, "f-3-1") val file2 = new Path(tempDirPath, "f-3-2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CustomWriteTaskStatsTrackerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CustomWriteTaskStatsTrackerSuite.scala index e9f625b2ded9f..8e2e4b10a18be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CustomWriteTaskStatsTrackerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CustomWriteTaskStatsTrackerSuite.scala @@ -54,7 +54,7 @@ class CustomWriteTaskStatsTracker extends WriteTaskStatsTracker { override def newPartition(partitionValues: InternalRow): Unit = {} - override def newFile(filePath: String): Unit = { + override def newFile(filePath: String, partitionValues: Option[InternalRow]): Unit = { numRowsPerFile.put(filePath, 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 72aa607591d57..b796ecf9b6e59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -19,12 +19,14 @@ package org.apache.spark.sql.execution.metric import java.io.File +import scala.collection.mutable import scala.collection.mutable.HashMap import org.apache.spark.TestUtils -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} -import org.apache.spark.sql.DataFrame +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerTaskEnd} +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.connector.write.SparkListenerSQLPartitionMetrics import org.apache.spark.sql.execution.{SparkPlan, SparkPlanInfo} import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SQLAppStatusStore} import org.apache.spark.sql.internal.SQLConf.WHOLESTAGE_CODEGEN_ENABLED @@ -104,55 +106,102 @@ trait SQLMetricsTestUtils extends SQLTestUtils { assert(totalNumBytes > 0) } + private class CapturePartitionMetrics extends SparkListener { + + val events: mutable.Buffer[SparkListenerSQLPartitionMetrics] = + mutable.Buffer[SparkListenerSQLPartitionMetrics]() + + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case metrics: SparkListenerSQLPartitionMetrics => events += metrics + case _ => + } + } + } + + protected def withSparkListener[L <: SparkListener] + (spark: SparkSession, listener: L) + (body: L => Unit): Unit = { + spark.sparkContext.addSparkListener(listener) + try { + body(listener) + } + finally { + spark.sparkContext.removeSparkListener(listener) + } + } + protected def testMetricsNonDynamicPartition( dataFormat: String, tableName: String): Unit = { - withTable(tableName) { - Seq((1, 2)).toDF("i", "j") - .write.format(dataFormat).mode("overwrite").saveAsTable(tableName) - - val tableLocation = - new File(spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).location) - - // 2 files, 100 rows, 0 dynamic partition. - verifyWriteDataMetrics(Seq(2, 0, 100)) { - (0 until 100).map(i => (i, i + 1)).toDF("i", "j").repartition(2) - .write.format(dataFormat).mode("overwrite").insertInto(tableName) + val listener = new CapturePartitionMetrics() + withSparkListener(spark, listener) { _ => + withTable(tableName) { + Seq((1, 2)).toDF("i", "j") + .write.format(dataFormat).mode("overwrite").saveAsTable(tableName) + + val tableLocation = + new File(spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).location) + + // 2 files, 100 rows, 0 dynamic partition. + verifyWriteDataMetrics(Seq(2, 0, 100)) { + (0 until 100).map(i => (i, i + 1)).toDF("i", "j").repartition(2) + .write.format(dataFormat).mode("overwrite").insertInto(tableName) + } + assert(TestUtils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2) } - assert(TestUtils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2) } + + // Verify that there are no partition metrics for the entire write process. + assert(listener.events.isEmpty) } protected def testMetricsDynamicPartition( provider: String, dataFormat: String, tableName: String): Unit = { - withTable(tableName) { - withTempPath { dir => - spark.sql( - s""" - |CREATE TABLE $tableName(a int, b int) - |USING $provider - |PARTITIONED BY(a) - |LOCATION '${dir.toURI}' + val listener = new CapturePartitionMetrics() + withSparkListener(spark, listener) { _ => + withTable(tableName) { + withTempPath { dir => + spark.sql( + s""" + |CREATE TABLE $tableName(a int, b int) + |USING $provider + |PARTITIONED BY(a) + |LOCATION '${dir.toURI}' """.stripMargin) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) - assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) - - val df = spark.range(start = 0, end = 40, step = 1, numPartitions = 1) - .selectExpr("id a", "id b") - - // 40 files, 80 rows, 40 dynamic partitions. - verifyWriteDataMetrics(Seq(40, 40, 80)) { - df.union(df).repartition(2, $"a") - .write - .format(dataFormat) - .mode("overwrite") - .insertInto(tableName) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + val df = spark.range(start = 0, end = 40, step = 1, numPartitions = 1) + .selectExpr("id a", "id b") + + // 40 files, 80 rows, 40 dynamic partitions. + verifyWriteDataMetrics(Seq(40, 40, 80)) { + df.union(df).repartition(2, $"a") + .write + .format(dataFormat) + .mode("overwrite") + .insertInto(tableName) + } + assert(TestUtils.recursiveList(dir).count(_.getName.startsWith("part-")) == 40) } - assert(TestUtils.recursiveList(dir).count(_.getName.startsWith("part-")) == 40) } } + + // Verify that there a single event for partition metrics for the entire write process. This + // test creates the table and performs a repartitioning, but only 1 action actually results + // in collecting partition metrics. + assert(listener.events.length == 1) + val event = listener.events.head + + // Verify the number of partitions + assert(event.metrics.keySet.size == 40) + // Verify the number of files per partition + event.metrics.values.forEach(partitionStats => assert(partitionStats.numFiles == 1)) + // Verify the number of rows per partition + event.metrics.values.forEach(partitionStats => assert(partitionStats.numRecords == 2)) } /**