diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f66423e05e9b6..fb592989ce4b0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -908,7 +908,7 @@ def test_stopiteration_in_udf(self): def foo(x): raise StopIteration() - with self.assertRaises(Py4JJavaError) as cm: + with self.assertRaises(Py4JJavaError): self.spark.range(0, 1000).withColumn('v', udf(foo)).show() def test_validate_column_types(self): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 383fdde59aad0..18f88cee89e1f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1279,28 +1279,22 @@ def test_pipe_unicode(self): def test_stopiteration_in_client_code(self): - def a_rdd(keyed=False): - return self.sc.parallelize( - ((x % 2, x) if keyed else x) - for x in range(10) - ) - def stopit(*x): raise StopIteration() - def do_test(action, *args, **kwargs): - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - action(*args, **kwargs) - - do_test(a_rdd().map(stopit).collect) - do_test(a_rdd().filter(stopit).collect) - do_test(a_rdd().cartesian(a_rdd()).flatMap(stopit).collect) - do_test(a_rdd().foreach, stopit) - do_test(a_rdd(keyed=True).reduceByKeyLocally, stopit) - do_test(a_rdd().reduce, stopit) - do_test(a_rdd().fold, 0, stopit) - do_test(a_rdd().aggregate, 0, stopit, lambda *x: 1) - do_test(a_rdd().aggregate, 0, lambda *x: 1, stopit) + seq_rdd = self.sc.parallelize(range(10)) + keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) + exc = Py4JJavaError, RuntimeError + + self.assertRaises(exc, seq_rdd.map(stopit).collect) + self.assertRaises(exc, seq_rdd.filter(stopit).collect) + self.assertRaises(exc, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + self.assertRaises(exc, seq_rdd.foreach, stopit) + self.assertRaises(exc, keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaises(exc, seq_rdd.reduce, stopit) + self.assertRaises(exc, seq_rdd.fold, 0, stopit) + self.assertRaises(exc, seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaises(exc, seq_rdd.aggregate, 0, lambda *x: 1, stopit) class ProfilerTests(PySparkTestCase):