Skip to content

Commit

Permalink
[SPARK-42168][3.4][SQL][PYTHON][FOLLOW-UP] Test FlatMapCoGroupsInPand…
Browse files Browse the repository at this point in the history
…as with Window function

### What changes were proposed in this pull request?
This ports tests from #39717 in branch-3.2 to branch-3.4. See #39752 (comment).

### Why are the changes needed?
To make sure this use case is tested.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
E2E test in `test_pandas_cogrouped_map.py` and analysis test in `EnsureRequirementsSuite.scala`.

Closes #39803 from EnricoMi/branch-3.4-cogroup-window-bug-test.

Authored-by: Enrico Minack <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
EnricoMi authored and HyukjinKwon committed Jan 30, 2023
1 parent 545df6d commit 837fb7c
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 1 deletion.
54 changes: 53 additions & 1 deletion python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import unittest
from typing import cast

from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf
from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf, sum
from pyspark.sql.types import DoubleType, StructType, StructField, Row
from pyspark.sql.window import Window
from pyspark.errors import IllegalArgumentException, PythonException
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
Expand Down Expand Up @@ -365,6 +366,57 @@ def test_self_join(self):

self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())

def test_with_window_function(self):
# SPARK-42168: a window function with same partition keys but differing key order
ids = 2
days = 100
vals = 10000
parts = 10

id_df = self.spark.range(ids)
day_df = self.spark.range(days).withColumnRenamed("id", "day")
vals_df = self.spark.range(vals).withColumnRenamed("id", "value")
df = id_df.join(day_df).join(vals_df)

left_df = df.withColumnRenamed("value", "left").repartition(parts).cache()
# SPARK-42132: this bug requires us to alias all columns from df here
right_df = (
df.select(col("id").alias("id"), col("day").alias("day"), col("value").alias("right"))
.repartition(parts)
.cache()
)

# note the column order is different to the groupBy("id", "day") column order below
window = Window.partitionBy("day", "id")

left_grouped_df = left_df.groupBy("id", "day")
right_grouped_df = right_df.withColumn("day_sum", sum(col("day")).over(window)).groupBy(
"id", "day"
)

def cogroup(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
return pd.DataFrame(
[
{
"id": left["id"][0]
if not left.empty
else (right["id"][0] if not right.empty else None),
"day": left["day"][0]
if not left.empty
else (right["day"][0] if not right.empty else None),
"lefts": len(left.index),
"rights": len(right.index),
}
]
)

df = left_grouped_df.cogroup(right_grouped_df).applyInPandas(
cogroup, schema="id long, day long, lefts integer, rights integer"
)

actual = df.orderBy("id", "day").take(days)
self.assertEqual(actual, [Row(0, day, vals, vals) for day in range(days)])

@staticmethod
def _test_with_key(left, right, isLeft):
def right_assign_key(key, lft, rgt):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,22 @@

package org.apache.spark.sql.execution.exchange

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class EnsureRequirementsSuite extends SharedSparkSession {
private val exprA = Literal(1)
Expand Down Expand Up @@ -1104,6 +1109,57 @@ class EnsureRequirementsSuite extends SharedSparkSession {
}
}

test("SPARK-42168: FlatMapCoGroupInPandas and Window function with differing key order") {
val lKey = AttributeReference("key", IntegerType)()
val lKey2 = AttributeReference("key2", IntegerType)()

val rKey = AttributeReference("key", IntegerType)()
val rKey2 = AttributeReference("key2", IntegerType)()
val rValue = AttributeReference("value", IntegerType)()

val left = DummySparkPlan()
val right = WindowExec(
Alias(
WindowExpression(
Sum(rValue).toAggregateExpression(),
WindowSpecDefinition(
Seq(rKey2, rKey),
Nil,
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)
)
), "sum")() :: Nil,
Seq(rKey2, rKey),
Nil,
DummySparkPlan()
)

val pythonUdf = PythonUDF("pyUDF", null,
StructType(Seq(StructField("value", IntegerType))),
Seq.empty,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
true)

val flapMapCoGroup = FlatMapCoGroupsInPandasExec(
Seq(lKey, lKey2),
Seq(rKey, rKey2),
pythonUdf,
AttributeReference("value", IntegerType)() :: Nil,
left,
right
)

val result = EnsureRequirements.apply(flapMapCoGroup)
result match {
case FlatMapCoGroupsInPandasExec(leftKeys, rightKeys, _, _,
SortExec(leftOrder, false, _, _), SortExec(rightOrder, false, _, _)) =>
assert(leftKeys === Seq(lKey, lKey2))
assert(rightKeys === Seq(rKey, rKey2))
assert(leftKeys.map(k => SortOrder(k, Ascending)) === leftOrder)
assert(rightKeys.map(k => SortOrder(k, Ascending)) === rightOrder)
case other => fail(other.toString)
}
}

def bucket(numBuckets: Int, expr: Expression): TransformExpression = {
TransformExpression(BucketFunction, Seq(expr), Some(numBuckets))
}
Expand Down

0 comments on commit 837fb7c

Please sign in to comment.