diff --git a/distributed/client.py b/distributed/client.py index 695c9dd39f2..b0f3f148331 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -51,7 +51,7 @@ connect, rpc, ) -from .diagnostics.plugin import UploadFile, WorkerPlugin +from .diagnostics.plugin import UploadFile, WorkerPlugin, _get_worker_plugin_name from .metrics import time from .protocol import to_serialize from .protocol.pickle import dumps, loads @@ -4013,6 +4013,7 @@ def register_worker_plugin(self, plugin=None, name=None, **kwargs): name : str, optional A name for the plugin. Registering a plugin with the same name will have no effect. + If plugin has no name attribute a random name is used. **kwargs : optional If you pass a class as the plugin, instead of a class instance, then the class will be instantiated with any extra keyword arguments. @@ -4049,12 +4050,66 @@ class will be instantiated with any extra keyword arguments. See Also -------- distributed.WorkerPlugin + unregister_worker_plugin """ if isinstance(plugin, type): plugin = plugin(**kwargs) + if name is None: + name = _get_worker_plugin_name(plugin) + + assert name + return self.sync(self._register_worker_plugin, plugin=plugin, name=name) + async def _unregister_worker_plugin(self, name): + responses = await self.scheduler.unregister_worker_plugin(name=name) + + for response in responses.values(): + if response["status"] == "error": + exc = response["exception"] + tb = response["traceback"] + raise exc.with_traceback(tb) + return responses + + def unregister_worker_plugin(self, name): + """Unregisters a lifecycle worker plugin + + This unregisters an existing worker plugin. As part of the unregistration process + the plugin's ``teardown`` method will be called. + + Parameters + ---------- + name : str + Name of the plugin to unregister. See the :meth:`Client.register_worker_plugin` + docstring for more information. + + Examples + -------- + >>> class MyPlugin(WorkerPlugin): + ... def __init__(self, *args, **kwargs): + ... pass # the constructor is up to you + ... def setup(self, worker: dask.distributed.Worker): + ... pass + ... def teardown(self, worker: dask.distributed.Worker): + ... pass + ... def transition(self, key: str, start: str, finish: str, **kwargs): + ... pass + ... def release_key(self, key: str, state: str, cause: Optional[str], reason: None, report: bool): + ... pass + ... def release_dep(self, dep: str, state: str, report: bool): + ... pass + + >>> plugin = MyPlugin(1, 2, 3) + >>> client.register_worker_plugin(plugin, name='foo') + >>> client.unregister_worker_plugin(name='foo') + + See Also + -------- + register_worker_plugin + """ + return self.sync(self._unregister_worker_plugin, name=name) + class _WorkerSetupPlugin(WorkerPlugin): """ This is used to support older setup functions as callbacks """ diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 58b5adbc585..80e1736ac28 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -3,6 +3,9 @@ import socket import subprocess import sys +import uuid + +from dask.utils import funcname logger = logging.getLogger(__name__) @@ -186,6 +189,15 @@ def release_dep(self, dep, state, report): """ +def _get_worker_plugin_name(plugin) -> str: + """Returns the worker plugin name. If plugin has no name attribute + a random name is used.""" + if hasattr(plugin, "name"): + return plugin.name + else: + return funcname(plugin) + "-" + str(uuid.uuid4()) + + class PipInstall(WorkerPlugin): """A Worker Plugin to pip install a set of packages diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py index 858f8feedc2..393311b5d7a 100644 --- a/distributed/diagnostics/tests/test_worker_plugin.py +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -53,6 +53,40 @@ async def test_create_with_client(c, s): assert worker._my_plugin_status == "teardown" +@gen_cluster(client=True, nthreads=[]) +async def test_remove_with_client(c, s): + await c.register_worker_plugin(MyPlugin(123), name="foo") + await c.register_worker_plugin(MyPlugin(546), name="bar") + + worker = await Worker(s.address, loop=s.loop) + # remove the 'foo' plugin + await c.unregister_worker_plugin("foo") + assert worker._my_plugin_status == "teardown" + + # check that on the scheduler registered worker plugins we only have 'bar' + assert len(s.worker_plugins) == 1 + assert "bar" in s.worker_plugins + + # check on the worker plugins that we only have 'bar' + assert len(worker.plugins) == 1 + assert "bar" in worker.plugins + + # let's remove 'bar' and we should have none worker plugins + await c.unregister_worker_plugin("bar") + assert worker._my_plugin_status == "teardown" + assert not s.worker_plugins + assert not worker.plugins + + +@gen_cluster(client=True, nthreads=[]) +async def test_remove_with_client_raises(c, s): + await c.register_worker_plugin(MyPlugin(123), name="foo") + + worker = await Worker(s.address, loop=s.loop) + with pytest.raises(ValueError, match="bar"): + await c.unregister_worker_plugin("bar") + + @gen_cluster(client=True, nthreads=[]) async def test_create_with_client_and_plugin_from_class(c, s): await c.register_worker_plugin(MyPlugin, data=456) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 0f68d02b876..fb9610c53d5 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3458,7 +3458,7 @@ def __init__( ) ) self.event_counts = defaultdict(int) - self.worker_plugins = [] + self.worker_plugins = dict() worker_handlers = { "task-finished": self.handle_task_finished, @@ -3525,6 +3525,7 @@ def __init__( "get_task_status": self.get_task_status, "get_task_stream": self.get_task_stream, "register_worker_plugin": self.register_worker_plugin, + "unregister_worker_plugin": self.unregister_worker_plugin, "adaptive_target": self.adaptive_target, "workers_to_close": self.workers_to_close, "subscribe_worker_status": self.subscribe_worker_status, @@ -6308,13 +6309,23 @@ def stop_task_metadata(self, comm=None, name=None): async def register_worker_plugin(self, comm, plugin, name=None): """ Registers a setup function, and call it on every worker """ - self.worker_plugins.append({"plugin": plugin, "name": name}) + self.worker_plugins[name] = plugin responses = await self.broadcast( msg=dict(op="plugin-add", plugin=plugin, name=name) ) return responses + async def unregister_worker_plugin(self, comm, name): + """ Unregisters a worker plugin""" + try: + worker_plugins = self.worker_plugins.pop(name) + except KeyError: + raise ValueError(f"The worker plugin {name} does not exists") + + responses = await self.broadcast(msg=dict(op="plugin-remove", name=name)) + return responses + def transition(self, key, finish: str, *args, **kwargs): """Transition a key from its current state to the finish state diff --git a/distributed/worker.py b/distributed/worker.py index eb87d797a86..bb641e3ff1f 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -7,7 +7,6 @@ import random import sys import threading -import uuid import warnings import weakref from collections import defaultdict, deque, namedtuple @@ -41,6 +40,7 @@ pingpong, send_recv, ) +from .diagnostics.plugin import _get_worker_plugin_name from .diskutils import WorkSpace from .http import get_handlers from .metrics import time @@ -680,6 +680,7 @@ def __init__( "actor_execute": self.actor_execute, "actor_attribute": self.actor_attribute, "plugin-add": self.plugin_add, + "plugin-remove": self.plugin_remove, "get_monitor_info": self.get_monitor_info, } @@ -921,8 +922,8 @@ async def _register_with_scheduler(self): else: await asyncio.gather( *[ - self.plugin_add(**plugin_kwargs) - for plugin_kwargs in response["worker-plugins"] + self.plugin_add(name=name, plugin=plugin) + for name, plugin in response["worker-plugins"].items() ] ) @@ -2551,11 +2552,9 @@ async def plugin_add(self, comm=None, plugin=None, name=None): with log_errors(pdb=False): if isinstance(plugin, bytes): plugin = pickle.loads(plugin) - if not name: - if hasattr(plugin, "name"): - name = plugin.name - else: - name = funcname(plugin) + "-" + str(uuid.uuid4()) + + if name is None: + name = _get_worker_plugin_name(plugin) assert name @@ -2576,6 +2575,21 @@ async def plugin_add(self, comm=None, plugin=None, name=None): return {"status": "OK"} + async def plugin_remove(self, comm=None, name=None): + with log_errors(pdb=False): + logger.info(f"Removing Worker plugin {name}") + try: + plugin = self.plugins.pop(name) + if hasattr(plugin, "teardown"): + result = plugin.teardown(worker=self) + if isawaitable(result): + result = await result + except Exception as e: + msg = error_message(e) + return msg + + return {"status": "OK"} + async def actor_execute( self, comm=None, actor=None, function=None, args=(), kwargs={} ):