Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
WweiL committed Jun 8, 2024
1 parent c73ebfc commit d967119
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 7 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def lastProgress(self) -> Optional[StreamingQueryProgress]:
cmd.last_progress = True
progress = self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json
if len(progress) > 0:
return StreamingQueryProgress.fromJson(json.loads(progress))
return StreamingQueryProgress.fromJson(json.loads(progress[-1]))
else:
return None

Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/sql/streaming/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,10 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress":
def __getitem__(self, key):
return getattr(self, key)

def __setitem__(self, key, value):
internal_key = "_" + key
setattr(self, internal_key, value)

@property
def id(self) -> uuid.UUID:
"""
Expand Down Expand Up @@ -655,6 +659,9 @@ def prettyJson(self) -> str:
def __str__(self) -> str:
return self.prettyJson

def __repr__(self) -> str:
return self.prettyJson


class StateOperatorProgress:
"""
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def recentProgress(self) -> List[StreamingQueryProgress]:
>>> sq.stop()
"""
return [StreamingQueryProgress.fromJson(json.loads(p)) for p in self._jsq.recentProgress()]
return [StreamingQueryProgress.fromJObject(p) for p in self._jsq.recentProgress()]

@property
def lastProgress(self) -> Optional[StreamingQueryProgress]:
Expand Down Expand Up @@ -314,7 +314,7 @@ def lastProgress(self) -> Optional[StreamingQueryProgress]:
"""
lastProgress = self._jsq.lastProgress()
if lastProgress:
return StreamingQueryProgress.fromJson(json.loads(lastProgress.json()))
return StreamingQueryProgress.fromJObject(lastProgress)
else:
return None

Expand Down
28 changes: 27 additions & 1 deletion python/pyspark/sql/tests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

class StreamingTestsMixin:
def test_streaming_query_functions_basic(self):
df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
query = (
df.writeStream.format("memory")
.queryName("test_streaming_query_functions_basic")
Expand All @@ -46,6 +46,7 @@ def test_streaming_query_functions_basic(self):
lastProgress = query.lastProgress
self.assertEqual(lastProgress["name"], query.name)
self.assertEqual(lastProgress["id"], query.id)
# SPARK-48567 Use attribute to access progress
self.assertTrue(any(p == lastProgress for p in recentProgress))
query.explain()

Expand All @@ -58,6 +59,31 @@ def test_streaming_query_functions_basic(self):
finally:
query.stop()

def test_streaming_progress(self):
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
query = df.writeStream.format("noop").start()
try:
query.processAllAvailable()
recentProgress = query.recentProgress
lastProgress = query.lastProgress
self.assertEqual(lastProgress["name"], query.name)
self.assertEqual(lastProgress["id"], query.id)
# SPARK-48567 Use attribute to access fields in q.lastProgress
self.assertEqual(lastProgress.name, query.name)
self.assertEqual(lastProgress.id, query.id)
new_name = "myNewQuery"
lastProgress["name"] = new_name
self.assertEqual(lastProgress.name, new_name)
self.assertTrue(any(p == lastProgress for p in recentProgress))

except Exception as e:
self.fail(
"Streaming query functions sanity check shouldn't throw any error. "
"Error message: " + str(e)
)
finally:
query.stop()

def test_stream_trigger(self):
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")

Expand Down
35 changes: 32 additions & 3 deletions python/pyspark/sql/tests/streaming/test_streaming_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,9 @@ def onQueryTerminated(self, event):
"my_event", count(lit(1)).alias("rc"), count(col("error")).alias("erc")
)

q = observed_ds.writeStream.format("console").start()
q = observed_ds.writeStream.format("noop").start()

while q.lastProgress is None or q.lastProgress["batchId"] == 0:
while q.lastProgress is None or q.lastProgress.batchId == 0:
q.awaitTermination(0.5)

time.sleep(5)
Expand All @@ -241,6 +241,35 @@ def onQueryTerminated(self, event):
q.stop()
self.spark.streams.removeListener(error_listener)

def test_streaming_progress(self):
try:
# Test a fancier query with stateful operation and observed metrics
df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
df_stateful = df_observe.groupBy().count() # make query stateful
q = (
df_stateful.writeStream.format("noop")
.queryName("test")
.outputMode("update")
.trigger(processingTime="5 seconds")
.start()
)

while q.lastProgress is None or q.lastProgress.batchId == 0:
q.awaitTermination(0.5)

q.stop()

self.check_streaming_query_progress(q.lastProgress, True)
for p in q.recentProgress:
self.check_streaming_query_progress(p, True)

row = q.lastProgress.observedMetrics.get("my_event")
self.assertTrue(row["rc"] > 0)
self.assertTrue(row["erc"] > 0)
finally:
q.stop()


class StreamingListenerTests(StreamingListenerTestsMixin, ReusedSQLTestCase):
def test_number_of_public_methods(self):
Expand Down Expand Up @@ -355,7 +384,7 @@ def verify(test_listener):
.start()
)
self.assertTrue(q.isActive)
time.sleep(10)
q.awaitTermination(10)
q.stop()

# Make sure all events are empty
Expand Down

0 comments on commit d967119

Please sign in to comment.