Skip to content

Commit

Permalink
[SPARK-29664][PYTHON][SQL] Column.getItem behavior is not consistent …
Browse files Browse the repository at this point in the history
…with Scala

### What changes were proposed in this pull request?

This PR changes the behavior of `Column.getItem` to call `Column.getItem` on Scala side instead of `Column.apply`.

### Why are the changes needed?

The current behavior is not consistent with that of Scala.

In PySpark:
```Python
df = spark.range(2)
map_col = create_map(lit(0), lit(100), lit(1), lit(200))
df.withColumn("mapped", map_col.getItem(col('id'))).show()
# +---+------+
# | id|mapped|
# +---+------+
# |  0|   100|
# |  1|   200|
# +---+------+
```
In Scala:
```Scala
val df = spark.range(2)
val map_col = map(lit(0), lit(100), lit(1), lit(200))
// The following getItem results in the following exception, which is the right behavior:
// java.lang.RuntimeException: Unsupported literal type class org.apache.spark.sql.Column id
//  at org.apache.spark.sql.catalyst.expressions.Literal$.apply(literals.scala:78)
//  at org.apache.spark.sql.Column.getItem(Column.scala:856)
//  ... 49 elided
df.withColumn("mapped", map_col.getItem(col("id"))).show
```

### Does this PR introduce any user-facing change?

Yes. If the use wants to pass `Column` object to `getItem`, he/she now needs to use the indexing operator to achieve the previous behavior.

```Python
df = spark.range(2)
map_col = create_map(lit(0), lit(100), lit(1), lit(200))
df.withColumn("mapped", map_col[col('id'))].show()
# +---+------+
# | id|mapped|
# +---+------+
# |  0|   100|
# |  1|   200|
# +---+------+
```

### How was this patch tested?

Existing tests.

Closes #26351 from imback82/spark-29664.

Authored-by: Terry Kim <[email protected]>
Signed-off-by: HyukjinKwon <[email protected]>
  • Loading branch information
imback82 authored and HyukjinKwon committed Nov 1, 2019
1 parent 8a8ac00 commit 3175f4b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
3 changes: 3 additions & 0 deletions docs/pyspark-migration-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ Please refer [Migration Guide: SQL, Datasets and DataFrame](sql-migration-guide.

- Since Spark 3.0, `createDataFrame(..., verifySchema=True)` validates `LongType` as well in PySpark. Previously, `LongType` was not verified and resulted in `None` in case the value overflows. To restore this behavior, `verifySchema` can be set to `False` to disable the validation.

- Since Spark 3.0, `Column.getItem` is fixed such that it does not call `Column.apply`. Consequently, if `Column` is used as an argument to `getItem`, the indexing operator should be used.
For example, `map_col.getItem(col('id'))` should be replaced with `map_col[col('id')]`.

## Upgrading from PySpark 2.3 to 2.4

- In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`.
Expand Down
12 changes: 5 additions & 7 deletions python/pyspark/sql/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,12 @@ def getItem(self, key):
+----+------+
| 1| value|
+----+------+
>>> df.select(df.l[0], df.d["key"]).show()
+----+------+
|l[0]|d[key]|
+----+------+
| 1| value|
+----+------+
.. versionchanged:: 3.0
If `key` is a `Column` object, the indexing operator should be used instead.
For example, `map_col.getItem(col('id'))` should be replaced with `map_col[col('id')]`.
"""
return self[key]
return _bin_op("getItem")(self, key)

@since(1.3)
def getField(self, name):
Expand Down
14 changes: 13 additions & 1 deletion python/pyspark/sql/tests/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import sys

from py4j.protocol import Py4JJavaError

from pyspark.sql import Column, Row
from pyspark.sql.types import *
from pyspark.sql.utils import AnalysisException
Expand Down Expand Up @@ -85,14 +87,24 @@ def test_column_operators(self):
"Cannot apply 'in' operator against a column",
lambda: 1 in cs)

def test_column_getitem(self):
def test_column_apply(self):
from pyspark.sql.functions import col

self.assertIsInstance(col("foo")[1:3], Column)
self.assertIsInstance(col("foo")[0], Column)
self.assertIsInstance(col("foo")["bar"], Column)
self.assertRaises(ValueError, lambda: col("foo")[0:10:2])

def test_column_getitem(self):
from pyspark.sql.functions import col, create_map, lit

map_col = create_map(lit(0), lit(100), lit(1), lit(200))
self.assertRaisesRegexp(
Py4JJavaError,
"Unsupported literal type class org.apache.spark.sql.Column id",
lambda: map_col.getItem(col('id'))
)

def test_column_select(self):
df = self.df
self.assertEqual(self.testData, df.select("*").collect())
Expand Down

0 comments on commit 3175f4b

Please sign in to comment.