diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index a703742fa93ff..cc1e2e2201884 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -125,7 +125,7 @@ def lastProgress(self) -> Optional[StreamingQueryProgress]: cmd.last_progress = True progress = self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json if len(progress) > 0: - return StreamingQueryProgress.fromJson(json.loads(progress)) + return StreamingQueryProgress.fromJson(json.loads(progress[-1])) else: return None diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 5a620d7172b84..86c8bc0258330 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -502,6 +502,10 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": def __getitem__(self, key): return getattr(self, key) + def __setitem__(self, key, value): + internal_key = "_" + key + setattr(self, internal_key, value) + @property def id(self) -> uuid.UUID: """ @@ -655,6 +659,9 @@ def prettyJson(self) -> str: def __str__(self) -> str: return self.prettyJson + def __repr__(self) -> str: + return self.prettyJson + class StateOperatorProgress: """ diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index 3c5006fd3a247..9c47c895ae41e 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -283,7 +283,7 @@ def recentProgress(self) -> List[StreamingQueryProgress]: >>> sq.stop() """ - return [StreamingQueryProgress.fromJson(json.loads(p)) for p in self._jsq.recentProgress()] + return [StreamingQueryProgress.fromJObject(p) for p in self._jsq.recentProgress()] @property def lastProgress(self) -> Optional[StreamingQueryProgress]: @@ -314,7 +314,7 @@ def lastProgress(self) -> Optional[StreamingQueryProgress]: """ lastProgress = self._jsq.lastProgress() if lastProgress: - return StreamingQueryProgress.fromJson(json.loads(lastProgress.json())) + return StreamingQueryProgress.fromJObject(lastProgress) else: return None diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py index 1799f0d1336e5..90cc6b3d29f8d 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming.py +++ b/python/pyspark/sql/tests/streaming/test_streaming.py @@ -28,7 +28,7 @@ class StreamingTestsMixin: def test_streaming_query_functions_basic(self): - df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") query = ( df.writeStream.format("memory") .queryName("test_streaming_query_functions_basic") @@ -46,6 +46,7 @@ def test_streaming_query_functions_basic(self): lastProgress = query.lastProgress self.assertEqual(lastProgress["name"], query.name) self.assertEqual(lastProgress["id"], query.id) + # SPARK-48567 Use attribute to access progress self.assertTrue(any(p == lastProgress for p in recentProgress)) query.explain() @@ -58,6 +59,31 @@ def test_streaming_query_functions_basic(self): finally: query.stop() + def test_streaming_progress(self): + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + query = df.writeStream.format("noop").start() + try: + query.processAllAvailable() + recentProgress = query.recentProgress + lastProgress = query.lastProgress + self.assertEqual(lastProgress["name"], query.name) + self.assertEqual(lastProgress["id"], query.id) + # SPARK-48567 Use attribute to access fields in q.lastProgress + self.assertEqual(lastProgress.name, query.name) + self.assertEqual(lastProgress.id, query.id) + new_name = "myNewQuery" + lastProgress["name"] = new_name + self.assertEqual(lastProgress.name, new_name) + self.assertTrue(any(p == lastProgress for p in recentProgress)) + + except Exception as e: + self.fail( + "Streaming query functions sanity check shouldn't throw any error. " + "Error message: " + str(e) + ) + finally: + query.stop() + def test_stream_trigger(self): df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 15f5575d36479..adbad6480a6cd 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -227,9 +227,9 @@ def onQueryTerminated(self, event): "my_event", count(lit(1)).alias("rc"), count(col("error")).alias("erc") ) - q = observed_ds.writeStream.format("console").start() + q = observed_ds.writeStream.format("noop").start() - while q.lastProgress is None or q.lastProgress["batchId"] == 0: + while q.lastProgress is None or q.lastProgress.batchId == 0: q.awaitTermination(0.5) time.sleep(5) @@ -241,6 +241,35 @@ def onQueryTerminated(self, event): q.stop() self.spark.streams.removeListener(error_listener) + def test_streaming_progress(self): + try: + # Test a fancier query with stateful operation and observed metrics + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() + df_observe = df.observe("my_event", count(lit(1)).alias("rc")) + df_stateful = df_observe.groupBy().count() # make query stateful + q = ( + df_stateful.writeStream.format("noop") + .queryName("test") + .outputMode("update") + .trigger(processingTime="5 seconds") + .start() + ) + + while q.lastProgress is None or q.lastProgress.batchId == 0: + q.awaitTermination(0.5) + + q.stop() + + self.check_streaming_query_progress(q.lastProgress, True) + for p in q.recentProgress: + self.check_streaming_query_progress(p, True) + + row = q.lastProgress.observedMetrics.get("my_event") + self.assertTrue(row["rc"] > 0) + self.assertTrue(row["erc"] > 0) + finally: + q.stop() + class StreamingListenerTests(StreamingListenerTestsMixin, ReusedSQLTestCase): def test_number_of_public_methods(self): @@ -355,7 +384,7 @@ def verify(test_listener): .start() ) self.assertTrue(q.isActive) - time.sleep(10) + q.awaitTermination(10) q.stop() # Make sure all events are empty