Skip to content

Commit

Permalink
[SPARK-49691][PYTHON][CONNECT] Function substring should accept col…
Browse files Browse the repository at this point in the history
…umn names

### What changes were proposed in this pull request?
Function `substring` should accept column names

### Why are the changes needed?
Bug fix:

```
In [1]:     >>> import pyspark.sql.functions as sf
   ...:     >>> df = spark.createDataFrame([('Spark', 2, 3)], ['s', 'p', 'l'])
   ...:     >>> df.select('*', sf.substring('s', 'p', 'l')).show()
```

works in PySpark Classic, but fail in Connect with:
```
NumberFormatException                     Traceback (most recent call last)
Cell In[2], line 1
----> 1 df.select('*', sf.substring('s', 'p', 'l')).show()

File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:1170, in DataFrame.show(self, n, truncate, vertical)
   1169 def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None:
-> 1170     print(self._show_string(n, truncate, vertical))

File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:927, in DataFrame._show_string(self, n, truncate, vertical)
    910     except ValueError:
    911         raise PySparkTypeError(
    912             errorClass="NOT_BOOL",
    913             messageParameters={
   (...)
    916             },
    917         )
    919 table, _ = DataFrame(
    920     plan.ShowString(
    921         child=self._plan,
    922         num_rows=n,
    923         truncate=_truncate,
    924         vertical=vertical,
    925     ),
    926     session=self._session,
--> 927 )._to_table()
    928 return table[0][0].as_py()

File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:1844, in DataFrame._to_table(self)
   1842 def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]:
   1843     query = self._plan.to_proto(self._session.client)
-> 1844     table, schema, self._execution_info = self._session.client.to_table(
   1845         query, self._plan.observations
   1846     )
   1847     assert table is not None
   1848     return (table, schema)

File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:892, in SparkConnectClient.to_table(self, plan, observations)
    890 req = self._execute_plan_request_with_metadata()
    891 req.plan.CopyFrom(plan)
--> 892 table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(req, observations)
    894 # Create a query execution object.
    895 ei = ExecutionInfo(metrics, observed_metrics)

File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1517, in SparkConnectClient._execute_and_fetch(self, req, observations, self_destruct)
   1514 properties: Dict[str, Any] = {}
   1516 with Progress(handlers=self._progress_handlers, operation_id=req.operation_id) as progress:
-> 1517     for response in self._execute_and_fetch_as_iterator(
   1518         req, observations, progress=progress
   1519     ):
   1520         if isinstance(response, StructType):
   1521             schema = response

File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1494, in SparkConnectClient._execute_and_fetch_as_iterator(self, req, observations, progress)
   1492     raise kb
   1493 except Exception as error:
-> 1494     self._handle_error(error)

File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1764, in SparkConnectClient._handle_error(self, error)
   1762 self.thread_local.inside_error_handling = True
   1763 if isinstance(error, grpc.RpcError):
-> 1764     self._handle_rpc_error(error)
   1765 elif isinstance(error, ValueError):
   1766     if "Cannot invoke RPC" in str(error) and "closed" in str(error):

File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1840, in SparkConnectClient._handle_rpc_error(self, rpc_error)
   1837             if info.metadata["errorClass"] == "INVALID_HANDLE.SESSION_CHANGED":
   1838                 self._closed = True
-> 1840             raise convert_exception(
   1841                 info,
   1842                 status.message,
   1843                 self._fetch_enriched_error(info),
   1844                 self._display_server_stack_trace(),
   1845             ) from None
   1847     raise SparkConnectGrpcException(status.message) from None
   1848 else:

NumberFormatException: [CAST_INVALID_INPUT] The value 'p' of the type "STRING" cannot be cast to "INT" because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. SQLSTATE: 22018
...
```

### Does this PR introduce _any_ user-facing change?
yes, Function `substring` in Connect can properly handle column names

### How was this patch tested?
new doctests

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#48135 from zhengruifeng/py_substring_fix.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
zhengruifeng authored and dongjoon-hyun committed Sep 18, 2024
1 parent b86e5d2 commit ed3a9b1
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 11 deletions.
10 changes: 8 additions & 2 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2488,8 +2488,14 @@ def sentences(
sentences.__doc__ = pysparkfuncs.sentences.__doc__


def substring(str: "ColumnOrName", pos: int, len: int) -> Column:
return _invoke_function("substring", _to_col(str), lit(pos), lit(len))
def substring(
str: "ColumnOrName",
pos: Union["ColumnOrName", int],
len: Union["ColumnOrName", int],
) -> Column:
_pos = lit(pos) if isinstance(pos, int) else _to_col(pos)
_len = lit(len) if isinstance(len, int) else _to_col(len)
return _invoke_function("substring", _to_col(str), _pos, _len)


substring.__doc__ = pysparkfuncs.substring.__doc__
Expand Down
63 changes: 54 additions & 9 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11309,7 +11309,9 @@ def sentences(

@_try_remote_functions
def substring(
str: "ColumnOrName", pos: Union["ColumnOrName", int], len: Union["ColumnOrName", int]
str: "ColumnOrName",
pos: Union["ColumnOrName", int],
len: Union["ColumnOrName", int],
) -> Column:
"""
Substring starts at `pos` and is of length `len` when str is String type or
Expand Down Expand Up @@ -11348,16 +11350,59 @@ def substring(

Examples
--------
Example 1: Using literal integers as arguments

>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([('abcd',)], ['s',])
>>> df.select(substring(df.s, 1, 2).alias('s')).collect()
[Row(s='ab')]
>>> df.select('*', sf.substring(df.s, 1, 2)).show()
+----+------------------+
| s|substring(s, 1, 2)|
+----+------------------+
|abcd| ab|
+----+------------------+

Example 2: Using columns as arguments

>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([('Spark', 2, 3)], ['s', 'p', 'l'])
>>> df.select('*', sf.substring(df.s, 2, df.l)).show()
+-----+---+---+------------------+
| s| p| l|substring(s, 2, l)|
+-----+---+---+------------------+
|Spark| 2| 3| par|
+-----+---+---+------------------+

>>> df.select('*', sf.substring(df.s, df.p, 3)).show()
+-----+---+---+------------------+
| s| p| l|substring(s, p, 3)|
+-----+---+---+------------------+
|Spark| 2| 3| par|
+-----+---+---+------------------+

>>> df.select('*', sf.substring(df.s, df.p, df.l)).show()
+-----+---+---+------------------+
| s| p| l|substring(s, p, l)|
+-----+---+---+------------------+
|Spark| 2| 3| par|
+-----+---+---+------------------+

Example 3: Using column names as arguments

>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([('Spark', 2, 3)], ['s', 'p', 'l'])
>>> df.select(substring(df.s, 2, df.l).alias('s')).collect()
[Row(s='par')]
>>> df.select(substring(df.s, df.p, 3).alias('s')).collect()
[Row(s='par')]
>>> df.select(substring(df.s, df.p, df.l).alias('s')).collect()
[Row(s='par')]
>>> df.select('*', sf.substring(df.s, 2, 'l')).show()
+-----+---+---+------------------+
| s| p| l|substring(s, 2, l)|
+-----+---+---+------------------+
|Spark| 2| 3| par|
+-----+---+---+------------------+

>>> df.select('*', sf.substring('s', 'p', 'l')).show()
+-----+---+---+------------------+
| s| p| l|substring(s, p, l)|
+-----+---+---+------------------+
|Spark| 2| 3| par|
+-----+---+---+------------------+
"""
pos = _enum_to_value(pos)
pos = lit(pos) if isinstance(pos, int) else pos
Expand Down

0 comments on commit ed3a9b1

Please sign in to comment.