Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace register_worker_callbacks with worker plugins #2453

Merged
merged 10 commits into from
May 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 80 additions & 12 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
parse_timedelta,
shutting_down,
Any,
has_keyword,
)
from .versions import get_versions

Expand Down Expand Up @@ -3847,17 +3848,6 @@ def _get_task_stream(
else:
raise gen.Return(msgs)

@gen.coroutine
def _register_worker_callbacks(self, setup=None):
responses = yield self.scheduler.register_worker_callbacks(setup=dumps(setup))
results = {}
for key, resp in responses.items():
if resp["status"] == "OK":
results[key] = resp["result"]
elif resp["status"] == "error":
six.reraise(*clean_exception(**resp))
raise gen.Return(results)

def register_worker_callbacks(self, setup=None):
"""
Registers a setup callback function for all current and future workers.
Expand All @@ -3876,7 +3866,85 @@ def register_worker_callbacks(self, setup=None):
setup : callable(dask_worker: Worker) -> None
Function to register and run on all workers
"""
return self.sync(self._register_worker_callbacks, setup=setup)
return self.register_worker_plugin(_WorkerSetupPlugin(setup))

@gen.coroutine
def _register_worker_plugin(self, plugin=None, name=None):
responses = yield self.scheduler.register_worker_plugin(
plugin=dumps(plugin), name=name
)
for response in responses.values():
if response["status"] == "error":
exc = response["exception"]
typ = type(exc)
tb = response["traceback"]
six.reraise(typ, exc, tb)
raise gen.Return(responses)

def register_worker_plugin(self, plugin=None, name=None):
"""
Registers a lifecycle worker plugin for all current and future workers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question: do you have thoughts on using ..versionadded:: <version> directives? Don't need to do it here necessarily, but I ask since this is being referenced from the dask documentation, which may be on a different release cycle.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong thoughts. I've never used them much, either as an author or reader.

My plan was just to merge the doc PR after this gets released.

Happy with whatever though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see that you've since merged the doc PR :) Shouldn't be a big deal either way

This registers a new object to handle setup and teardown for workers in
this cluster. The plugin will instantiate itself on all currently
connected workers. It will also be run on any worker that connects in
the future.

The plugin should be an object with ``setup`` and ``teardown`` methods.
It must be serializable with the pickle or cloudpickle modules.

If the plugin has a ``name`` attribute, or if the ``name=`` keyword is
used then that will control idempotency. A a plugin with that name has
already registered then any future plugins will not run.

For alternatives to plugins, you may also wish to look into preload
scripts.

Parameters
----------
plugin: object
The plugin object to pass to the workers
name: str, optional
A name for the plugin.
Registering a plugin with the same name will have no effect.

Examples
--------
>>> class MyPlugin:
... def __init__(self, *args, **kwargs):
... pass # the constructor is up to you
... def setup(self, worker: dask.distributed.Worker):
... pass
... def teardown(self, worker: dask.distributed.Worker):
... pass

>>> plugin = MyPlugin(1, 2, 3)
>>> client.register_worker_plugin(plugin)

You can get access to the plugin with the ``get_worker`` function

>>> client.register_worker_plugin(other_plugin, name='my-plugin')
>>> def f():
... worker = get_worker()
... plugin = worker.plugins['my-plugin']
... return plugin.my_state

>>> future = client.run(f)
"""
return self.sync(self._register_worker_plugin, plugin=plugin, name=name)


class _WorkerSetupPlugin(object):
""" This is used to support older setup functions as callbacks """

def __init__(self, setup):
self._setup = setup

def setup(self, worker):
if has_keyword(self._setup, "dask_worker"):
return self._setup(dask_worker=worker)
else:
return self._setup()


class Executor(Client):
Expand Down
17 changes: 8 additions & 9 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,7 @@ def __init__(
self.log = deque(
maxlen=dask.config.get("distributed.scheduler.transition-log-length")
)
self.worker_setups = []
self.worker_plugins = []

worker_handlers = {
"task-finished": self.handle_task_finished,
Expand Down Expand Up @@ -1061,7 +1061,7 @@ def __init__(
"heartbeat_worker": self.heartbeat_worker,
"get_task_status": self.get_task_status,
"get_task_stream": self.get_task_stream,
"register_worker_callbacks": self.register_worker_callbacks,
"register_worker_plugin": self.register_worker_plugin,
}

self._transitions = {
Expand Down Expand Up @@ -1509,7 +1509,7 @@ def add_worker(
"status": "OK",
"time": time(),
"heartbeat-interval": heartbeat_interval(len(self.workers)),
"worker-setups": self.worker_setups,
"worker-plugins": self.worker_plugins,
}
)
yield self.handle_worker(comm=comm, worker=address)
Expand Down Expand Up @@ -3406,14 +3406,13 @@ def get_task_stream(self, comm=None, start=None, stop=None, count=None):
return ts.collect(start=start, stop=stop, count=count)

@gen.coroutine
def register_worker_callbacks(self, comm, setup=None):
def register_worker_plugin(self, comm, plugin, name=None):
""" Registers a setup function, and call it on every worker """
if setup is None:
raise gen.Return({})

self.worker_setups.append(setup)
self.worker_plugins.append(plugin)

responses = yield self.broadcast(msg=dict(op="run", function=setup))
responses = yield self.broadcast(
msg=dict(op="plugin-add", plugin=plugin, name=name)
)
raise gen.Return(responses)

#####################
Expand Down
7 changes: 3 additions & 4 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,6 @@ def test_startup2():
return os.getenv("MY_ENV_VALUE", None) == "WORKER_ENV_VALUE"

# Nothing has been run yet
assert len(s.worker_setups) == 0
result = yield c.run(test_import)
assert list(result.values()) == [False] * 2
result = yield c.run(test_startup2)
Expand All @@ -1327,7 +1326,6 @@ def test_startup2():
# Add a preload function
response = yield c.register_worker_callbacks(setup=mystartup)
assert len(response) == 2
assert len(s.worker_setups) == 1

# Check it has been ran on existing worker
result = yield c.run(test_import)
Expand All @@ -1342,7 +1340,6 @@ def test_startup2():
# Register another preload function
response = yield c.register_worker_callbacks(setup=mystartup2)
assert len(response) == 2
assert len(s.worker_setups) == 2

# Check it has been run
result = yield c.run(test_startup2)
Expand All @@ -1356,7 +1353,9 @@ def test_startup2():
assert list(result.values()) == [True]
yield worker.close()

# Final exception test

@gen_cluster(client=True)
def test_register_worker_callbacks_err(c, s, a, b):
with pytest.raises(ZeroDivisionError):
yield c.register_worker_callbacks(setup=lambda: 1 / 0)

Expand Down
68 changes: 68 additions & 0 deletions distributed/tests/test_worker_plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from distributed.utils_test import gen_cluster
from distributed import Worker


class MyPlugin:
name = "MyPlugin"

def __init__(self, data):
self.data = data

def setup(self, worker):
assert isinstance(worker, Worker)
self.worker = worker
self.worker._my_plugin_status = "setup"
self.worker._my_plugin_data = self.data

def teardown(self, worker):
assert isinstance(worker, Worker)
self.worker._my_plugin_status = "teardown"


@gen_cluster(client=True, ncores=[])
def test_create_with_client(c, s):
yield c.register_worker_plugin(MyPlugin(123))

worker = Worker(s.address, loop=s.loop)
yield worker._start()
assert worker._my_plugin_status == "setup"
assert worker._my_plugin_data == 123

yield worker._close()
assert worker._my_plugin_status == "teardown"


@gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]})
def test_create_on_construction(c, s, a, b):
assert len(a.plugins) == len(b.plugins) == 1
assert a._my_plugin_status == "setup"
assert a._my_plugin_data == 5


@gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]})
def test_idempotence_with_name(c, s, a, b):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the names of these two tests flipped (this one doesn't use plugin names, the other does?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we're good here. The MyPlugin class has a name attribute, which is what causes the idempotence here.

In the test below, we change the name intentionally.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, understood.

a._my_plugin_data = 100

yield c.register_worker_plugin(MyPlugin(5))

assert a._my_plugin_data == 100 # call above has no effect


@gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]})
def test_duplicate_with_no_name(c, s, a, b):
assert len(a.plugins) == len(b.plugins) == 1

plugin = MyPlugin(10)
plugin.name = "other-name"

yield c.register_worker_plugin(plugin)

assert len(a.plugins) == len(b.plugins) == 2

assert a._my_plugin_data == 10

yield c.register_worker_plugin(plugin)
assert len(a.plugins) == len(b.plugins) == 2

yield c.register_worker_plugin(plugin, name="foo")
assert len(a.plugins) == len(b.plugins) == 3
57 changes: 47 additions & 10 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import random
import threading
import sys
import uuid
import warnings
import weakref
import psutil
Expand Down Expand Up @@ -307,6 +308,7 @@ def __init__(
protocol=None,
dashboard_address=None,
nanny=None,
plugins=(),
low_level_profiler=dask.config.get("distributed.worker.profile.low-level"),
**kwargs
):
Expand Down Expand Up @@ -576,6 +578,7 @@ def __init__(
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
}

stream_handlers = {
Expand Down Expand Up @@ -638,6 +641,9 @@ def __init__(
)
self.periodic_callbacks["profile-cycle"] = pc

self.plugins = {}
self._pending_plugins = plugins

Worker._instances.add(self)

##################
Expand Down Expand Up @@ -763,16 +769,9 @@ def _register_with_scheduler(self):
if response["status"] != "OK":
raise ValueError("Unexpected response from register: %r" % (response,))
else:
# Retrieve eventual init functions and run them
for function_bytes in response["worker-setups"]:
setup_function = pickle.loads(function_bytes)
if has_arg(setup_function, "dask_worker"):
result = setup_function(dask_worker=self)
else:
result = setup_function()
logger.info(
"Init function %s ran: output=%s" % (setup_function, result)
)
yield [
self.plugin_add(plugin=plugin) for plugin in response["worker-plugins"]
]

logger.info(" Registered to: %26s", self.scheduler.address)
logger.info("-" * 49)
Expand Down Expand Up @@ -968,6 +967,9 @@ def _start(self, addr_or_port=0):

setproctitle("dask-worker [%s]" % self.address)

yield [self.plugin_add(plugin=plugin) for plugin in self._pending_plugins]
self._pending_plugins = ()

yield self._register_with_scheduler()

self.start_periodic_callbacks()
Expand Down Expand Up @@ -998,6 +1000,12 @@ def close(self, report=True, timeout=10, nanny=True, executor_wait=True):
self.status = "closing"
setproctitle("dask-worker [closing]")

yield [
plugin.teardown(self)
for plugin in self.plugins.values()
if hasattr(plugin, "teardown")
]

self.stop()
for pc in self.periodic_callbacks.values():
pc.stop()
Expand Down Expand Up @@ -2206,6 +2214,35 @@ def run(self, comm, function, args=(), wait=True, kwargs=None):
def run_coroutine(self, comm, function, args=(), kwargs=None, wait=True):
return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait)

@gen.coroutine
def plugin_add(self, comm=None, plugin=None, name=None):
with log_errors(pdb=False):
if isinstance(plugin, bytes):
plugin = pickle.loads(plugin)
if not name:
if hasattr(plugin, "name"):
name = plugin.name
else:
name = funcname(plugin) + "-" + str(uuid.uuid4())

assert name

if name in self.plugins:
return {"status": "repeat"}
else:
self.plugins[name] = plugin

logger.info("Starting Worker plugin %s" % name)
try:
result = plugin.setup(worker=self)
if isinstance(result, gen.Future):
result = yield result
except Exception as e:
msg = error_message(e)
return msg
else:
return {"status": "OK"}

@gen.coroutine
def actor_execute(self, comm=None, actor=None, function=None, args=(), kwargs={}):
separate_thread = kwargs.pop("separate_thread", True)
Expand Down