Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Allowing named worker-setups #2391

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 64 additions & 8 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,27 @@ def _del_global_client(c):
pass


def _serialize_named_func(named_func):
""" takes a function or a tuple of (name, func)

returns a tuple of (name, serialized_func)
"""
if isinstance(named_func, tuple):
name, func = named_func
else:
func = named_func
name = None

serialized = dumps(func)

if name is None:
h = tokenize(serialized)
name = funcname(func) + '-' + h

named_func = (name, serialized)
return named_func


class Future(WrappedKey):
""" A remotely running computation

Expand Down Expand Up @@ -3547,35 +3568,70 @@ def _get_task_stream(self, start=None, stop=None, count=None, plot=False,

@gen.coroutine
def _register_worker_callbacks(self, setup=None):
responses = yield self.scheduler.register_worker_callbacks(setup=dumps(setup))

# prepare the callbacks and their names
if setup is not None:
setup = _serialize_named_func(setup)

responses = yield self.scheduler.register_worker_callbacks(setup=setup)

results = {}
for key, resp in responses.items():
if resp['status'] == 'OK':
results[key] = resp['result']
elif resp['status'] == 'error':
six.reraise(*clean_exception(**resp))

raise gen.Return(results)

def register_worker_callbacks(self, setup=None):
"""
Registers a setup callback function for all current and future workers.

This registers a new setup function for workers in this cluster. The
function will run immediately on all currently connected workers. It
will also be run upon connection by any workers that are added in the
future. Multiple setup functions can be registered - these will be
called in the order they were added.
This registers a new setup function for workers in this cluster.
The function will run immediately on all currently connected workers.
It will also be run upon connection by any workers that are added in the
future.
Multiple setup functions can be registered.
An optional callback name can be provided by passing a tuple as
the argument. We only keep one function version per callback name.
If the callback name is not given, a unique name is generated by
tokenizing the function.

The callback function shall be indempotent, and may be called several
times. The order of invocation of callback functions is undefined.

If the function takes an input argument named ``dask_worker`` then
that variable will be populated with the worker itself.

Parameters
----------
setup : callable(dask_worker: Worker) -> None
Function to register and run on all workers
setup : callable(dask_worker: Worker) -> None,
or tuple of (name, callable)
Function to register and run on all workers.
If a name is not given, then a name is generated from the callable.

"""
return self.sync(self._register_worker_callbacks, setup=setup)

def unregister_worker_callbacks(self, setup=None):
"""
Unregisters a worker callback registered via `register_worker_callbacks`.

See `register_worker_callbacks` for the definition of the arguments.

Parameters
----------
setup : callable(dask_worker: Worker) -> None,
or tuple of (name, callable)
The setup callback to remove; it will be matched by the name.
If a name is not given, then a name is generated from the callable.
"""
if setup is not None:
setup = _serialize_named_func(setup)

return self.sync(self.scheduler.unregister_worker_callbacks, setup=setup)


class Executor(Client):
""" Deprecated: see Client """
Expand Down
41 changes: 34 additions & 7 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ def __init__(
self.plugins = []
self.transition_log = deque(maxlen=dask.config.get('distributed.scheduler.transition-log-length'))
self.log = deque(maxlen=dask.config.get('distributed.scheduler.transition-log-length'))
self.worker_setups = []
self.worker_setups = {}

worker_handlers = {
'task-finished': self.handle_task_finished,
Expand Down Expand Up @@ -963,7 +963,8 @@ def __init__(
'heartbeat_worker': self.heartbeat_worker,
'get_task_status': self.get_task_status,
'get_task_stream': self.get_task_stream,
'register_worker_callbacks': self.register_worker_callbacks
'register_worker_callbacks': self.register_worker_callbacks,
'unregister_worker_callbacks': self.unregister_worker_callbacks
}

self._transitions = {
Expand Down Expand Up @@ -3091,15 +3092,41 @@ def get_task_stream(self, comm=None, start=None, stop=None, count=None):

@gen.coroutine
def register_worker_callbacks(self, comm, setup=None):
""" Registers a setup function, and call it on every worker """
if setup is None:
raise gen.Return({})
"""
Registers a set of event driven callback functions
on workers for the given name.

setup must be a tuple of (name, serialized_function)
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For spacing I recommend the following:

"""
Registers a set of event driven callback functions on workers for the given name  # <<--- move this down to avoid the slightly longer line

setup must be a tuple of (name, serialized_function)  # <<<--- move this left to the same level as """
"""

responses = {}

if setup is not None:
name, func = setup

oldfunc = self.worker_setups.get(name, "")

self.worker_setups.append(setup)
if oldfunc != func:
# add the setup function to the list to run them on new clients.
self.worker_setups[name] = func

# trigger the setup function on the existing clients.
responses.update((yield self.broadcast(msg=dict(op='run', function=func))))

responses = yield self.broadcast(msg=dict(op='run', function=setup))
raise gen.Return(responses)

def unregister_worker_callbacks(self, comm, setup=None):
"""
Unregisters a set of event driven callback functions on workers
for the given name.

setup must be a tuple of (name, serialized_function).
The value of serialized_function is unused.

"""
if setup is not None:
name, func = setup
self.worker_setups.pop(name)

#####################
# State Transitions #
#####################
Expand Down
69 changes: 55 additions & 14 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,69 +1203,110 @@ def test_custom_metrics(c, s, a, b):
@gen_cluster(client=True)
def test_register_worker_callbacks(c, s, a, b):
#preload function to run
def mystartup(dask_worker):
dask_worker.init_variable = 1
def mystartup1(dask_worker):
if not hasattr(dask_worker, "init_variable"):
dask_worker.init_variable = 0
dask_worker.init_variable = dask_worker.init_variable + 1

def mystartup2():
import os
os.environ['MY_ENV_VALUE'] = 'WORKER_ENV_VALUE'
return "Env set."

#preload function to run
def mystartup3(dask_worker):
dask_worker.init_variable2 = 1

#Check that preload function has been run
def test_import(dask_worker):
def test_startup1(dask_worker):
return hasattr(dask_worker, 'init_variable')
# and dask_worker.init_variable == 1

def test_run_once(dask_worker):
return (hasattr(dask_worker, 'init_variable')
and dask_worker.init_variable == 1)

def test_startup2():
import os
return os.getenv('MY_ENV_VALUE', None) == 'WORKER_ENV_VALUE'

def test_startup3(dask_worker):
return hasattr(dask_worker, 'init_variable2')

# Nothing has been run yet
assert len(s.worker_setups) == 0
result = yield c.run(test_import)
result = yield c.run(test_startup1)
assert list(result.values()) == [False] * 2
result = yield c.run(test_startup2)
assert list(result.values()) == [False] * 2

# Start a worker and check that startup is not run
worker = Worker(s.address, loop=s.loop)
yield worker._start()
result = yield c.run(test_import, workers=[worker.address])
result = yield c.run(test_startup1, workers=[worker.address])
assert list(result.values()) == [False]
yield worker._close()

# Add a preload function
response = yield c.register_worker_callbacks(setup=mystartup)
assert len(response) == 2
yield c.register_worker_callbacks(setup=mystartup1)
assert len(s.worker_setups) == 1

# Add the same preload function, again
yield c.register_worker_callbacks(setup=mystartup1)
assert len(s.worker_setups) == 1

# Check it has been ran on existing worker
result = yield c.run(test_import)
result = yield c.run(test_startup1)
assert list(result.values()) == [True] * 2

result = yield c.run(test_run_once)
assert list(result.values()) == [True] * 2

# Start a worker and check it is ran on it
worker = Worker(s.address, loop=s.loop)
yield worker._start()
result = yield c.run(test_import, workers=[worker.address])
result = yield c.run(test_startup1, workers=[worker.address])
assert list(result.values()) == [True]
yield worker._close()

# Register another preload function
response = yield c.register_worker_callbacks(setup=mystartup2)
assert len(response) == 2
# Register another preload function, twice with a name
yield c.register_worker_callbacks(setup=('mystartup2', mystartup2))
assert 'mystartup2' in s.worker_setups
yield c.register_worker_callbacks(setup=('mystartup2', mystartup2))
assert 'mystartup2' in s.worker_setups
assert len(s.worker_setups) == 2
assert len(s.worker_setups) == 2

# Check it has been run
result = yield c.run(test_startup2)
assert list(result.values()) == [True] * 2

# unregister a preload function
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed 'register_worker_callbacks' to return the registered names to allow unregisterring callback functions with inferred names.

yield c.register_worker_callbacks(setup=mystartup3)
assert len(s.worker_setups) == 3
yield c.unregister_worker_callbacks(setup=mystartup3)
assert len(s.worker_setups) == 2

# unregister a preload function with name
yield c.register_worker_callbacks(setup=('mystartup3', mystartup3))
assert len(s.worker_setups) == 3
assert 'mystartup3' in s.worker_setups
yield c.unregister_worker_callbacks(setup=('mystartup3', None))
assert 'mystartup3' not in s.worker_setups
assert len(s.worker_setups) == 2

# Start a worker and check it is ran on it
worker = Worker(s.address, loop=s.loop)
yield worker._start()
result = yield c.run(test_import, workers=[worker.address])
result = yield c.run(test_startup1, workers=[worker.address])
assert list(result.values()) == [True]

result = yield c.run(test_startup2, workers=[worker.address])
assert list(result.values()) == [True]

# startup3 is not ran, as it is unregistered before worker is added.
result = yield c.run(test_startup3, workers=[worker.address])
assert list(result.values()) == [False]

yield worker._close()

# Final exception test
Expand Down
23 changes: 14 additions & 9 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,8 @@ def __init__(self, scheduler_ip=None, scheduler_port=None,
io_loop=self.io_loop)
self.periodic_callbacks['profile-cycle'] = pc

self.worker_setups = {}

_global_workers.append(weakref.ref(self))

##################
Expand Down Expand Up @@ -610,23 +612,26 @@ def _register_with_scheduler(self):
raise ValueError("Unexpected response from register: %r" %
(response,))
else:
# Retrieve eventual init functions and run them
for function_bytes in response['worker-setups']:
setup_function = pickle.loads(function_bytes)
if has_arg(setup_function, 'dask_worker'):
result = setup_function(dask_worker=self)
else:
result = setup_function()
logger.info('Init function %s ran: output=%s' % (setup_function, result))

logger.info(' Registered to: %26s', self.scheduler.address)
logger.info('-' * 49)

# Retrieve eventual init functions (only worker-setups for now)
for name, function_bytes in response['worker-setups'].items():
self.worker_setups[name] = pickle.loads(function_bytes)

self.batched_stream = BatchedSend(interval='2ms', loop=self.loop)
self.batched_stream.start(comm)
self.periodic_callbacks['heartbeat'].start()
self.loop.add_callback(self.handle_scheduler, comm)

# run eventual init functions (only worker-setups for now)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this to run the worker init functions after heartbeat; also I am using the locak worker_setups directionary rather than the unpickled function directly.

for name, setup_function in self.worker_setups.items():
if has_arg(setup_function, 'dask_worker'):
result = setup_function(dask_worker=self)
else:
result = setup_function()
logger.info('Init function %s : %s ran: output=%s' % (name, setup_function, result))

@gen.coroutine
def heartbeat(self):
if not self.heartbeat_active:
Expand Down