diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py index 6fe2b89408014..1c2c04f2da54f 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py @@ -20,7 +20,11 @@ class StreamingParityTests(StreamingTestsMixin, ReusedConnectTestCase): - pass + def _assert_exception_tree_contains_msg(self, exception, msg): + self.assertTrue( + msg in exception._message, + "Exception tree doesn't contain the expected message: %s" % msg, + ) if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py index 2b9072c34befe..a7c22897096b6 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming.py +++ b/python/pyspark/sql/tests/streaming/test_streaming.py @@ -24,7 +24,6 @@ from pyspark.sql.functions import lit from pyspark.sql.types import StructType, StructField, IntegerType, StringType from pyspark.testing.sqlutils import ReusedSQLTestCase -from pyspark.errors.exceptions.connect import SparkConnectException class StreamingTestsMixin: @@ -295,26 +294,6 @@ def test_stream_exception(self): self.assertIsInstance(exception, StreamingQueryException) self._assert_exception_tree_contains_msg(exception, "ZeroDivisionError") - def _assert_exception_tree_contains_msg(self, exception, msg): - if isinstance(exception, SparkConnectException): - self._assert_exception_tree_contains_msg_connect(exception, msg) - else: - self._assert_exception_tree_contains_msg_default(exception, msg) - - def _assert_exception_tree_contains_msg_connect(self, exception, msg): - self.assertTrue( - msg in exception._message, - "Exception tree doesn't contain the expected message: %s" % msg, - ) - - def _assert_exception_tree_contains_msg_default(self, exception, msg): - e = exception - contains = msg in e._desc - while e._cause is not None and not contains: - e = e._cause - contains = msg in e._desc - self.assertTrue(contains, "Exception tree doesn't contain the expected message: %s" % msg) - def test_query_manager_get(self): df = self.spark.readStream.format("rate").load() for q in self.spark.streams.active: @@ -408,7 +387,13 @@ def test_streaming_with_temporary_view(self): class StreamingTests(StreamingTestsMixin, ReusedSQLTestCase): - pass + def _assert_exception_tree_contains_msg(self, exception, msg): + e = exception + contains = msg in e._desc + while e._cause is not None and not contains: + e = e._cause + contains = msg in e._desc + self.assertTrue(contains, "Exception tree doesn't contain the expected message: %s" % msg) if __name__ == "__main__":