Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Enable pmap progress bar with cpu backend / remove deprecated host_callback #1841

Merged
merged 8 commits into from
Aug 9, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 34 additions & 23 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import random
import re
from threading import Lock
import warnings

import numpy as np
Expand All @@ -18,7 +19,7 @@
import jax
from jax import device_put, jit, lax, vmap
from jax.core import Tracer
from jax.experimental import host_callback
from jax.experimental import io_callback
import jax.numpy as jnp

_DISABLE_CONTROL_FLOW_PRIM = False
Expand Down Expand Up @@ -201,24 +202,40 @@ def progress_bar_factory(num_samples, num_chains):

remainder = num_samples % print_rate

idx_map = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add some comments here for how this works?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added by _calc_chain_idx

tqdm_bars = {}
finished_chains = []
lock = Lock()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a comment for why this is needed?

Copy link
Contributor Author

@andrewdipper andrewdipper Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. Also added a lock around closing the chains. It appears it hasn't been an issue but could cause multiple closes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify the purpose of Lock in more details? I'm not familiar with its usage. What happens if we dont use it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. The locking is only around idx_counter since after a chain has a chain id there isn't access to resources that are shared across threads.

for chain in range(num_chains):
tqdm_bars[chain] = tqdm_auto(range(num_samples), position=chain)
tqdm_bars[chain].set_description("Compiling.. ", refresh=True)

def _update_tqdm(arg, transform, device):
chain_match = _CHAIN_RE.search(str(device))
assert chain_match
chain = int(chain_match.group())
def _calc_chain_idx(iter_num):
with lock:
try:
idx = idx_map[iter_num]
except KeyError:
idx = 0
idx_map[iter_num] = 0

if idx + 1 == num_chains:
del idx_map[iter_num]
else:
idx_map[iter_num] += 1
return idx

def _update_tqdm(iter_num, increment):
iter_num = int(iter_num)
increment = int(increment)
chain = _calc_chain_idx(iter_num)
tqdm_bars[chain].set_description(f"Running chain {chain}", refresh=False)
tqdm_bars[chain].update(arg)
tqdm_bars[chain].update(increment)

def _close_tqdm(arg, transform, device):
chain_match = _CHAIN_RE.search(str(device))
assert chain_match
chain = int(chain_match.group())
tqdm_bars[chain].update(arg)
def _close_tqdm(iter_num, increment):
iter_num = int(iter_num)
increment = int(increment)
chain = _calc_chain_idx(iter_num + 1) # +1 so no collision in idx_map
tqdm_bars[chain].update(increment)
finished_chains.append(chain)
if len(finished_chains) == num_chains:
for chain in range(num_chains):
Expand All @@ -231,26 +248,20 @@ def _update_progress_bar(iter_num):

_ = lax.cond(
iter_num == 1,
lambda _: host_callback.id_tap(
_update_tqdm, 0, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
lambda _: io_callback(_update_tqdm, None, -1, 0),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure why -1, 0 is used here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The -1 was a mistake, changed to 0,0. The second 0 is consistent with the current implementation. I believe this is 0 instead on 1 so that subsequent calls using print_rate and remainder add to the total. But it does seem to be off by 1

lambda _: None,
operand=None,
)
_ = lax.cond(
iter_num % print_rate == 0,
lambda _: host_callback.id_tap(
_update_tqdm, print_rate, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
lambda _: io_callback(_update_tqdm, None, iter_num, print_rate),
lambda _: None,
operand=None,
)
_ = lax.cond(
iter_num == num_samples,
lambda _: host_callback.id_tap(
_close_tqdm, remainder, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
lambda _: io_callback(_close_tqdm, None, iter_num, remainder),
lambda _: None,
operand=None,
)

Expand Down
Loading