diff --git a/python/pyspark/sql_tests.py b/python/pyspark/sql_tests.py index 3201f4ede6dbb..bda06b73a8566 100644 --- a/python/pyspark/sql_tests.py +++ b/python/pyspark/sql_tests.py @@ -20,10 +20,19 @@ individual modules. """ import os +import sys import pydoc import shutil import tempfile -import unittest + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ UserDefinedType, DoubleType @@ -83,18 +92,16 @@ def setUpClass(cls): ReusedPySparkTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) os.unlink(cls.tempdir.name) + cls.sqlCtx = SQLContext(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + rdd = cls.sc.parallelize(cls.testData) + cls.df = cls.sqlCtx.inferSchema(rdd) @classmethod def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) - def setUp(self): - self.sqlCtx = SQLContext(self.sc) - self.testData = [Row(key=i, value=str(i)) for i in range(100)] - rdd = self.sc.parallelize(self.testData) - self.df = self.sqlCtx.inferSchema(rdd) - def test_udf(self): self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index f0020b1a34de1..b5e28c498040b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -23,7 +23,6 @@ from fileinput import input from glob import glob import os -import pydoc import re import shutil import subprocess