From ef7a116828d1781dde35d422550ff369c548a6aa Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Fri, 7 Jun 2024 20:15:32 -0700 Subject: [PATCH] ready for CI, pending query.name fix --- python/pyspark/sql/streaming/listener.py | 8 +++++++- python/pyspark/sql/tests/streaming/test_streaming.py | 2 +- .../sql/tests/streaming/test_streaming_listener.py | 5 +---- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 86c8bc0258330..a69dc059b47f3 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -500,7 +500,13 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": ) def __getitem__(self, key): - return getattr(self, key) + # Before Spark 4.0, StreamingQuery.lastProgress returns a dict, which casts id and runId + # to string. To prevent breaking change, also cast them to string when accessed with + # __getitem__. + if key == "id" or key == "runId": + return str(getattr(self, key)) + else: + return getattr(self, key) def __setitem__(self, key, value): internal_key = "_" + key diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py index 90cc6b3d29f8d..61f595d44c4de 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming.py +++ b/python/pyspark/sql/tests/streaming/test_streaming.py @@ -70,7 +70,7 @@ def test_streaming_progress(self): self.assertEqual(lastProgress["id"], query.id) # SPARK-48567 Use attribute to access fields in q.lastProgress self.assertEqual(lastProgress.name, query.name) - self.assertEqual(lastProgress.id, query.id) + self.assertEqual(str(lastProgress.id), query.id) new_name = "myNewQuery" lastProgress["name"] = new_name self.assertEqual(lastProgress.name, new_name) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index adbad6480a6cd..9276267322c36 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -121,7 +121,7 @@ def check_streaming_query_progress(self, progress, is_stateful): self.assertTrue(isinstance(progress.sink, SinkProgress)) self.check_sink_progress(progress.sink) - self.assertTrue(isinstance(progress.observedMetrics, dict)) + self.assertTrue(isinstance(progress.observedMetrics, Row)) def check_state_operator_progress(self, progress): """Check StateOperatorProgress""" @@ -264,9 +264,6 @@ def test_streaming_progress(self): for p in q.recentProgress: self.check_streaming_query_progress(p, True) - row = q.lastProgress.observedMetrics.get("my_event") - self.assertTrue(row["rc"] > 0) - self.assertTrue(row["erc"] > 0) finally: q.stop()