From 271091d85b5d285655a514f61e805d6285994920 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Fri, 27 Jan 2023 09:20:08 -0800 Subject: [PATCH] [SPARK-42168][SQL][PYTHON][FOLLOW-UP] Test FlatMapCoGroupsInPandas with Window function ### What changes were proposed in this pull request? This ports tests from #39717 in branch-3.2 to master. ### Why are the changes needed? To make sure this use case is tested. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? E2E test in `test_pandas_cogrouped_map.py` and analysis test in `EnsureRequirementsSuite.scala`. Closes #39752 from EnricoMi/branch-cogroup-window-bug-test. Authored-by: Enrico Minack Signed-off-by: Chao Sun --- .../tests/pandas/test_pandas_cogrouped_map.py | 54 +++++++++++++++++- .../exchange/EnsureRequirementsSuite.scala | 56 +++++++++++++++++++ 2 files changed, 109 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py index d92a105f5d4b5..5cbc9e1caa430 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py @@ -18,8 +18,9 @@ import unittest from typing import cast -from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf +from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf, sum from pyspark.sql.types import DoubleType, StructType, StructField, Row +from pyspark.sql.window import Window from pyspark.errors import IllegalArgumentException, PythonException from pyspark.testing.sqlutils import ( ReusedSQLTestCase, @@ -365,6 +366,57 @@ def test_self_join(self): self.assertEqual(row.asDict(), Row(column=2, value=2).asDict()) + def test_with_window_function(self): + # SPARK-42168: a window function with same partition keys but differing key order + ids = 2 + days = 100 + vals = 10000 + parts = 10 + + id_df = self.spark.range(ids) + day_df = self.spark.range(days).withColumnRenamed("id", "day") + vals_df = self.spark.range(vals).withColumnRenamed("id", "value") + df = id_df.join(day_df).join(vals_df) + + left_df = df.withColumnRenamed("value", "left").repartition(parts).cache() + # SPARK-42132: this bug requires us to alias all columns from df here + right_df = ( + df.select(col("id").alias("id"), col("day").alias("day"), col("value").alias("right")) + .repartition(parts) + .cache() + ) + + # note the column order is different to the groupBy("id", "day") column order below + window = Window.partitionBy("day", "id") + + left_grouped_df = left_df.groupBy("id", "day") + right_grouped_df = right_df.withColumn("day_sum", sum(col("day")).over(window)).groupBy( + "id", "day" + ) + + def cogroup(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: + return pd.DataFrame( + [ + { + "id": left["id"][0] + if not left.empty + else (right["id"][0] if not right.empty else None), + "day": left["day"][0] + if not left.empty + else (right["day"][0] if not right.empty else None), + "lefts": len(left.index), + "rights": len(right.index), + } + ] + ) + + df = left_grouped_df.cogroup(right_grouped_df).applyInPandas( + cogroup, schema="id long, day long, lefts integer, rights integer" + ) + + actual = df.orderBy("id", "day").take(days) + self.assertEqual(actual, [Row(0, day, vals, vals) for day in range(days)]) + @staticmethod def _test_with_key(left, right, isLeft): def right_assign_key(key, lft, rgt): diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index bc1fd7a5fa5b7..844037339ab9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution.exchange +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan @@ -25,9 +27,12 @@ import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.execution.{DummySparkPlan, SortExec} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec +import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class EnsureRequirementsSuite extends SharedSparkSession { private val exprA = Literal(1) @@ -1104,6 +1109,57 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } + test("SPARK-42168: FlatMapCoGroupInPandas and Window function with differing key order") { + val lKey = AttributeReference("key", IntegerType)() + val lKey2 = AttributeReference("key2", IntegerType)() + + val rKey = AttributeReference("key", IntegerType)() + val rKey2 = AttributeReference("key2", IntegerType)() + val rValue = AttributeReference("value", IntegerType)() + + val left = DummySparkPlan() + val right = WindowExec( + Alias( + WindowExpression( + Sum(rValue).toAggregateExpression(), + WindowSpecDefinition( + Seq(rKey2, rKey), + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) + ) + ), "sum")() :: Nil, + Seq(rKey2, rKey), + Nil, + DummySparkPlan() + ) + + val pythonUdf = PythonUDF("pyUDF", null, + StructType(Seq(StructField("value", IntegerType))), + Seq.empty, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + true) + + val flapMapCoGroup = FlatMapCoGroupsInPandasExec( + Seq(lKey, lKey2), + Seq(rKey, rKey2), + pythonUdf, + AttributeReference("value", IntegerType)() :: Nil, + left, + right + ) + + val result = EnsureRequirements.apply(flapMapCoGroup) + result match { + case FlatMapCoGroupsInPandasExec(leftKeys, rightKeys, _, _, + SortExec(leftOrder, false, _, _), SortExec(rightOrder, false, _, _)) => + assert(leftKeys === Seq(lKey, lKey2)) + assert(rightKeys === Seq(rKey, rKey2)) + assert(leftKeys.map(k => SortOrder(k, Ascending)) === leftOrder) + assert(rightKeys.map(k => SortOrder(k, Ascending)) === rightOrder) + case other => fail(other.toString) + } + } + def bucket(numBuckets: Int, expr: Expression): TransformExpression = { TransformExpression(BucketFunction, Seq(expr), Some(numBuckets)) }