Skip to content

Commit

Permalink
[SPARK-3786] [PySpark] speedup tests
Browse files Browse the repository at this point in the history
This patch try to speed up tests of PySpark, re-use the SparkContext in tests.py and mllib/tests.py to reduce the overhead of create SparkContext, remove some test cases, which did not make sense. It also improve the performance of some cases, such as MergerTests and SortTests.

before this patch:

real	21m27.320s
user	4m42.967s
sys	0m17.343s

after this patch:

real	9m47.541s
user	2m12.947s
sys	0m14.543s

It almost cut the time by half.

Author: Davies Liu <[email protected]>

Closes apache#2646 from davies/tests and squashes the following commits:

c54de60 [Davies Liu] revert change about memory limit
6a2a4b0 [Davies Liu] refactor of tests, speedup 100%
  • Loading branch information
davies authored and JoshRosen committed Oct 6, 2014
1 parent 20ea54c commit 4f01265
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 91 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pyspark.serializers import PickleSerializer
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
from pyspark.tests import PySparkTestCase
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase


_have_scipy = False
Expand Down
5 changes: 2 additions & 3 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,6 @@ def _external_items(self):
for v in self.data.iteritems():
yield v
self.data.clear()
gc.collect()

# remove the merged partition
for j in range(self.spills):
Expand Down Expand Up @@ -428,7 +427,7 @@ def _recursive_merged_items(self, start):
subdirs = [os.path.join(d, "parts", str(i))
for d in self.localdirs]
m = ExternalMerger(self.agg, self.memory_limit, self.serializer,
subdirs, self.scale * self.partitions)
subdirs, self.scale * self.partitions, self.partitions)
m.pdata = [{} for _ in range(self.partitions)]
limit = self._next_limit()

Expand Down Expand Up @@ -486,7 +485,7 @@ def sorted(self, iterator, key=None, reverse=False):
goes above the limit.
"""
global MemoryBytesSpilled, DiskBytesSpilled
batch = 10
batch = 100
chunks, current_chunk = [], []
iterator = iter(iterator)
while True:
Expand Down
92 changes: 42 additions & 50 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@
SPARK_HOME = os.environ["SPARK_HOME"]


class TestMerger(unittest.TestCase):
class MergerTests(unittest.TestCase):

def setUp(self):
self.N = 1 << 16
self.N = 1 << 14
self.l = [i for i in xrange(self.N)]
self.data = zip(self.l, self.l)
self.agg = Aggregator(lambda x: [x],
Expand Down Expand Up @@ -115,15 +115,15 @@ def test_medium_dataset(self):
sum(xrange(self.N)) * 3)

def test_huge_dataset(self):
m = ExternalMerger(self.agg, 10)
m = ExternalMerger(self.agg, 10, partitions=3)
m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)),
self.N * 10)
m._cleanup()


class TestSorter(unittest.TestCase):
class SorterTests(unittest.TestCase):
def test_in_memory_sort(self):
l = range(1024)
random.shuffle(l)
Expand Down Expand Up @@ -244,16 +244,25 @@ def tearDown(self):
sys.path = self._old_sys_path


class TestCheckpoint(PySparkTestCase):
class ReusedPySparkTestCase(unittest.TestCase):

@classmethod
def setUpClass(cls):
cls.sc = SparkContext('local[4]', cls.__name__, batchSize=2)

@classmethod
def tearDownClass(cls):
cls.sc.stop()


class CheckpointTests(ReusedPySparkTestCase):

def setUp(self):
PySparkTestCase.setUp(self)
self.checkpointDir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(self.checkpointDir.name)
self.sc.setCheckpointDir(self.checkpointDir.name)

def tearDown(self):
PySparkTestCase.tearDown(self)
shutil.rmtree(self.checkpointDir.name)

def test_basic_checkpointing(self):
Expand Down Expand Up @@ -288,7 +297,7 @@ def test_checkpoint_and_restore(self):
self.assertEquals([1, 2, 3, 4], recovered.collect())


class TestAddFile(PySparkTestCase):
class AddFileTests(PySparkTestCase):

def test_add_py_file(self):
# To ensure that we're actually testing addPyFile's effects, check that
Expand Down Expand Up @@ -354,7 +363,7 @@ def func(x):
self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect())


class TestRDDFunctions(PySparkTestCase):
class RDDTests(ReusedPySparkTestCase):

def test_id(self):
rdd = self.sc.parallelize(range(10))
Expand All @@ -365,12 +374,6 @@ def test_id(self):
self.assertEqual(id + 1, id2)
self.assertEqual(id2, rdd2.id())

def test_failed_sparkcontext_creation(self):
# Regression test for SPARK-1550
self.sc.stop()
self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))
self.sc = SparkContext("local")

def test_save_as_textfile_with_unicode(self):
# Regression test for SPARK-970
x = u"\u00A1Hola, mundo!"
Expand Down Expand Up @@ -636,7 +639,7 @@ def test_distinct(self):
self.assertEquals(result.count(), 3)


class TestProfiler(PySparkTestCase):
class ProfilerTests(PySparkTestCase):

def setUp(self):
self._old_sys_path = list(sys.path)
Expand Down Expand Up @@ -666,10 +669,9 @@ def heavy_foo(x):
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))


class TestSQL(PySparkTestCase):
class SQLTests(ReusedPySparkTestCase):

def setUp(self):
PySparkTestCase.setUp(self)
self.sqlCtx = SQLContext(self.sc)

def test_udf(self):
Expand Down Expand Up @@ -754,27 +756,19 @@ def test_serialize_nested_array_and_map(self):
self.assertEqual("2", row.d)


class TestIO(PySparkTestCase):

def test_stdout_redirection(self):
import subprocess

def func(x):
subprocess.check_call('ls', shell=True)
self.sc.parallelize([1]).foreach(func)
class InputFormatTests(ReusedPySparkTestCase):

@classmethod
def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(cls.tempdir.name)
cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc)

class TestInputFormat(PySparkTestCase):

def setUp(self):
PySparkTestCase.setUp(self)
self.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(self.tempdir.name)
self.sc._jvm.WriteInputFormatTestDataGenerator.generateData(self.tempdir.name, self.sc._jsc)

def tearDown(self):
PySparkTestCase.tearDown(self)
shutil.rmtree(self.tempdir.name)
@classmethod
def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
shutil.rmtree(cls.tempdir.name)

def test_sequencefiles(self):
basepath = self.tempdir.name
Expand Down Expand Up @@ -954,15 +948,13 @@ def test_converters(self):
self.assertEqual(maps, em)


class TestOutputFormat(PySparkTestCase):
class OutputFormatTests(ReusedPySparkTestCase):

def setUp(self):
PySparkTestCase.setUp(self)
self.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(self.tempdir.name)

def tearDown(self):
PySparkTestCase.tearDown(self)
shutil.rmtree(self.tempdir.name, ignore_errors=True)

def test_sequencefiles(self):
Expand Down Expand Up @@ -1243,8 +1235,7 @@ def test_malformed_RDD(self):
basepath + "/malformed/sequence"))


class TestDaemon(unittest.TestCase):

class DaemonTests(unittest.TestCase):
def connect(self, port):
from socket import socket, AF_INET, SOCK_STREAM
sock = socket(AF_INET, SOCK_STREAM)
Expand Down Expand Up @@ -1290,7 +1281,7 @@ def test_termination_sigterm(self):
self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))


class TestWorker(PySparkTestCase):
class WorkerTests(PySparkTestCase):

def test_cancel_task(self):
temp = tempfile.NamedTemporaryFile(delete=True)
Expand Down Expand Up @@ -1342,11 +1333,6 @@ def run():
rdd = self.sc.parallelize(range(100), 1)
self.assertEqual(100, rdd.map(str).count())

def test_fd_leak(self):
N = 1100 # fd limit is 1024 by default
rdd = self.sc.parallelize(range(N), N)
self.assertEquals(N, rdd.count())

def test_after_exception(self):
def raise_exception(_):
raise Exception()
Expand Down Expand Up @@ -1379,7 +1365,7 @@ def test_accumulator_when_reuse_worker(self):
self.assertEqual(sum(range(100)), acc1.value)


class TestSparkSubmit(unittest.TestCase):
class SparkSubmitTests(unittest.TestCase):

def setUp(self):
self.programDir = tempfile.mkdtemp()
Expand Down Expand Up @@ -1492,6 +1478,8 @@ def test_single_script_on_cluster(self):
|sc = SparkContext()
|print sc.parallelize([1, 2, 3]).map(foo).collect()
""")
# this will fail if you have different spark.executor.memory
# in conf/spark-defaults.conf
proc = subprocess.Popen(
[self.sparkSubmit, "--master", "local-cluster[1,1,512]", script],
stdout=subprocess.PIPE)
Expand All @@ -1500,7 +1488,11 @@ def test_single_script_on_cluster(self):
self.assertIn("[2, 4, 6]", out)


class ContextStopTests(unittest.TestCase):
class ContextTests(unittest.TestCase):

def test_failed_sparkcontext_creation(self):
# Regression test for SPARK-1550
self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))

def test_stop(self):
sc = SparkContext()
Expand Down
74 changes: 37 additions & 37 deletions python/run-tests
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ rm -rf metastore warehouse
function run_test() {
echo "Running test: $1"

SPARK_TESTING=1 "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log
SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log

FAILED=$((PIPESTATUS[0]||$FAILED))

Expand All @@ -48,6 +48,37 @@ function run_test() {
fi
}

function run_core_tests() {
echo "Run core tests ..."
run_test "pyspark/rdd.py"
run_test "pyspark/context.py"
run_test "pyspark/conf.py"
PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
PYSPARK_DOC_TEST=1 run_test "pyspark/serializers.py"
run_test "pyspark/shuffle.py"
run_test "pyspark/tests.py"
}

function run_sql_tests() {
echo "Run sql tests ..."
run_test "pyspark/sql.py"
}

function run_mllib_tests() {
echo "Run mllib tests ..."
run_test "pyspark/mllib/classification.py"
run_test "pyspark/mllib/clustering.py"
run_test "pyspark/mllib/linalg.py"
run_test "pyspark/mllib/random.py"
run_test "pyspark/mllib/recommendation.py"
run_test "pyspark/mllib/regression.py"
run_test "pyspark/mllib/stat.py"
run_test "pyspark/mllib/tree.py"
run_test "pyspark/mllib/util.py"
run_test "pyspark/mllib/tests.py"
}

echo "Running PySpark tests. Output is in python/unit-tests.log."

export PYSPARK_PYTHON="python"
Expand All @@ -60,49 +91,18 @@ fi
echo "Testing with Python version:"
$PYSPARK_PYTHON --version

run_test "pyspark/rdd.py"
run_test "pyspark/context.py"
run_test "pyspark/conf.py"
run_test "pyspark/sql.py"
# These tests are included in the module-level docs, and so must
# be handled on a higher level rather than within the python file.
export PYSPARK_DOC_TEST=1
run_test "pyspark/broadcast.py"
run_test "pyspark/accumulators.py"
run_test "pyspark/serializers.py"
unset PYSPARK_DOC_TEST
run_test "pyspark/shuffle.py"
run_test "pyspark/tests.py"
run_test "pyspark/mllib/classification.py"
run_test "pyspark/mllib/clustering.py"
run_test "pyspark/mllib/linalg.py"
run_test "pyspark/mllib/random.py"
run_test "pyspark/mllib/recommendation.py"
run_test "pyspark/mllib/regression.py"
run_test "pyspark/mllib/stat.py"
run_test "pyspark/mllib/tests.py"
run_test "pyspark/mllib/tree.py"
run_test "pyspark/mllib/util.py"
run_core_tests
run_sql_tests
run_mllib_tests

# Try to test with PyPy
if [ $(which pypy) ]; then
export PYSPARK_PYTHON="pypy"
echo "Testing with PyPy version:"
$PYSPARK_PYTHON --version

run_test "pyspark/rdd.py"
run_test "pyspark/context.py"
run_test "pyspark/conf.py"
run_test "pyspark/sql.py"
# These tests are included in the module-level docs, and so must
# be handled on a higher level rather than within the python file.
export PYSPARK_DOC_TEST=1
run_test "pyspark/broadcast.py"
run_test "pyspark/accumulators.py"
run_test "pyspark/serializers.py"
unset PYSPARK_DOC_TEST
run_test "pyspark/shuffle.py"
run_test "pyspark/tests.py"
run_core_tests
run_sql_tests
fi

if [[ $FAILED == 0 ]]; then
Expand Down

0 comments on commit 4f01265

Please sign in to comment.