Skip to content

Commit

Permalink
Unregister worker plugin (#4748)
Browse files Browse the repository at this point in the history
  • Loading branch information
ncclementi authored Apr 30, 2021
1 parent 97c66db commit 233ec88
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 11 deletions.
57 changes: 56 additions & 1 deletion distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
connect,
rpc,
)
from .diagnostics.plugin import UploadFile, WorkerPlugin
from .diagnostics.plugin import UploadFile, WorkerPlugin, _get_worker_plugin_name
from .metrics import time
from .protocol import to_serialize
from .protocol.pickle import dumps, loads
Expand Down Expand Up @@ -4013,6 +4013,7 @@ def register_worker_plugin(self, plugin=None, name=None, **kwargs):
name : str, optional
A name for the plugin.
Registering a plugin with the same name will have no effect.
If plugin has no name attribute a random name is used.
**kwargs : optional
If you pass a class as the plugin, instead of a class instance, then the
class will be instantiated with any extra keyword arguments.
Expand Down Expand Up @@ -4049,12 +4050,66 @@ class will be instantiated with any extra keyword arguments.
See Also
--------
distributed.WorkerPlugin
unregister_worker_plugin
"""
if isinstance(plugin, type):
plugin = plugin(**kwargs)

if name is None:
name = _get_worker_plugin_name(plugin)

assert name

return self.sync(self._register_worker_plugin, plugin=plugin, name=name)

async def _unregister_worker_plugin(self, name):
responses = await self.scheduler.unregister_worker_plugin(name=name)

for response in responses.values():
if response["status"] == "error":
exc = response["exception"]
tb = response["traceback"]
raise exc.with_traceback(tb)
return responses

def unregister_worker_plugin(self, name):
"""Unregisters a lifecycle worker plugin
This unregisters an existing worker plugin. As part of the unregistration process
the plugin's ``teardown`` method will be called.
Parameters
----------
name : str
Name of the plugin to unregister. See the :meth:`Client.register_worker_plugin`
docstring for more information.
Examples
--------
>>> class MyPlugin(WorkerPlugin):
... def __init__(self, *args, **kwargs):
... pass # the constructor is up to you
... def setup(self, worker: dask.distributed.Worker):
... pass
... def teardown(self, worker: dask.distributed.Worker):
... pass
... def transition(self, key: str, start: str, finish: str, **kwargs):
... pass
... def release_key(self, key: str, state: str, cause: Optional[str], reason: None, report: bool):
... pass
... def release_dep(self, dep: str, state: str, report: bool):
... pass
>>> plugin = MyPlugin(1, 2, 3)
>>> client.register_worker_plugin(plugin, name='foo')
>>> client.unregister_worker_plugin(name='foo')
See Also
--------
register_worker_plugin
"""
return self.sync(self._unregister_worker_plugin, name=name)


class _WorkerSetupPlugin(WorkerPlugin):
""" This is used to support older setup functions as callbacks """
Expand Down
12 changes: 12 additions & 0 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import socket
import subprocess
import sys
import uuid

from dask.utils import funcname

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -186,6 +189,15 @@ def release_dep(self, dep, state, report):
"""


def _get_worker_plugin_name(plugin) -> str:
"""Returns the worker plugin name. If plugin has no name attribute
a random name is used."""
if hasattr(plugin, "name"):
return plugin.name
else:
return funcname(plugin) + "-" + str(uuid.uuid4())


class PipInstall(WorkerPlugin):
"""A Worker Plugin to pip install a set of packages
Expand Down
34 changes: 34 additions & 0 deletions distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,40 @@ async def test_create_with_client(c, s):
assert worker._my_plugin_status == "teardown"


@gen_cluster(client=True, nthreads=[])
async def test_remove_with_client(c, s):
await c.register_worker_plugin(MyPlugin(123), name="foo")
await c.register_worker_plugin(MyPlugin(546), name="bar")

worker = await Worker(s.address, loop=s.loop)
# remove the 'foo' plugin
await c.unregister_worker_plugin("foo")
assert worker._my_plugin_status == "teardown"

# check that on the scheduler registered worker plugins we only have 'bar'
assert len(s.worker_plugins) == 1
assert "bar" in s.worker_plugins

# check on the worker plugins that we only have 'bar'
assert len(worker.plugins) == 1
assert "bar" in worker.plugins

# let's remove 'bar' and we should have none worker plugins
await c.unregister_worker_plugin("bar")
assert worker._my_plugin_status == "teardown"
assert not s.worker_plugins
assert not worker.plugins


@gen_cluster(client=True, nthreads=[])
async def test_remove_with_client_raises(c, s):
await c.register_worker_plugin(MyPlugin(123), name="foo")

worker = await Worker(s.address, loop=s.loop)
with pytest.raises(ValueError, match="bar"):
await c.unregister_worker_plugin("bar")


@gen_cluster(client=True, nthreads=[])
async def test_create_with_client_and_plugin_from_class(c, s):
await c.register_worker_plugin(MyPlugin, data=456)
Expand Down
15 changes: 13 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3458,7 +3458,7 @@ def __init__(
)
)
self.event_counts = defaultdict(int)
self.worker_plugins = []
self.worker_plugins = dict()

worker_handlers = {
"task-finished": self.handle_task_finished,
Expand Down Expand Up @@ -3525,6 +3525,7 @@ def __init__(
"get_task_status": self.get_task_status,
"get_task_stream": self.get_task_stream,
"register_worker_plugin": self.register_worker_plugin,
"unregister_worker_plugin": self.unregister_worker_plugin,
"adaptive_target": self.adaptive_target,
"workers_to_close": self.workers_to_close,
"subscribe_worker_status": self.subscribe_worker_status,
Expand Down Expand Up @@ -6308,13 +6309,23 @@ def stop_task_metadata(self, comm=None, name=None):

async def register_worker_plugin(self, comm, plugin, name=None):
""" Registers a setup function, and call it on every worker """
self.worker_plugins.append({"plugin": plugin, "name": name})
self.worker_plugins[name] = plugin

responses = await self.broadcast(
msg=dict(op="plugin-add", plugin=plugin, name=name)
)
return responses

async def unregister_worker_plugin(self, comm, name):
""" Unregisters a worker plugin"""
try:
worker_plugins = self.worker_plugins.pop(name)
except KeyError:
raise ValueError(f"The worker plugin {name} does not exists")

responses = await self.broadcast(msg=dict(op="plugin-remove", name=name))
return responses

def transition(self, key, finish: str, *args, **kwargs):
"""Transition a key from its current state to the finish state
Expand Down
30 changes: 22 additions & 8 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import random
import sys
import threading
import uuid
import warnings
import weakref
from collections import defaultdict, deque, namedtuple
Expand Down Expand Up @@ -41,6 +40,7 @@
pingpong,
send_recv,
)
from .diagnostics.plugin import _get_worker_plugin_name
from .diskutils import WorkSpace
from .http import get_handlers
from .metrics import time
Expand Down Expand Up @@ -680,6 +680,7 @@ def __init__(
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
}

Expand Down Expand Up @@ -921,8 +922,8 @@ async def _register_with_scheduler(self):
else:
await asyncio.gather(
*[
self.plugin_add(**plugin_kwargs)
for plugin_kwargs in response["worker-plugins"]
self.plugin_add(name=name, plugin=plugin)
for name, plugin in response["worker-plugins"].items()
]
)

Expand Down Expand Up @@ -2551,11 +2552,9 @@ async def plugin_add(self, comm=None, plugin=None, name=None):
with log_errors(pdb=False):
if isinstance(plugin, bytes):
plugin = pickle.loads(plugin)
if not name:
if hasattr(plugin, "name"):
name = plugin.name
else:
name = funcname(plugin) + "-" + str(uuid.uuid4())

if name is None:
name = _get_worker_plugin_name(plugin)

assert name

Expand All @@ -2576,6 +2575,21 @@ async def plugin_add(self, comm=None, plugin=None, name=None):

return {"status": "OK"}

async def plugin_remove(self, comm=None, name=None):
with log_errors(pdb=False):
logger.info(f"Removing Worker plugin {name}")
try:
plugin = self.plugins.pop(name)
if hasattr(plugin, "teardown"):
result = plugin.teardown(worker=self)
if isawaitable(result):
result = await result
except Exception as e:
msg = error_message(e)
return msg

return {"status": "OK"}

async def actor_execute(
self, comm=None, actor=None, function=None, args=(), kwargs={}
):
Expand Down

0 comments on commit 233ec88

Please sign in to comment.