Skip to content

Commit

Permalink
[SPARK-50489][SQL][PYTHON] Fix self-join after applyInArrow
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Fix self-join after `applyInArrow`, the same issue of `applyInPandas` was fixed in apache#31429

### Why are the changes needed?
bug fix

before:
```
In [1]: import pyarrow as pa

In [2]: df = spark.createDataFrame([(1, 1)], ("k", "v"))

In [3]: def arrow_func(key, table):
   ...:     return pa.Table.from_pydict({"x": [2], "y": [2]})
   ...:

In [4]: df2 = df.groupby("k").applyInArrow(arrow_func, schema="x long, y long")

In [5]: df2.show()
24/12/04 17:47:43 WARN CheckAllocator: More than one DefaultAllocationManager on classpath. Choosing first found
+---+---+
|  x|  y|
+---+---+
|  2|  2|
+---+---+

In [6]: df2.join(df2)
...
Failure when resolving conflicting references in Join:
'Join Inner
:- FlatMapGroupsInArrow [k#0L], arrow_func(k#0L, v#1L)#2, [x#3L, y#4L]
:  +- Project [k#0L, k#0L, v#1L]
:     +- LogicalRDD [k#0L, v#1L], false
+- FlatMapGroupsInArrow [k#12L], arrow_func(k#12L, v#13L)#2, [x#3L, y#4L]
   +- Project [k#12L, k#12L, v#13L]
      +- LogicalRDD [k#12L, v#13L], false

Conflicting attributes: "x", "y". SQLSTATE: XX000
	at org.apache.spark.SparkException$.internalError(SparkException.scala:92)
	at org.apache.spark.SparkException$.internalError(SparkException.scala:79)
	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$2(CheckAnalysis.scala:798)
```

after:
```
In [6]: df2.join(df2)
Out[6]: DataFrame[x: bigint, y: bigint, x: bigint, y: bigint]

In [7]: df2.join(df2).show()
+---+---+---+---+
|  x|  y|  x|  y|
+---+---+---+---+
|  2|  2|  2|  2|
+---+---+---+---+
```

### Does this PR introduce _any_ user-facing change?
bug fix

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#49056 from zhengruifeng/fix_arrow_join.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Dec 5, 2024
1 parent fe904e6 commit 7278bc7
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python/pyspark/sql/tests/test_arrow_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,16 @@ def summarize(left, right):
"+---------+------------+----------+-------------+\n",
)

def test_self_join(self):
df = self.spark.createDataFrame([(1, 1)], ("k", "v"))

def arrow_func(key, left, right):
return pa.Table.from_pydict({"x": [2], "y": [2]})

df2 = df.groupby("k").cogroup(df.groupby("k")).applyInArrow(arrow_func, "x long, y long")

self.assertEqual(df2.join(df2).count(), 1)


class CogroupedMapInArrowTests(CogroupedMapInArrowTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/sql/tests/test_arrow_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,16 @@ def foo(_):
self.assertEqual(r.a, "hi")
self.assertEqual(r.b, 1)

def test_self_join(self):
df = self.spark.createDataFrame([(1, 1)], ("k", "v"))

def arrow_func(key, table):
return pa.Table.from_pydict({"x": [2], "y": [2]})

df2 = df.groupby("k").applyInArrow(arrow_func, schema="x long, y long")

self.assertEqual(df2.join(df2).count(), 1)


class GroupedMapInArrowTests(GroupedMapInArrowTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,27 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
_.output.map(_.exprId.id),
newFlatMap => newFlatMap.copy(output = newFlatMap.output.map(_.newInstance())))

case f: FlatMapGroupsInArrow =>
deduplicateAndRenew[FlatMapGroupsInArrow](
existingRelations,
f,
_.output.map(_.exprId.id),
newFlatMap => newFlatMap.copy(output = newFlatMap.output.map(_.newInstance())))

case f: FlatMapCoGroupsInPandas =>
deduplicateAndRenew[FlatMapCoGroupsInPandas](
existingRelations,
f,
_.output.map(_.exprId.id),
newFlatMap => newFlatMap.copy(output = newFlatMap.output.map(_.newInstance())))

case f: FlatMapCoGroupsInArrow =>
deduplicateAndRenew[FlatMapCoGroupsInArrow](
existingRelations,
f,
_.output.map(_.exprId.id),
newFlatMap => newFlatMap.copy(output = newFlatMap.output.map(_.newInstance())))

case m: MapInPandas =>
deduplicateAndRenew[MapInPandas](
existingRelations,
Expand Down

0 comments on commit 7278bc7

Please sign in to comment.