Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't pile up context_meter callbacks #7961

Merged
merged 1 commit into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 44 additions & 13 deletions distributed/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ class ContextMeter:
A->B comms: network-write 0.567 seconds
"""

_callbacks: ContextVar[list[Callable[[Hashable, float, str], None]]]
_callbacks: ContextVar[dict[Hashable, Callable[[Hashable, float, str], None]]]

def __init__(self):
self._callbacks = ContextVar(f"MetricHook<{id(self)}>._callbacks", default=[])
self._callbacks = ContextVar(f"MetricHook<{id(self)}>._callbacks", default={})

def __reduce__(self):
assert self is context_meter, "Found copy of singleton"
Expand All @@ -204,13 +204,28 @@ def _unpickle_singleton():

@contextmanager
def add_callback(
self, callback: Callable[[Hashable, float, str], None]
self,
callback: Callable[[Hashable, float, str], None],
*,
key: Hashable | None = None,
) -> Iterator[None]:
"""Add a callback when entering the context and remove it when exiting it.
The callback must accept the same parameters as :meth:`digest_metric`.

Parameters
----------
callback: Callable
``f(label, value, unit)`` to be executed
key: Hashable, optional
Unique key for the callback. If two nested calls to ``add_callback`` use the
same key, suppress the outermost callback.
"""
if key is None:
key = object()
cbs = self._callbacks.get()
tok = self._callbacks.set(cbs + [callback])
cbs = cbs.copy()
cbs[key] = callback
tok = self._callbacks.set(cbs)
try:
yield
finally:
Expand All @@ -221,7 +236,7 @@ def digest_metric(self, label: Hashable, value: float, unit: str) -> None:
metric.
"""
cbs = self._callbacks.get()
for cb in cbs:
for cb in cbs.values():
cb(label, value, unit)

@contextmanager
Expand All @@ -234,9 +249,10 @@ def meter(
) -> Iterator[MeterOutput]:
"""Convenience context manager or decorator which calls func() before and after
the wrapped code, calculates the delta, and finally calls :meth:`digest_metric`.
It also subtracts any other calls to :meth:`meter` or :meth:`digest_metric` with
the same unit performed within the context, so that the total is strictly
additive.

If unit=='seconds', it also subtracts any other calls to :meth:`meter` or
:meth:`digest_metric` with the same unit performed within the context, so that
the total is strictly additive.

Parameters
----------
Expand All @@ -256,10 +272,19 @@ def meter(
nested calls to :meth:`meter`, then delta (for seconds only) is reduced by the
inner metrics, to a minimum of ``floor``.
"""
if unit != "seconds":
try:
with meter(func, floor=floor) as m:
yield m
finally:
self.digest_metric(label, m.delta, unit)
return

# If unit=="seconds", subtract time metered from the sub-contexts
offsets = []

def callback(label2: Hashable, value2: float, unit2: str) -> None:
if unit2 == unit == "seconds":
if unit2 == unit:
# This must be threadsafe to support callbacks invoked from
# distributed.utils.offload; '+=' on a float would not be threadsafe!
offsets.append(value2)
Expand Down Expand Up @@ -316,14 +341,20 @@ def __init__(self, func: Callable[[], float] = timemod.perf_counter):
self.start = func()
self.metrics = []

def _callback(self, label: Hashable, value: float, unit: str) -> None:
self.metrics.append((label, value, unit))

@contextmanager
def record(self) -> Iterator[None]:
def record(self, *, key: Hashable | None = None) -> Iterator[None]:
"""Ingest metrics logged with :meth:`ContextMeter.digest_metric` or
:meth:`ContextMeter.meter` and temporarily store them in :ivar:`metrics`.

Parameters
----------
key: Hashable, optional
See :meth:`ContextMeter.add_callback`
"""
with context_meter.add_callback(
lambda label, value, unit: self.metrics.append((label, value, unit))
):
with context_meter.add_callback(self._callback, key=key):
yield

def finalize(
Expand Down
65 changes: 57 additions & 8 deletions distributed/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,25 +77,34 @@ def test_meter_floor(kwargs, delta):


def test_context_meter():
it = iter([123, 124])
it = iter([123, 124, 125, 126])
cbs = []

with metrics.context_meter.add_callback(lambda l, v, u: cbs.append((l, v, u))):
with metrics.context_meter.meter("m1", func=lambda: next(it)) as m:
assert m.start == 123
assert math.isnan(m.stop)
assert math.isnan(m.delta)
with metrics.context_meter.meter("m1", func=lambda: next(it)) as m1:
assert m1.start == 123
assert math.isnan(m1.stop)
assert math.isnan(m1.delta)
with metrics.context_meter.meter("m2", func=lambda: next(it), unit="foo") as m2:
assert m2.start == 125
assert math.isnan(m2.stop)
assert math.isnan(m2.delta)

metrics.context_meter.digest_metric("m1", 2, "seconds")
metrics.context_meter.digest_metric("m1", 1, "foo")

# Not recorded - out of context
metrics.context_meter.digest_metric("m1", 123, "foo")

assert m.start == 123
assert m.stop == 124
assert m.delta == 1
assert m1.start == 123
assert m1.stop == 124
assert m1.delta == 1
assert m2.start == 125
assert m2.stop == 126
assert m2.delta == 1
assert cbs == [
("m1", 1, "seconds"),
("m2", 1, "foo"),
("m1", 2, "seconds"),
("m1", 1, "foo"),
]
Expand Down Expand Up @@ -199,3 +208,43 @@ def test_delayed_metrics_ledger():
("foo", 10, "bytes"),
("other", 20, "seconds"),
]


def test_context_meter_keyed():
cbs = []

def cb(tag, key):
return metrics.context_meter.add_callback(
lambda l, v, u: cbs.append((tag, l)), key=key
)

with cb("x", key="x"), cb("y", key="y"):
metrics.context_meter.digest_metric("l1", 1, "u")
with cb("z", key="x"):
metrics.context_meter.digest_metric("l2", 2, "u")
metrics.context_meter.digest_metric("l3", 3, "u")

assert cbs == [
("x", "l1"),
("y", "l1"),
("z", "l2"),
("y", "l2"),
("x", "l3"),
("y", "l3"),
]


def test_delayed_metrics_ledger_keyed():
l1 = metrics.DelayedMetricsLedger()
l2 = metrics.DelayedMetricsLedger()
l3 = metrics.DelayedMetricsLedger()

with l1.record(key="x"), l2.record(key="y"):
metrics.context_meter.digest_metric("l1", 1, "u")
with l3.record(key="x"):
metrics.context_meter.digest_metric("l2", 2, "u")
metrics.context_meter.digest_metric("l3", 3, "u")

assert l1.metrics == [("l1", 1, "u"), ("l3", 3, "u")]
assert l2.metrics == [("l1", 1, "u"), ("l2", 2, "u"), ("l3", 3, "u")]
assert l3.metrics == [("l2", 2, "u")]
19 changes: 19 additions & 0 deletions distributed/tests/test_worker_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,3 +618,22 @@ async def test_new_metrics_during_heartbeat(c, s, a):
assert a.digests_total["execute", span.id, "x", "test", "test"] == n
assert s.cumulative_worker_metrics["execute", "x", "test", "test"] == n
assert span.cumulative_worker_metrics["execute", "x", "test", "test"] == n


@gen_cluster(
client=True,
nthreads=[("", 1)],
config={"distributed.scheduler.worker-saturation": float("inf")},
)
async def test_delayed_ledger_is_not_reentrant(c, s, a):
"""https://github.com/dask/distributed/issues/7949

Test that, when there's a long chain of task done -> task start events,
the callbacks added by the delayed ledger don't pile up on top of each other.
"""

def f(_):
return len(context_meter._callbacks.get())

out = await c.gather(c.map(f, range(1000)))
assert max(out) < 10
9 changes: 6 additions & 3 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3621,7 +3621,7 @@ def _start_async_instruction( # type: ignore[valid-type]

@wraps(func)
async def wrapper() -> StateMachineEvent:
with ledger.record():
with ledger.record(key="async-instruction"):
return await func(*args, **kwargs)

task = asyncio.create_task(wrapper(), name=task_name)
Expand Down Expand Up @@ -3650,8 +3650,11 @@ def _finish_async_instruction(
logger.exception("async instruction handlers should never raise!")
raise

with ledger.record():
# Capture metric events in _transition_to_memory()
# Capture metric events in _transition_to_memory()
# As this may trigger calls to _start_async_instruction for more tasks,
# make sure we don't endlessly pile up context_meter callbacks by specifying
# the same key as in _start_async_instruction.
with ledger.record(key="async-instruction"):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an aside, this whole incident would have been prevented by

self.handle_stimulus(stim)

self._finalize_metrics(stim, ledger, span_id)
Expand Down