Skip to content

Commit

Permalink
Consistent worker Client instance in get_client
Browse files Browse the repository at this point in the history
Fixes dask#4959

`get_client` was calling the private `Worker._get_client` method when it ran within a task. `_get_client` should really have been called `_make_client`, since it created a new client every time. The simplest correct thing to do instead would have been to use the `Worker.client` property, which caches this instance.

In order to pass the `timeout` parameter through though, I changed `Worker.get_client` to actually match its docstring and always return the same instance.
  • Loading branch information
gjoseph92 committed Oct 27, 2021
1 parent a1b67b8 commit fa4763b
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 50 deletions.
23 changes: 23 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
TaskStateMetadataPlugin,
_LockedCommPool,
captured_logger,
cluster,
dec,
div,
gen_cluster,
Expand Down Expand Up @@ -965,6 +966,28 @@ def f(x):
assert a._client is a_client


@gen_cluster(client=True, nthreads=[("127.0.0.1", 4)])
async def test_get_client_threadsafe(c, s, a):
def f(x):
return get_client().id

futures = c.map(f, range(100))
ids = await c.gather(futures)
assert len(set(ids)) == 1


def test_get_client_threadsafe_sync():
def f(x):
return get_client().id

with cluster(nworkers=1, worker_kwargs={"nthreads": 4}) as (scheduler, workers):
with Client(scheduler["address"]) as client:
futures = client.map(f, range(100))
ids = client.gather(futures)
assert len(set(ids)) == 1
assert set(ids) != {client.id}


def test_get_client_sync(client):
def f(x):
cc = get_client()
Expand Down
100 changes: 50 additions & 50 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def __init__(
self.has_what = defaultdict(set)
self.pending_data_per_worker = defaultdict(deque)
self.nanny = nanny
self._lock = threading.Lock()
self._client_lock = threading.Lock()

self.data_needed = []

Expand Down Expand Up @@ -3554,11 +3554,7 @@ def validate_state(self):

@property
def client(self) -> Client:
with self._lock:
if self._client:
return self._client
else:
return self._get_client()
return self._get_client()

def _get_client(self, timeout=None) -> Client:
"""Get local client attached to this worker
Expand All @@ -3569,56 +3565,60 @@ def _get_client(self, timeout=None) -> Client:
--------
get_client
"""
with self._client_lock:
if self._client:
return self._client

if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")

timeout = parse_timedelta(timeout, "s")
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")

try:
from .client import default_client
timeout = parse_timedelta(timeout, "s")

client = default_client()
except ValueError: # no clients found, need to make a new one
pass
else:
# must be lazy import otherwise cyclic import
from distributed.deploy.cluster import Cluster
try:
from .client import default_client

if (
client.scheduler
and client.scheduler.address == self.scheduler.address
# The below conditions should only happen in case a second
# cluster is alive, e.g. if a submitted task spawned its onwn
# LocalCluster, see gh4565
or (
isinstance(client._start_arg, str)
and client._start_arg == self.scheduler.address
or isinstance(client._start_arg, Cluster)
and client._start_arg.scheduler_address == self.scheduler.address
client = default_client()
except ValueError: # no clients found, need to make a new one
pass
else:
# must be lazy import otherwise cyclic import
from distributed.deploy.cluster import Cluster

if (
client.scheduler
and client.scheduler.address == self.scheduler.address
# The below conditions should only happen in case a second
# cluster is alive, e.g. if a submitted task spawned its onwn
# LocalCluster, see gh4565
or (
isinstance(client._start_arg, str)
and client._start_arg == self.scheduler.address
or isinstance(client._start_arg, Cluster)
and client._start_arg.scheduler_address
== self.scheduler.address
)
):
self._client = client

if not self._client:
from .client import Client

asynchronous = self.loop is IOLoop.current()
self._client = Client(
self.scheduler,
loop=self.loop,
security=self.security,
set_as_default=True,
asynchronous=asynchronous,
direct_to_workers=True,
name="worker",
timeout=timeout,
)
):
self._client = client

if not self._client:
from .client import Client

asynchronous = self.loop is IOLoop.current()
self._client = Client(
self.scheduler,
loop=self.loop,
security=self.security,
set_as_default=True,
asynchronous=asynchronous,
direct_to_workers=True,
name="worker",
timeout=timeout,
)
Worker._initialized_clients.add(self._client)
if not asynchronous:
assert self._client.status == "running"
Worker._initialized_clients.add(self._client)
if not asynchronous:
assert self._client.status == "running"

return self._client
return self._client

def get_current_task(self):
"""Get the key of the task we are currently running
Expand Down

0 comments on commit fa4763b

Please sign in to comment.