From d067fc6c1635dfe7730223021e912e78637bb791 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 19 Jun 2024 12:15:02 +0900 Subject: [PATCH] Revert "[SPARK-48567][SS] StreamingQuery.lastProgress should return the actual StreamingQueryProgress" This reverts commit 042804ad545c88afe69c149b25baea00fc213708. --- python/pyspark/sql/connect/streaming/query.py | 9 +- python/pyspark/sql/streaming/listener.py | 228 +++++++----------- python/pyspark/sql/streaming/query.py | 13 +- .../sql/tests/streaming/test_streaming.py | 44 +--- .../streaming/test_streaming_listener.py | 32 +-- 5 files changed, 99 insertions(+), 227 deletions(-) diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index cc1e2e2201884..98ecdc4966c75 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -33,7 +33,6 @@ QueryProgressEvent, QueryIdleEvent, QueryTerminatedEvent, - StreamingQueryProgress, ) from pyspark.sql.streaming.query import ( StreamingQuery as PySparkStreamingQuery, @@ -111,21 +110,21 @@ def status(self) -> Dict[str, Any]: status.__doc__ = PySparkStreamingQuery.status.__doc__ @property - def recentProgress(self) -> List[StreamingQueryProgress]: + def recentProgress(self) -> List[Dict[str, Any]]: cmd = pb2.StreamingQueryCommand() cmd.recent_progress = True progress = self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json - return [StreamingQueryProgress.fromJson(json.loads(p)) for p in progress] + return [json.loads(p) for p in progress] recentProgress.__doc__ = PySparkStreamingQuery.recentProgress.__doc__ @property - def lastProgress(self) -> Optional[StreamingQueryProgress]: + def lastProgress(self) -> Optional[Dict[str, Any]]: cmd = pb2.StreamingQueryCommand() cmd.last_progress = True progress = self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json if len(progress) > 0: - return StreamingQueryProgress.fromJson(json.loads(progress[-1])) + return json.loads(progress[-1]) else: return None diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 6cc2cc3fa2b86..2aa63cdb91ab6 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -397,13 +397,10 @@ def errorClassOnException(self) -> Optional[str]: return self._errorClassOnException -class StreamingQueryProgress(dict): +class StreamingQueryProgress: """ .. versionadded:: 3.4.0 - .. versionchanged:: 4.0.0 - Becomes a subclass of dict - Notes ----- This API is evolving. @@ -429,25 +426,23 @@ def __init__( jprogress: Optional["JavaObject"] = None, jdict: Optional[Dict[str, Any]] = None, ): - super().__init__( - id=id, - runId=runId, - name=name, - timestamp=timestamp, - batchId=batchId, - batchDuration=batchDuration, - durationMs=durationMs, - eventTime=eventTime, - stateOperators=stateOperators, - sources=sources, - sink=sink, - numInputRows=numInputRows, - inputRowsPerSecond=inputRowsPerSecond, - processedRowsPerSecond=processedRowsPerSecond, - observedMetrics=observedMetrics, - ) self._jprogress: Optional["JavaObject"] = jprogress self._jdict: Optional[Dict[str, Any]] = jdict + self._id: uuid.UUID = id + self._runId: uuid.UUID = runId + self._name: Optional[str] = name + self._timestamp: str = timestamp + self._batchId: int = batchId + self._batchDuration: int = batchDuration + self._durationMs: Dict[str, int] = durationMs + self._eventTime: Dict[str, str] = eventTime + self._stateOperators: List[StateOperatorProgress] = stateOperators + self._sources: List[SourceProgress] = sources + self._sink: SinkProgress = sink + self._numInputRows: int = numInputRows + self._inputRowsPerSecond: float = inputRowsPerSecond + self._processedRowsPerSecond: float = processedRowsPerSecond + self._observedMetrics: Dict[str, Row] = observedMetrics @classmethod def fromJObject(cls, jprogress: "JavaObject") -> "StreamingQueryProgress": @@ -494,11 +489,9 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": stateOperators=[StateOperatorProgress.fromJson(s) for s in j["stateOperators"]], sources=[SourceProgress.fromJson(s) for s in j["sources"]], sink=SinkProgress.fromJson(j["sink"]), - numInputRows=j["numInputRows"] if "numInputRows" in j else None, - inputRowsPerSecond=j["inputRowsPerSecond"] if "inputRowsPerSecond" in j else None, - processedRowsPerSecond=j["processedRowsPerSecond"] - if "processedRowsPerSecond" in j - else None, + numInputRows=j["numInputRows"], + inputRowsPerSecond=j["inputRowsPerSecond"], + processedRowsPerSecond=j["processedRowsPerSecond"], observedMetrics={ k: Row(*row_dict.keys())(*row_dict.values()) # Assume no nested rows for k, row_dict in j["observedMetrics"].items() @@ -513,10 +506,7 @@ def id(self) -> uuid.UUID: A unique query id that persists across restarts. See py:meth:`~pyspark.sql.streaming.StreamingQuery.id`. """ - # Before Spark 4.0, StreamingQuery.lastProgress returns a dict, which casts id and runId - # to string. But here they are UUID. - # To prevent breaking change, do not cast them to string when accessed with attribute. - return super().__getitem__("id") + return self._id @property def runId(self) -> uuid.UUID: @@ -524,24 +514,21 @@ def runId(self) -> uuid.UUID: A query id that is unique for every start/restart. See py:meth:`~pyspark.sql.streaming.StreamingQuery.runId`. """ - # Before Spark 4.0, StreamingQuery.lastProgress returns a dict, which casts id and runId - # to string. But here they are UUID. - # To prevent breaking change, do not cast them to string when accessed with attribute. - return super().__getitem__("runId") + return self._runId @property def name(self) -> Optional[str]: """ User-specified name of the query, `None` if not specified. """ - return self["name"] + return self._name @property def timestamp(self) -> str: """ The timestamp to start a query. """ - return self["timestamp"] + return self._timestamp @property def batchId(self) -> int: @@ -551,21 +538,21 @@ def batchId(self) -> int: Similarly, when there is no data to be processed, the batchId will not be incremented. """ - return self["batchId"] + return self._batchId @property def batchDuration(self) -> int: """ The process duration of each batch. """ - return self["batchDuration"] + return self._batchDuration @property def durationMs(self) -> Dict[str, int]: """ The amount of time taken to perform various operations in milliseconds. """ - return self["durationMs"] + return self._durationMs @property def eventTime(self) -> Dict[str, str]: @@ -583,21 +570,21 @@ def eventTime(self) -> Dict[str, str]: All timestamps are in ISO8601 format, i.e. UTC timestamps. """ - return self["eventTime"] + return self._eventTime @property def stateOperators(self) -> List["StateOperatorProgress"]: """ Information about operators in the query that store state. """ - return self["stateOperators"] + return self._stateOperators @property def sources(self) -> List["SourceProgress"]: """ detailed statistics on data being read from each of the streaming sources. """ - return self["sources"] + return self._sources @property def sink(self) -> "SinkProgress": @@ -605,41 +592,32 @@ def sink(self) -> "SinkProgress": A unique query id that persists across restarts. See py:meth:`~pyspark.sql.streaming.StreamingQuery.id`. """ - return self["sink"] + return self._sink @property def observedMetrics(self) -> Dict[str, Row]: - return self["observedMetrics"] + return self._observedMetrics @property def numInputRows(self) -> int: """ The aggregate (across all sources) number of records processed in a trigger. """ - if self["numInputRows"] is not None: - return self["numInputRows"] - else: - return sum(s.numInputRows for s in self.sources) + return self._numInputRows @property def inputRowsPerSecond(self) -> float: """ The aggregate (across all sources) rate of data arriving. """ - if self["inputRowsPerSecond"] is not None: - return self["inputRowsPerSecond"] - else: - return sum(s.inputRowsPerSecond for s in self.sources) + return self._inputRowsPerSecond @property def processedRowsPerSecond(self) -> float: """ The aggregate (across all sources) rate at which Spark is processing data. """ - if self["processedRowsPerSecond"] is not None: - return self["processedRowsPerSecond"] - else: - return sum(s.processedRowsPerSecond for s in self.sources) + return self._processedRowsPerSecond @property def json(self) -> str: @@ -663,29 +641,14 @@ def prettyJson(self) -> str: else: return json.dumps(self._jdict, indent=4) - def __getitem__(self, key: str) -> Any: - # Before Spark 4.0, StreamingQuery.lastProgress returns a dict, which casts id and runId - # to string. But here they are UUID. - # To prevent breaking change, also cast them to string when accessed with __getitem__. - if key == "id" or key == "runId": - return str(super().__getitem__(key)) - else: - return super().__getitem__(key) - def __str__(self) -> str: return self.prettyJson - def __repr__(self) -> str: - return self.prettyJson - -class StateOperatorProgress(dict): +class StateOperatorProgress: """ .. versionadded:: 3.4.0 - .. versionchanged:: 4.0.0 - Becomes a subclass of dict - Notes ----- This API is evolving. @@ -708,22 +671,20 @@ def __init__( jprogress: Optional["JavaObject"] = None, jdict: Optional[Dict[str, Any]] = None, ): - super().__init__( - operatorName=operatorName, - numRowsTotal=numRowsTotal, - numRowsUpdated=numRowsUpdated, - numRowsRemoved=numRowsRemoved, - allUpdatesTimeMs=allUpdatesTimeMs, - allRemovalsTimeMs=allRemovalsTimeMs, - commitTimeMs=commitTimeMs, - memoryUsedBytes=memoryUsedBytes, - numRowsDroppedByWatermark=numRowsDroppedByWatermark, - numShufflePartitions=numShufflePartitions, - numStateStoreInstances=numStateStoreInstances, - customMetrics=customMetrics, - ) self._jprogress: Optional["JavaObject"] = jprogress self._jdict: Optional[Dict[str, Any]] = jdict + self._operatorName: str = operatorName + self._numRowsTotal: int = numRowsTotal + self._numRowsUpdated: int = numRowsUpdated + self._numRowsRemoved: int = numRowsRemoved + self._allUpdatesTimeMs: int = allUpdatesTimeMs + self._allRemovalsTimeMs: int = allRemovalsTimeMs + self._commitTimeMs: int = commitTimeMs + self._memoryUsedBytes: int = memoryUsedBytes + self._numRowsDroppedByWatermark: int = numRowsDroppedByWatermark + self._numShufflePartitions: int = numShufflePartitions + self._numStateStoreInstances: int = numStateStoreInstances + self._customMetrics: Dict[str, int] = customMetrics @classmethod def fromJObject(cls, jprogress: "JavaObject") -> "StateOperatorProgress": @@ -763,51 +724,51 @@ def fromJson(cls, j: Dict[str, Any]) -> "StateOperatorProgress": @property def operatorName(self) -> str: - return self["operatorName"] + return self._operatorName @property def numRowsTotal(self) -> int: - return self["numRowsTotal"] + return self._numRowsTotal @property def numRowsUpdated(self) -> int: - return self["numRowsUpdated"] + return self._numRowsUpdated @property def allUpdatesTimeMs(self) -> int: - return self["allUpdatesTimeMs"] + return self._allUpdatesTimeMs @property def numRowsRemoved(self) -> int: - return self["numRowsRemoved"] + return self._numRowsRemoved @property def allRemovalsTimeMs(self) -> int: - return self["allRemovalsTimeMs"] + return self._allRemovalsTimeMs @property def commitTimeMs(self) -> int: - return self["commitTimeMs"] + return self._commitTimeMs @property def memoryUsedBytes(self) -> int: - return self["memoryUsedBytes"] + return self._memoryUsedBytes @property def numRowsDroppedByWatermark(self) -> int: - return self["numRowsDroppedByWatermark"] + return self._numRowsDroppedByWatermark @property def numShufflePartitions(self) -> int: - return self["numShufflePartitions"] + return self._numShufflePartitions @property def numStateStoreInstances(self) -> int: - return self["numStateStoreInstances"] + return self._numStateStoreInstances @property - def customMetrics(self) -> dict: - return self["customMetrics"] + def customMetrics(self) -> Dict[str, int]: + return self._customMetrics @property def json(self) -> str: @@ -834,17 +795,11 @@ def prettyJson(self) -> str: def __str__(self) -> str: return self.prettyJson - def __repr__(self) -> str: - return self.prettyJson - -class SourceProgress(dict): +class SourceProgress: """ .. versionadded:: 3.4.0 - .. versionchanged:: 4.0.0 - Becomes a subclass of dict - Notes ----- This API is evolving. @@ -863,18 +818,16 @@ def __init__( jprogress: Optional["JavaObject"] = None, jdict: Optional[Dict[str, Any]] = None, ) -> None: - super().__init__( - description=description, - startOffset=startOffset, - endOffset=endOffset, - latestOffset=latestOffset, - numInputRows=numInputRows, - inputRowsPerSecond=inputRowsPerSecond, - processedRowsPerSecond=processedRowsPerSecond, - metrics=metrics, - ) self._jprogress: Optional["JavaObject"] = jprogress self._jdict: Optional[Dict[str, Any]] = jdict + self._description: str = description + self._startOffset: str = startOffset + self._endOffset: str = endOffset + self._latestOffset: str = latestOffset + self._numInputRows: int = numInputRows + self._inputRowsPerSecond: float = inputRowsPerSecond + self._processedRowsPerSecond: float = processedRowsPerSecond + self._metrics: Dict[str, str] = metrics @classmethod def fromJObject(cls, jprogress: "JavaObject") -> "SourceProgress": @@ -909,53 +862,53 @@ def description(self) -> str: """ Description of the source. """ - return self["description"] + return self._description @property def startOffset(self) -> str: """ The starting offset for data being read. """ - return self["startOffset"] + return self._startOffset @property def endOffset(self) -> str: """ The ending offset for data being read. """ - return self["endOffset"] + return self._endOffset @property def latestOffset(self) -> str: """ The latest offset from this source. """ - return self["latestOffset"] + return self._latestOffset @property def numInputRows(self) -> int: """ The number of records read from this source. """ - return self["numInputRows"] + return self._numInputRows @property def inputRowsPerSecond(self) -> float: """ The rate at which data is arriving from this source. """ - return self["inputRowsPerSecond"] + return self._inputRowsPerSecond @property def processedRowsPerSecond(self) -> float: """ The rate at which data from this source is being processed by Spark. """ - return self["processedRowsPerSecond"] + return self._processedRowsPerSecond @property - def metrics(self) -> dict: - return self["metrics"] + def metrics(self) -> Dict[str, str]: + return self._metrics @property def json(self) -> str: @@ -982,17 +935,11 @@ def prettyJson(self) -> str: def __str__(self) -> str: return self.prettyJson - def __repr__(self) -> str: - return self.prettyJson - -class SinkProgress(dict): +class SinkProgress: """ .. versionadded:: 3.4.0 - .. versionchanged:: 4.0.0 - Becomes a subclass of dict - Notes ----- This API is evolving. @@ -1006,13 +953,11 @@ def __init__( jprogress: Optional["JavaObject"] = None, jdict: Optional[Dict[str, Any]] = None, ) -> None: - super().__init__( - description=description, - numOutputRows=numOutputRows, - metrics=metrics, - ) self._jprogress: Optional["JavaObject"] = jprogress self._jdict: Optional[Dict[str, Any]] = jdict + self._description: str = description + self._numOutputRows: int = numOutputRows + self._metrics: Dict[str, str] = metrics @classmethod def fromJObject(cls, jprogress: "JavaObject") -> "SinkProgress": @@ -1037,7 +982,7 @@ def description(self) -> str: """ Description of the source. """ - return self["description"] + return self._description @property def numOutputRows(self) -> int: @@ -1045,11 +990,11 @@ def numOutputRows(self) -> int: Number of rows written to the sink or -1 for Continuous Mode (temporarily) or Sink V1 (until decommissioned). """ - return self["numOutputRows"] + return self._numOutputRows @property def metrics(self) -> Dict[str, str]: - return self["metrics"] + return self._metrics @property def json(self) -> str: @@ -1076,9 +1021,6 @@ def prettyJson(self) -> str: def __str__(self) -> str: return self.prettyJson - def __repr__(self) -> str: - return self.prettyJson - def _test() -> None: import sys diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index 916f96a5b2c2f..d3d58da3562b6 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -22,10 +22,7 @@ from pyspark.errors.exceptions.captured import ( StreamingQueryException as CapturedStreamingQueryException, ) -from pyspark.sql.streaming.listener import ( - StreamingQueryListener, - StreamingQueryProgress, -) +from pyspark.sql.streaming.listener import StreamingQueryListener if TYPE_CHECKING: from py4j.java_gateway import JavaObject @@ -254,7 +251,7 @@ def status(self) -> Dict[str, Any]: return json.loads(self._jsq.status().json()) @property - def recentProgress(self) -> List[StreamingQueryProgress]: + def recentProgress(self) -> List[Dict[str, Any]]: """ Returns an array of the most recent [[StreamingQueryProgress]] updates for this query. The number of progress updates retained for each stream is configured by Spark session @@ -283,10 +280,10 @@ def recentProgress(self) -> List[StreamingQueryProgress]: >>> sq.stop() """ - return [StreamingQueryProgress.fromJObject(p) for p in self._jsq.recentProgress()] + return [json.loads(p.json()) for p in self._jsq.recentProgress()] @property - def lastProgress(self) -> Optional[StreamingQueryProgress]: + def lastProgress(self) -> Optional[Dict[str, Any]]: """ Returns the most recent :class:`StreamingQueryProgress` update of this streaming query or None if there were no progress updates @@ -314,7 +311,7 @@ def lastProgress(self) -> Optional[StreamingQueryProgress]: """ lastProgress = self._jsq.lastProgress() if lastProgress: - return StreamingQueryProgress.fromJObject(lastProgress) + return json.loads(lastProgress.json()) else: return None diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py index 00d1fbf538850..e284d052d9ae2 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming.py +++ b/python/pyspark/sql/tests/streaming/test_streaming.py @@ -29,7 +29,7 @@ class StreamingTestsMixin: def test_streaming_query_functions_basic(self): - df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() query = ( df.writeStream.format("memory") .queryName("test_streaming_query_functions_basic") @@ -43,8 +43,8 @@ def test_streaming_query_functions_basic(self): self.assertEqual(query.exception(), None) self.assertFalse(query.awaitTermination(1)) query.processAllAvailable() - lastProgress = query.lastProgress recentProgress = query.recentProgress + lastProgress = query.lastProgress self.assertEqual(lastProgress["name"], query.name) self.assertEqual(lastProgress["id"], query.id) self.assertTrue(any(p == lastProgress for p in recentProgress)) @@ -59,46 +59,6 @@ def test_streaming_query_functions_basic(self): finally: query.stop() - def test_streaming_progress(self): - """ - Should be able to access fields using attributes in lastProgress / recentProgress - e.g. q.lastProgress.id - """ - df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") - query = df.writeStream.format("noop").start() - try: - query.processAllAvailable() - lastProgress = query.lastProgress - recentProgress = query.recentProgress - self.assertEqual(lastProgress["name"], query.name) - # Return str when accessed using dict get. - self.assertEqual(lastProgress["id"], query.id) - # SPARK-48567 Use attribute to access fields in q.lastProgress - self.assertEqual(lastProgress.name, query.name) - # Return uuid when accessed using attribute. - self.assertEqual(str(lastProgress.id), query.id) - self.assertTrue(any(p == lastProgress for p in recentProgress)) - self.assertTrue(lastProgress.numInputRows > 0) - # Also access source / sink progress with attributes - self.assertTrue(len(lastProgress.sources) > 0) - self.assertTrue(lastProgress.sources[0].numInputRows > 0) - self.assertTrue(lastProgress["sources"][0]["numInputRows"] > 0) - self.assertTrue(lastProgress.sink.numOutputRows > 0) - self.assertTrue(lastProgress["sink"]["numOutputRows"] > 0) - # In Python, for historical reasons, changing field value - # in StreamingQueryProgress is allowed. - new_name = "myNewQuery" - lastProgress["name"] = new_name - self.assertEqual(lastProgress.name, new_name) - - except Exception as e: - self.fail( - "Streaming query functions sanity check shouldn't throw any error. " - "Error message: " + str(e) - ) - finally: - query.stop() - def test_streaming_query_name_edge_case(self): # Query name should be None when not specified q1 = self.spark.readStream.format("rate").load().writeStream.format("noop").start() diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 0f13450849c57..762fc335b56ad 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -227,9 +227,9 @@ def onQueryTerminated(self, event): "my_event", count(lit(1)).alias("rc"), count(col("error")).alias("erc") ) - q = observed_ds.writeStream.format("noop").start() + q = observed_ds.writeStream.format("console").start() - while q.lastProgress is None or q.lastProgress.batchId == 0: + while q.lastProgress is None or q.lastProgress["batchId"] == 0: q.awaitTermination(0.5) time.sleep(5) @@ -241,32 +241,6 @@ def onQueryTerminated(self, event): q.stop() self.spark.streams.removeListener(error_listener) - def test_streaming_progress(self): - try: - # Test a fancier query with stateful operation and observed metrics - df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() - df_observe = df.observe("my_event", count(lit(1)).alias("rc")) - df_stateful = df_observe.groupBy().count() # make query stateful - q = ( - df_stateful.writeStream.format("noop") - .queryName("test") - .outputMode("update") - .trigger(processingTime="5 seconds") - .start() - ) - - while q.lastProgress is None or q.lastProgress.batchId == 0: - q.awaitTermination(0.5) - - q.stop() - - self.check_streaming_query_progress(q.lastProgress, True) - for p in q.recentProgress: - self.check_streaming_query_progress(p, True) - - finally: - q.stop() - class StreamingListenerTests(StreamingListenerTestsMixin, ReusedSQLTestCase): def test_number_of_public_methods(self): @@ -381,7 +355,7 @@ def verify(test_listener): .start() ) self.assertTrue(q.isActive) - q.awaitTermination(10) + time.sleep(10) q.stop() # Make sure all events are empty