Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Enable pmap progress bar with cpu backend / remove deprecated host_callback #1841

Merged
merged 8 commits into from
Aug 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 42 additions & 38 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import random
import re
from threading import Lock
import warnings

import numpy as np
Expand All @@ -18,7 +19,7 @@
import jax
from jax import device_put, jit, lax, vmap
from jax.core import Tracer
from jax.experimental import host_callback
from jax.experimental import io_callback
import jax.numpy as jnp

_DISABLE_CONTROL_FLOW_PRIM = False
Expand Down Expand Up @@ -201,58 +202,57 @@ def progress_bar_factory(num_samples, num_chains):

remainder = num_samples % print_rate

idx_counter = 0 # resource counter to assign chains to progress bars
tqdm_bars = {}
finished_chains = []
# lock serializes access to idx_counter since callbacks are multithreaded
# this prevents races that assign multiple chains to a progress bar
lock = Lock()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a comment for why this is needed?

Copy link
Contributor Author

@andrewdipper andrewdipper Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. Also added a lock around closing the chains. It appears it hasn't been an issue but could cause multiple closes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify the purpose of Lock in more details? I'm not familiar with its usage. What happens if we dont use it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. The locking is only around idx_counter since after a chain has a chain id there isn't access to resources that are shared across threads.

for chain in range(num_chains):
tqdm_bars[chain] = tqdm_auto(range(num_samples), position=chain)
tqdm_bars[chain].set_description("Compiling.. ", refresh=True)

def _update_tqdm(arg, transform, device):
chain_match = _CHAIN_RE.search(str(device))
assert chain_match
chain = int(chain_match.group())
def _update_tqdm(increment, chain):
increment = int(increment)
chain = int(chain)
if chain == -1:
nonlocal idx_counter
with lock:
chain = idx_counter
idx_counter += 1
tqdm_bars[chain].set_description(f"Running chain {chain}", refresh=False)
tqdm_bars[chain].update(arg)

def _close_tqdm(arg, transform, device):
chain_match = _CHAIN_RE.search(str(device))
assert chain_match
chain = int(chain_match.group())
tqdm_bars[chain].update(arg)
finished_chains.append(chain)
if len(finished_chains) == num_chains:
for chain in range(num_chains):
tqdm_bars[chain].close()

def _update_progress_bar(iter_num):
tqdm_bars[chain].update(increment)
return chain

def _close_tqdm(increment, chain):
increment = int(increment)
chain = int(chain)
tqdm_bars[chain].update(increment)
tqdm_bars[chain].close()

def _update_progress_bar(iter_num, chain):
"""Updates tqdm progress bar of a JAX loop only if the iteration number is a multiple of the print_rate
Usage: carry = progress_bar((iter_num, print_rate), carry)
"""

_ = lax.cond(
chain = lax.cond(
iter_num == 1,
lambda _: host_callback.id_tap(
_update_tqdm, 0, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
lambda _: io_callback(_update_tqdm, jnp.array(0), 0, chain),
lambda _: chain,
operand=None,
)
_ = lax.cond(
chain = lax.cond(
iter_num % print_rate == 0,
lambda _: host_callback.id_tap(
_update_tqdm, print_rate, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
lambda _: io_callback(_update_tqdm, jnp.array(0), print_rate, chain),
lambda _: chain,
operand=None,
)
_ = lax.cond(
iter_num == num_samples,
lambda _: host_callback.id_tap(
_close_tqdm, remainder, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
lambda _: io_callback(_close_tqdm, None, remainder, chain),
lambda _: None,
operand=None,
)
return chain

def progress_bar_fori_loop(func):
"""Decorator that adds a progress bar to `body_fun` used in `lax.fori_loop`.
Expand All @@ -261,9 +261,10 @@ def progress_bar_fori_loop(func):
"""

def wrapper_progress_bar(i, vals):
result = func(i, vals)
_update_progress_bar(i + 1)
return result
(subvals, chain) = vals
result = func(i, subvals)
chain = _update_progress_bar(i + 1, chain)
return (result, chain)

return wrapper_progress_bar

Expand Down Expand Up @@ -378,8 +379,11 @@ def loop_fn(collection):

def loop_fn(collection):
return fori_loop(
0, upper, _body_fn_pbar, (init_val, collection, start_idx, thinning)
)
0,
upper,
_body_fn_pbar,
((init_val, collection, start_idx, thinning), -1), # -1 for chain id
)[0]

last_val, collection, _, _ = maybe_jit(loop_fn, donate_argnums=0)(collection)

Expand Down
Loading