From c5b557195f7115b8590f655aed56465ff91f1d88 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Mon, 16 Dec 2024 17:34:22 -0800 Subject: [PATCH] [SPARK-50593][SQL] SPJ: Support truncate transform --- .../plans/physical/partitioning.scala | 19 ++++--- .../connector/catalog/InMemoryBaseTable.scala | 2 +- .../exchange/EnsureRequirements.scala | 1 + .../KeyGroupedPartitioningSuite.scala | 55 +++++++++++++++++-- 4 files changed, 63 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 30e223c3c3c87..de9bc645a83b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -383,6 +383,7 @@ case class KeyGroupedPartitioning( } else { // We'll need to find leaf attributes from the partition expressions first. val attributes = expressions.flatMap(_.collectLeaves()) + .filter(KeyGroupedPartitioning.isReference) if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { // check that join keys (required clustering keys) @@ -457,14 +458,7 @@ object KeyGroupedPartitioning { def supportsExpressions(expressions: Seq[Expression]): Boolean = { def isSupportedTransform(transform: TransformExpression): Boolean = { - transform.children.size == 1 && isReference(transform.children.head) - } - - @tailrec - def isReference(e: Expression): Boolean = e match { - case _: Attribute => true - case g: GetStructField => isReference(g.child) - case _ => false + transform.children.count(isReference) == 1 } expressions.forall { @@ -473,6 +467,13 @@ object KeyGroupedPartitioning { case _ => false } } + + @tailrec + def isReference(e: Expression): Boolean = e match { + case _: Attribute => true + case g: GetStructField => isReference(g.child) + case _ => false + } } /** @@ -791,7 +792,7 @@ case class KeyGroupedShuffleSpec( distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos) } partitioning.expressions.map { e => - val leaves = e.collectLeaves() + val leaves = e.collectLeaves().filter(KeyGroupedPartitioning.isReference) assert(leaves.size == 1, s"Expected exactly one child from $e, but found ${leaves.size}") distKeyToPos.getOrElse(leaves.head.canonicalized, mutable.BitSet.empty) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index ab17b93ad6146..6b661253c13a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -648,7 +648,7 @@ case class PartitionInternalRow(keys: Array[Any]) return false } // Just compare by reference, not by value - this.keys == other.asInstanceOf[PartitionInternalRow].keys + this.keys sameElements other.asInstanceOf[PartitionInternalRow].keys } override def hashCode: Int = { Objects.hashCode(keys) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 8ec903f8e61da..10dfa21c1b57c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -627,6 +627,7 @@ case class EnsureRequirements( distribution: ClusteredDistribution): Option[KeyGroupedShuffleSpec] = { def tryCreate(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = { val attributes = partitioning.expressions.flatMap(_.collectLeaves()) + .filter(KeyGroupedPartitioning.isReference) val clustering = distribution.clustering val satisfies = if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 152896499010c..79be6c14d1802 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Literal, TransformExpression} import org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryTableCatalog} +import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning +import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryTableCatalog, PartitionInternalRow} import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.connector.distributions.Distributions import org.apache.spark.sql.connector.expressions._ @@ -37,6 +38,7 @@ import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { private val functions = Seq( @@ -195,10 +197,16 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { s"(2, 'ccc', CAST('2020-01-01' AS timestamp))") val df = sql(s"SELECT * FROM testcat.ns.$table") - val distribution = physical.ClusteredDistribution( - Seq(TransformExpression(TruncateFunction, Seq(attr("data"), Literal(2))))) + val transformExpression = Seq(TransformExpression( + TruncateFunction, Seq(attr("data"), Literal(2)))) + val distribution = physical.ClusteredDistribution(transformExpression) + val partValues = Seq( + PartitionInternalRow(Array(UTF8String.fromString("aa"))), + PartitionInternalRow(Array(UTF8String.fromString("bb"))), + PartitionInternalRow(Array(UTF8String.fromString("cc")))) + val partitioning = new KeyGroupedPartitioning(transformExpression, 3, partValues, partValues) - checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) + checkQueryPlan(df, distribution, partitioning) } /** @@ -2504,4 +2512,43 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { assert(scans.forall(_.inputRDD.partitions.length == 2)) } } + + test("SPARK-50593: SPJ: Support truncate transform") { + val partitions: Array[Transform] = Array( + Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(2)) + ) + + // create a table with 3 partitions, partitioned by `truncate` transform + createTable("table", columns, partitions) + sql(s"INSERT INTO testcat.ns.table VALUES " + + s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + + s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " + + s"(2, 'ccc', CAST('2020-01-01' AS timestamp))") + + createTable("table2", columns2, partitions) + sql(s"INSERT INTO testcat.ns.table2 VALUES " + + s"(1, 5, 'aaa')," + + s"(5, 10, 'bbb')," + + s"(20, 40, 'bbb')," + + s"(40, 80, 'ddd')") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true") { + + val df = + sql( + selectWithMergeJoinHint("table", "table2") + + "id, store_id, dept_id " + + "FROM testcat.ns.table JOIN testcat.ns.table2 " + + "ON table.data = table2.data " + + "SORT BY id, store_id, dept_id") + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + checkAnswer(df, + Seq(Row(0, 1, 5), Row(1, 5, 10), Row(1, 20, 40)) + ) + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.forall(_.inputRDD.partitions.length == 4)) + } + } }