Skip to content

Commit

Permalink
[SPARK-48569][SS][CONNECT] Handle edge cases in query.name
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

1. In connect, when a streaming query name is not specified, it's query.name should return None. Currently it returns an empty string without this patch.
2. In classic spark, one cannot set the streaming query's name to be empty string. This check was missing in Spark Connect. Adding it back.

### Why are the changes needed?

Edge case handling.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Added unit test.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #46920 from WweiL/SPARK-48569-query-name-None.

Authored-by: Wei Liu <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
WweiL authored and HyukjinKwon committed Jun 10, 2024
1 parent 3857a9d commit ec6db63
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
9 changes: 8 additions & 1 deletion python/pyspark/sql/connect/streaming/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,11 @@ def partitionBy(self, *cols: str) -> "DataStreamWriter": # type: ignore[misc]
partitionBy.__doc__ = PySparkDataStreamWriter.partitionBy.__doc__

def queryName(self, queryName: str) -> "DataStreamWriter":
if not queryName or type(queryName) != str or len(queryName.strip()) == 0:
raise PySparkValueError(
error_class="VALUE_NOT_NON_EMPTY_STR",
message_parameters={"arg_name": "queryName", "arg_value": str(queryName)},
)
self._write_proto.query_name = queryName
return self

Expand Down Expand Up @@ -605,7 +610,9 @@ def _start_internal(
session=self._session,
queryId=start_result.query_id.id,
runId=start_result.query_id.run_id,
name=start_result.name,
# A Streaming Query cannot have empty string as name
# Spark throws error in that case, so this cast is safe
name=start_result.name if start_result.name != "" else None,
)

if start_result.HasField("query_started_event_json"):
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def runId(self) -> str:
@property
def name(self) -> str:
"""
Returns the user-specified name of the query, or null if not specified.
Returns the user-specified name of the query, or None if not specified.
This name can be specified in the `org.apache.spark.sql.streaming.DataStreamWriter`
as `dataframe.writeStream.queryName("query").start()`.
This name, if set, must be unique across all active queries.
Expand All @@ -127,14 +127,14 @@ def name(self) -> str:
Returns
-------
str
The user-specified name of the query, or null if not specified.
The user-specified name of the query, or None if not specified.
Examples
--------
>>> sdf = spark.readStream.format("rate").load()
>>> sq = sdf.writeStream.format('memory').queryName('this_query').start()
Get the user-specified name of the query, or null if not specified.
Get the user-specified name of the query, or None if not specified.
>>> sq.name
'this_query'
Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/sql/tests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pyspark.sql.functions import lit
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.errors import PySparkValueError


class StreamingTestsMixin:
Expand Down Expand Up @@ -58,6 +59,26 @@ def test_streaming_query_functions_basic(self):
finally:
query.stop()

def test_streaming_query_name_edge_case(self):
# Query name should be None when not specified
q1 = self.spark.readStream.format("rate").load().writeStream.format("noop").start()
self.assertEqual(q1.name, None)

# Cannot set query name to be an empty string
error_thrown = False
try:
(
self.spark.readStream.format("rate")
.load()
.writeStream.format("noop")
.queryName("")
.start()
)
except PySparkValueError:
error_thrown = True

self.assertTrue(error_thrown)

def test_stream_trigger(self):
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")

Expand Down

0 comments on commit ec6db63

Please sign in to comment.