Skip to content

Commit

Permalink
Replace register_worker_callbacks with worker plugins (#2453)
Browse files Browse the repository at this point in the history
* Add worker plugins

* add docstring

* Replace legacy worker_callbacks with worker_plugins

* add and test name keyword

* fix missing import

* black

* respond to feedback

* Handle errors again

* Expand docstring
  • Loading branch information
mrocklin authored and TomAugspurger committed May 22, 2019
1 parent 28ce1ed commit 6134c75
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 35 deletions.
92 changes: 80 additions & 12 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
parse_timedelta,
shutting_down,
Any,
has_keyword,
)
from .versions import get_versions

Expand Down Expand Up @@ -3854,17 +3855,6 @@ def _get_task_stream(
else:
raise gen.Return(msgs)

@gen.coroutine
def _register_worker_callbacks(self, setup=None):
responses = yield self.scheduler.register_worker_callbacks(setup=dumps(setup))
results = {}
for key, resp in responses.items():
if resp["status"] == "OK":
results[key] = resp["result"]
elif resp["status"] == "error":
six.reraise(*clean_exception(**resp))
raise gen.Return(results)

def register_worker_callbacks(self, setup=None):
"""
Registers a setup callback function for all current and future workers.
Expand All @@ -3883,7 +3873,85 @@ def register_worker_callbacks(self, setup=None):
setup : callable(dask_worker: Worker) -> None
Function to register and run on all workers
"""
return self.sync(self._register_worker_callbacks, setup=setup)
return self.register_worker_plugin(_WorkerSetupPlugin(setup))

@gen.coroutine
def _register_worker_plugin(self, plugin=None, name=None):
responses = yield self.scheduler.register_worker_plugin(
plugin=dumps(plugin), name=name
)
for response in responses.values():
if response["status"] == "error":
exc = response["exception"]
typ = type(exc)
tb = response["traceback"]
six.reraise(typ, exc, tb)
raise gen.Return(responses)

def register_worker_plugin(self, plugin=None, name=None):
"""
Registers a lifecycle worker plugin for all current and future workers.
This registers a new object to handle setup and teardown for workers in
this cluster. The plugin will instantiate itself on all currently
connected workers. It will also be run on any worker that connects in
the future.
The plugin should be an object with ``setup`` and ``teardown`` methods.
It must be serializable with the pickle or cloudpickle modules.
If the plugin has a ``name`` attribute, or if the ``name=`` keyword is
used then that will control idempotency. A a plugin with that name has
already registered then any future plugins will not run.
For alternatives to plugins, you may also wish to look into preload
scripts.
Parameters
----------
plugin: object
The plugin object to pass to the workers
name: str, optional
A name for the plugin.
Registering a plugin with the same name will have no effect.
Examples
--------
>>> class MyPlugin:
... def __init__(self, *args, **kwargs):
... pass # the constructor is up to you
... def setup(self, worker: dask.distributed.Worker):
... pass
... def teardown(self, worker: dask.distributed.Worker):
... pass
>>> plugin = MyPlugin(1, 2, 3)
>>> client.register_worker_plugin(plugin)
You can get access to the plugin with the ``get_worker`` function
>>> client.register_worker_plugin(other_plugin, name='my-plugin')
>>> def f():
... worker = get_worker()
... plugin = worker.plugins['my-plugin']
... return plugin.my_state
>>> future = client.run(f)
"""
return self.sync(self._register_worker_plugin, plugin=plugin, name=name)


class _WorkerSetupPlugin(object):
""" This is used to support older setup functions as callbacks """

def __init__(self, setup):
self._setup = setup

def setup(self, worker):
if has_keyword(self._setup, "dask_worker"):
return self._setup(dask_worker=worker)
else:
return self._setup()


class Executor(Client):
Expand Down
17 changes: 8 additions & 9 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def __init__(
self.log = deque(
maxlen=dask.config.get("distributed.scheduler.transition-log-length")
)
self.worker_setups = []
self.worker_plugins = []

worker_handlers = {
"task-finished": self.handle_task_finished,
Expand Down Expand Up @@ -1062,7 +1062,7 @@ def __init__(
"heartbeat_worker": self.heartbeat_worker,
"get_task_status": self.get_task_status,
"get_task_stream": self.get_task_stream,
"register_worker_callbacks": self.register_worker_callbacks,
"register_worker_plugin": self.register_worker_plugin,
}

self._transitions = {
Expand Down Expand Up @@ -1510,7 +1510,7 @@ def add_worker(
"status": "OK",
"time": time(),
"heartbeat-interval": heartbeat_interval(len(self.workers)),
"worker-setups": self.worker_setups,
"worker-plugins": self.worker_plugins,
}
)
yield self.handle_worker(comm=comm, worker=address)
Expand Down Expand Up @@ -3407,14 +3407,13 @@ def get_task_stream(self, comm=None, start=None, stop=None, count=None):
return ts.collect(start=start, stop=stop, count=count)

@gen.coroutine
def register_worker_callbacks(self, comm, setup=None):
def register_worker_plugin(self, comm, plugin, name=None):
""" Registers a setup function, and call it on every worker """
if setup is None:
raise gen.Return({})

self.worker_setups.append(setup)
self.worker_plugins.append(plugin)

responses = yield self.broadcast(msg=dict(op="run", function=setup))
responses = yield self.broadcast(
msg=dict(op="plugin-add", plugin=plugin, name=name)
)
raise gen.Return(responses)

#####################
Expand Down
7 changes: 3 additions & 4 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,6 @@ def test_startup2():
return os.getenv("MY_ENV_VALUE", None) == "WORKER_ENV_VALUE"

# Nothing has been run yet
assert len(s.worker_setups) == 0
result = yield c.run(test_import)
assert list(result.values()) == [False] * 2
result = yield c.run(test_startup2)
Expand All @@ -1327,7 +1326,6 @@ def test_startup2():
# Add a preload function
response = yield c.register_worker_callbacks(setup=mystartup)
assert len(response) == 2
assert len(s.worker_setups) == 1

# Check it has been ran on existing worker
result = yield c.run(test_import)
Expand All @@ -1342,7 +1340,6 @@ def test_startup2():
# Register another preload function
response = yield c.register_worker_callbacks(setup=mystartup2)
assert len(response) == 2
assert len(s.worker_setups) == 2

# Check it has been run
result = yield c.run(test_startup2)
Expand All @@ -1356,7 +1353,9 @@ def test_startup2():
assert list(result.values()) == [True]
yield worker.close()

# Final exception test

@gen_cluster(client=True)
def test_register_worker_callbacks_err(c, s, a, b):
with pytest.raises(ZeroDivisionError):
yield c.register_worker_callbacks(setup=lambda: 1 / 0)

Expand Down
68 changes: 68 additions & 0 deletions distributed/tests/test_worker_plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from distributed.utils_test import gen_cluster
from distributed import Worker


class MyPlugin:
name = "MyPlugin"

def __init__(self, data):
self.data = data

def setup(self, worker):
assert isinstance(worker, Worker)
self.worker = worker
self.worker._my_plugin_status = "setup"
self.worker._my_plugin_data = self.data

def teardown(self, worker):
assert isinstance(worker, Worker)
self.worker._my_plugin_status = "teardown"


@gen_cluster(client=True, ncores=[])
def test_create_with_client(c, s):
yield c.register_worker_plugin(MyPlugin(123))

worker = Worker(s.address, loop=s.loop)
yield worker._start()
assert worker._my_plugin_status == "setup"
assert worker._my_plugin_data == 123

yield worker._close()
assert worker._my_plugin_status == "teardown"


@gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]})
def test_create_on_construction(c, s, a, b):
assert len(a.plugins) == len(b.plugins) == 1
assert a._my_plugin_status == "setup"
assert a._my_plugin_data == 5


@gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]})
def test_idempotence_with_name(c, s, a, b):
a._my_plugin_data = 100

yield c.register_worker_plugin(MyPlugin(5))

assert a._my_plugin_data == 100 # call above has no effect


@gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]})
def test_duplicate_with_no_name(c, s, a, b):
assert len(a.plugins) == len(b.plugins) == 1

plugin = MyPlugin(10)
plugin.name = "other-name"

yield c.register_worker_plugin(plugin)

assert len(a.plugins) == len(b.plugins) == 2

assert a._my_plugin_data == 10

yield c.register_worker_plugin(plugin)
assert len(a.plugins) == len(b.plugins) == 2

yield c.register_worker_plugin(plugin, name="foo")
assert len(a.plugins) == len(b.plugins) == 3
57 changes: 47 additions & 10 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import random
import threading
import sys
import uuid
import warnings
import weakref
import psutil
Expand Down Expand Up @@ -307,6 +308,7 @@ def __init__(
protocol=None,
dashboard_address=None,
nanny=None,
plugins=(),
low_level_profiler=dask.config.get("distributed.worker.profile.low-level"),
**kwargs
):
Expand Down Expand Up @@ -576,6 +578,7 @@ def __init__(
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
}

stream_handlers = {
Expand Down Expand Up @@ -638,6 +641,9 @@ def __init__(
)
self.periodic_callbacks["profile-cycle"] = pc

self.plugins = {}
self._pending_plugins = plugins

Worker._instances.add(self)

##################
Expand Down Expand Up @@ -763,16 +769,9 @@ def _register_with_scheduler(self):
if response["status"] != "OK":
raise ValueError("Unexpected response from register: %r" % (response,))
else:
# Retrieve eventual init functions and run them
for function_bytes in response["worker-setups"]:
setup_function = pickle.loads(function_bytes)
if has_arg(setup_function, "dask_worker"):
result = setup_function(dask_worker=self)
else:
result = setup_function()
logger.info(
"Init function %s ran: output=%s" % (setup_function, result)
)
yield [
self.plugin_add(plugin=plugin) for plugin in response["worker-plugins"]
]

logger.info(" Registered to: %26s", self.scheduler.address)
logger.info("-" * 49)
Expand Down Expand Up @@ -968,6 +967,9 @@ def _start(self, addr_or_port=0):

setproctitle("dask-worker [%s]" % self.address)

yield [self.plugin_add(plugin=plugin) for plugin in self._pending_plugins]
self._pending_plugins = ()

yield self._register_with_scheduler()

self.start_periodic_callbacks()
Expand Down Expand Up @@ -998,6 +1000,12 @@ def close(self, report=True, timeout=10, nanny=True, executor_wait=True):
self.status = "closing"
setproctitle("dask-worker [closing]")

yield [
plugin.teardown(self)
for plugin in self.plugins.values()
if hasattr(plugin, "teardown")
]

self.stop()
for pc in self.periodic_callbacks.values():
pc.stop()
Expand Down Expand Up @@ -2206,6 +2214,35 @@ def run(self, comm, function, args=(), wait=True, kwargs=None):
def run_coroutine(self, comm, function, args=(), kwargs=None, wait=True):
return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait)

@gen.coroutine
def plugin_add(self, comm=None, plugin=None, name=None):
with log_errors(pdb=False):
if isinstance(plugin, bytes):
plugin = pickle.loads(plugin)
if not name:
if hasattr(plugin, "name"):
name = plugin.name
else:
name = funcname(plugin) + "-" + str(uuid.uuid4())

assert name

if name in self.plugins:
return {"status": "repeat"}
else:
self.plugins[name] = plugin

logger.info("Starting Worker plugin %s" % name)
try:
result = plugin.setup(worker=self)
if isinstance(result, gen.Future):
result = yield result
except Exception as e:
msg = error_message(e)
return msg
else:
return {"status": "OK"}

@gen.coroutine
def actor_execute(self, comm=None, actor=None, function=None, args=(), kwargs={}):
separate_thread = kwargs.pop("separate_thread", True)
Expand Down

0 comments on commit 6134c75

Please sign in to comment.