Skip to content

Commit

Permalink
yarn: add env name param to backend (#5)
Browse files Browse the repository at this point in the history
* yarn: add env name param to backend

* yarn: polishing env param

* yarn: even more parametrizable conda env path
  • Loading branch information
aabadie authored Jun 8, 2017
1 parent 37cdbf0 commit 11c3b0d
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 32 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ doc/CHANGES.rst
doc/README.rst
# Coverage report
.coverage
coverage.xml
# Pytest
.cache
# Python cache
Expand Down
14 changes: 9 additions & 5 deletions joblibhadoop/yarn/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@

from joblib._parallel_backends import ThreadingBackend
from joblib.my_exceptions import WorkerInterrupt
from .pool import YarnPool
from .pool import YarnPool, JOBLIB_YARN_DEFAULT_CONDA_ENV


__interrupts__ = [KeyboardInterrupt, WorkerInterrupt]
JOBLIB_YARN_INTERRUPTS = [KeyboardInterrupt, WorkerInterrupt]


class YarnBackend(ThreadingBackend):
"""The YARN backend class."""

def __init__(self, packages=[]):
def __init__(self, env=JOBLIB_YARN_DEFAULT_CONDA_ENV, packages=[],
clear_env=False):
"""Constructor"""
self.packages = packages
self._pool = None
self.parallel = None
self.env = env
self.clear_env = clear_env

def effective_n_jobs(self, n_jobs):
"""Return the number of effective jobs running in the backend."""
Expand All @@ -30,12 +33,13 @@ def effective_n_jobs(self, n_jobs):
def configure(self, n_jobs, parallel=None, **backend_args):
"""Initialize the backend."""
n_jobs = self.effective_n_jobs(n_jobs)
self._pool = YarnPool(processes=n_jobs, packages=self.packages)
self._pool = YarnPool(processes=n_jobs, env=self.env,
packages=self.packages, clear_env=self.clear_env)
self.parallel = parallel
return n_jobs

def get_exceptions(self):
"""Return the list of interrupt supported by the backend."""
# We are using multiprocessing, we also want to capture
# KeyboardInterrupts
return __interrupts__
return JOBLIB_YARN_INTERRUPTS
51 changes: 32 additions & 19 deletions joblibhadoop/yarn/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,52 +12,65 @@
from .remotepool import RemotePool, RemoteWorker
from ..resources import conda_environment_filename

TEMP_DIR = os.environ.get('JOBLIB_TEMP_FOLDER', tempfile.gettempdir())

JOBLIB_YARN_WORKER = 'joblib-yarn-worker'
JOBLIB_YARN_CONDA_ENV = 'conda_env'
JOBLIB_YARN_DEFAULT_CONDA_ENV = 'joblib_yarn_conda_env'
JOBLIB_YARN_DEFAULT_CONDA_ROOT = TEMP_DIR

TEMP_DIR = os.environ.get('JOBLIB_TEMP_FOLDER', tempfile.gettempdir())
CONDA_ENV_CREATE_COMMAND = 'conda env create -p {} --file={}'
CONDA_ENV_INSTALL_COMMAND = 'conda install -y -q -p {} {}'


def create_conda_env(*packages):
def _create_conda_env(env, env_root_path, packages, clear):
"""Create a conda environment to pass to Knit"""
# Create conda environment
if os.path.isfile(os.path.join(TEMP_DIR, JOBLIB_YARN_CONDA_ENV + '.zip')):
return
env_dir = os.path.join(env_root_path, env)
env_file = env_dir + '.zip'
if clear:
# Remove an existing env directory
shutil.rmtree(env_dir, ignore_errors=True)

if os.path.isfile(env_file):
if clear:
os.remove(env_file)
else:
# Skip if env already exists and clear is not required
return

if not os.path.isdir(env_dir):
# Create conda environment
os.system(CONDA_ENV_CREATE_COMMAND.format(
env_dir, conda_environment_filename()))
if len(packages):
os.system(CONDA_ENV_INSTALL_COMMAND.format(
env_dir, ' '.join(packages)))

os.system(CONDA_ENV_CREATE_COMMAND.format(
os.path.join(TEMP_DIR, JOBLIB_YARN_CONDA_ENV),
conda_environment_filename()))
if len(packages):
os.system(CONDA_ENV_INSTALL_COMMAND.format(
os.path.join(TEMP_DIR, JOBLIB_YARN_CONDA_ENV),
' '.join(packages)))
# Archive conda environment
shutil.make_archive(os.path.join(TEMP_DIR, JOBLIB_YARN_CONDA_ENV), 'zip',
root_dir=TEMP_DIR,
base_dir=JOBLIB_YARN_CONDA_ENV)
shutil.make_archive(env_dir, 'zip', root_dir=env_root_path, base_dir=env)


class YarnPool(RemotePool):
"""The Yarn Pool mananger."""

def __init__(self, processes=None, port=0, authkey=None, packages=[]):
def __init__(self, processes=None, port=0, authkey=None,
env=JOBLIB_YARN_DEFAULT_CONDA_ENV,
env_root_path=JOBLIB_YARN_DEFAULT_CONDA_ROOT,
packages=[], clear_env=False):
super(YarnPool, self).__init__(processes=processes,
port=port,
authkey=authkey,
workerscript=JOBLIB_YARN_WORKER)
self.stopping = False
self.knit = Knit(autodetect=True)
create_conda_env(*packages)
_create_conda_env(env, env_root_path, packages, clear_env)
cmd = ('$PYTHON_BIN $CONDA_PREFIX/bin/{} --host {} --port {} --key {}'
.format(JOBLIB_YARN_WORKER,
socket.gethostname(),
self.server.address[1],
self.authkey))
self.app_id = self.knit.start(
cmd, num_containers=self._processes,
env='{}.zip'.format(os.path.join(TEMP_DIR, JOBLIB_YARN_CONDA_ENV)))
env='{}.zip'.format(os.path.join(TEMP_DIR, env)))
self.thread = Thread(target=self._monitor_appid)
self.thread.deamon = True
self.thread.start()
Expand Down
50 changes: 42 additions & 8 deletions joblibhadoop/yarn/tests/test_yarn_backend.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,65 @@
"""Test the yarn parallel backend."""

import os
import os.path
import shutil
import tempfile
from math import sqrt

import pytest
from pytest import mark

from joblib import (Parallel, delayed,
register_parallel_backend, parallel_backend)
from joblibhadoop.yarn import YarnBackend
from joblibhadoop.yarn.backend import __interrupts__
from joblibhadoop.yarn.backend import JOBLIB_YARN_INTERRUPTS
from joblibhadoop.yarn.pool import (_create_conda_env,
JOBLIB_YARN_DEFAULT_CONDA_ENV,
JOBLIB_YARN_DEFAULT_CONDA_ROOT)

__NAMENODE = os.environ['JOBLIB_HDFS_NAMENODE']
JOBLIB_HDFS_NAMENODE = os.environ['JOBLIB_HDFS_NAMENODE']


skip_localhost = pytest.mark.skipif(__NAMENODE == 'localhost',
skip_localhost = pytest.mark.skipif(JOBLIB_HDFS_NAMENODE == 'localhost',
reason="Cannot use nodemanager from "
"localhost")


@mark.parametrize('packages', [[], ['pandas']])
def test_create_conda_env(tmpdir, packages):
"""Test conda env creation works as expected."""
assert tempfile.gettempdir() == JOBLIB_YARN_DEFAULT_CONDA_ROOT

env = JOBLIB_YARN_DEFAULT_CONDA_ENV
env_dir = tmpdir.join(env).strpath
root_dir = tmpdir.strpath
env_file = env_dir + '.zip'

_create_conda_env(env, root_dir, packages, False)

assert os.path.isdir(env_dir)
assert os.path.isfile(env_file)

_create_conda_env(env, root_dir, packages, True)

assert os.path.isdir(env_dir)
assert os.path.isfile(env_file)

_create_conda_env(env, root_dir, packages, False)

assert os.path.isdir(env_dir)
assert os.path.isfile(env_file)

shutil.rmtree(env_dir, ignore_errors=True)
os.remove(env_file)


def test_supported_interrupt():
"""Verify the list of supported interrupts is correct."""
register_parallel_backend('yarn', YarnBackend)

backend = YarnBackend()
assert backend.get_exceptions() == __interrupts__
assert backend.get_exceptions() == JOBLIB_YARN_INTERRUPTS


def test_parallel_invalid_njobs_raises_value_error():
Expand All @@ -32,14 +68,12 @@ def test_parallel_invalid_njobs_raises_value_error():

with pytest.raises(ValueError) as excinfo:
with parallel_backend('yarn'):
result = Parallel(verbose=100)(
delayed(sqrt)(i**2) for i in range(100))
Parallel(verbose=100)(delayed(sqrt)(i**2) for i in range(100))
assert 'n_jobs < 0 is not implemented yet' in str(excinfo.value)

with pytest.raises(ValueError) as excinfo:
with parallel_backend('yarn', n_jobs=0):
result = Parallel(verbose=100)(
delayed(sqrt)(i**2) for i in range(100))
Parallel(verbose=100)(delayed(sqrt)(i**2) for i in range(100))
assert 'n_jobs == 0 in Parallel has no meaning' in str(excinfo.value)


Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ addopts =
--pep8
--cov=joblibhadoop
--cov-report=xml
--cov-report=term
env =
D:JOBLIB_HDFS_NAMENODE=localhost
testpaths = joblibhadoop

0 comments on commit 11c3b0d

Please sign in to comment.