Skip to content

Commit

Permalink
get local directory by SPARK_LOCAL_DIR
Browse files Browse the repository at this point in the history
support multiple local directories
  • Loading branch information
davies committed Jul 21, 2014
1 parent 57ee7ef commit 24cec6a
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ private[spark] class PythonRDD[T: ClassTag](
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
val localdir = env.conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))
val worker: Socket = env.createPythonWorker(pythonExec,
envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir))

// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,8 +1268,7 @@ def combineLocally(iterator):
in ('true', '1', 'yes'))
memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory") or "512m")
def _mergeCombiners(iterator):
# TODO: workdir
merger = ExternalHashMapMerger(mergeCombiners, memory, serializer=serializer)\
merger = ExternalHashMapMerger(mergeCombiners, memory, serializer)\
if spill else MapMerger(mergeCombiners)
merger.merge(iterator)
return merger.iteritems()
Expand Down
77 changes: 50 additions & 27 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,35 @@ class ExternalHashMapMerger(Merger):
PARTITIONS = 64
BATCH = 10000

def __init__(self, combiner, memory_limit=512, path="/tmp/pyspark/merge",
serializer=None, scale=1):
def __init__(self, combiner, memory_limit=512, serializer=None,
localdirs=None, scale=1):
self.combiner = combiner
self.memory_limit = memory_limit
self.path = os.path.join(path, str(os.getpid()))
self.serializer = serializer or BatchedSerializer(AutoSerializer(), 1024)
self.localdirs = localdirs or self._get_dirs()
self.scale = scale
self.data = {}
self.pdata = []
self.spills = 0

def _get_dirs(self):
path = os.environ.get("SPARK_LOCAL_DIR", "/tmp/spark")
dirs = path.split(",")
localdirs = []
for d in dirs:
d = os.path.join(d, "merge", str(os.getpid()))
try:
os.makedirs(d)
localdirs.append(d)
except IOError:
pass
if not localdirs:
raise IOError("no writable directories: " + path)
return localdirs

def _get_spill_dir(self, n):
return os.path.join(self.localdirs[n % len(self.localdirs)], str(n))

@property
def used_memory(self):
return get_used_memory()
Expand Down Expand Up @@ -144,7 +162,7 @@ def _partitioned_merge(self, iterator, limit):
limit = self.next_limit

def _first_spill(self):
path = os.path.join(self.path, str(self.spills))
path = self._get_spill_dir(self.spills)
if not os.path.exists(path):
os.makedirs(path)
streams = [open(os.path.join(path, str(i)), 'w')
Expand All @@ -159,7 +177,7 @@ def _first_spill(self):
self.spills += 1

def _spill(self):
path = os.path.join(self.path, str(self.spills))
path = self._get_spill_dir(self.spills)
if not os.path.exists(path):
os.makedirs(path)
for i in range(self.PARTITIONS):
Expand All @@ -181,23 +199,29 @@ def _external_items(self):
self._spill()
hard_limit = self.next_limit

for i in range(self.PARTITIONS):
self.data = {}
for j in range(self.spills):
p = os.path.join(self.path, str(j), str(i))
self.merge(self.serializer.load_stream(open(p)), check=False)

if j > 0 and self.used_memory > hard_limit and j < self.spills - 1:
self.data.clear() # will read from disk again
for v in self._recursive_merged_items(i):
yield v
return

for v in self.data.iteritems():
yield v
self.data.clear()

shutil.rmtree(self.path, True)
try:
for i in range(self.PARTITIONS):
self.data = {}
for j in range(self.spills):
path = self._get_spill_dir(j)
p = os.path.join(path, str(i))
self.merge(self.serializer.load_stream(open(p)), check=False)

if j > 0 and self.used_memory > hard_limit and j < self.spills - 1:
self.data.clear() # will read from disk again
for v in self._recursive_merged_items(i):
yield v
return

for v in self.data.iteritems():
yield v
self.data.clear()
finally:
self._cleanup()

def _cleanup(self):
for d in self.localdirs:
shutil.rmtree(d, True)

def _recursive_merged_items(self, start):
assert not self.data
Expand All @@ -206,14 +230,15 @@ def _recursive_merged_items(self, start):
self._spill()

for i in range(start, self.PARTITIONS):
subdirs = [os.path.join(d, 'merge', str(i)) for d in self.localdirs]
m = ExternalHashMapMerger(self.combiner, self.memory_limit,
os.path.join(self.path, 'merge', str(i)),
self.serializer, scale=self.scale * self.PARTITIONS)
self.serializer, subdirs, self.scale * self.PARTITIONS)
m.pdata = [{} for _ in range(self.PARTITIONS)]
limit = self.next_limit

for j in range(self.spills):
p = os.path.join(self.path, str(j), str(i))
path = self._get_spill_dir(j)
p = os.path.join(path, str(i))
m._partitioned_merge(self.serializer.load_stream(open(p)), 0)
if m.used_memory > limit:
m._spill()
Expand All @@ -222,8 +247,6 @@ def _recursive_merged_items(self, start):
for v in m._external_items():
yield v

shutil.rmtree(self.path, True)


if __name__ == '__main__':
import doctest
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def test_huge_dataset(self):
m.merge(map(lambda (k,v): (k, [str(v)]), self.data) * 10)
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(len(v) for k,v in m._recursive_merged_items(0)), self.N * 10)
m._cleanup()


class PySparkTestCase(unittest.TestCase):
Expand Down

0 comments on commit 24cec6a

Please sign in to comment.