-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Draft] Enable pmap progress bar with cpu backend / remove deprecated host_callback #1841
Changes from 3 commits
6f11983
ea92469
8df33ca
819416c
61c0d98
b2a2c35
d8e15e3
2cc83af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
import os | ||
import random | ||
import re | ||
from threading import Lock | ||
import warnings | ||
|
||
import numpy as np | ||
|
@@ -18,7 +19,7 @@ | |
import jax | ||
from jax import device_put, jit, lax, vmap | ||
from jax.core import Tracer | ||
from jax.experimental import host_callback | ||
from jax.experimental import io_callback | ||
import jax.numpy as jnp | ||
|
||
_DISABLE_CONTROL_FLOW_PRIM = False | ||
|
@@ -201,24 +202,40 @@ def progress_bar_factory(num_samples, num_chains): | |
|
||
remainder = num_samples % print_rate | ||
|
||
idx_map = {} | ||
tqdm_bars = {} | ||
finished_chains = [] | ||
lock = Lock() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you add a comment for why this is needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. Also added a lock around closing the chains. It appears it hasn't been an issue but could cause multiple closes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you clarify the purpose of Lock in more details? I'm not familiar with its usage. What happens if we dont use it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated. The locking is only around |
||
for chain in range(num_chains): | ||
tqdm_bars[chain] = tqdm_auto(range(num_samples), position=chain) | ||
tqdm_bars[chain].set_description("Compiling.. ", refresh=True) | ||
|
||
def _update_tqdm(arg, transform, device): | ||
chain_match = _CHAIN_RE.search(str(device)) | ||
assert chain_match | ||
chain = int(chain_match.group()) | ||
def _calc_chain_idx(iter_num): | ||
with lock: | ||
try: | ||
idx = idx_map[iter_num] | ||
except KeyError: | ||
idx = 0 | ||
idx_map[iter_num] = 0 | ||
|
||
if idx + 1 == num_chains: | ||
del idx_map[iter_num] | ||
else: | ||
idx_map[iter_num] += 1 | ||
return idx | ||
|
||
def _update_tqdm(iter_num, increment): | ||
iter_num = int(iter_num) | ||
increment = int(increment) | ||
chain = _calc_chain_idx(iter_num) | ||
tqdm_bars[chain].set_description(f"Running chain {chain}", refresh=False) | ||
tqdm_bars[chain].update(arg) | ||
tqdm_bars[chain].update(increment) | ||
|
||
def _close_tqdm(arg, transform, device): | ||
chain_match = _CHAIN_RE.search(str(device)) | ||
assert chain_match | ||
chain = int(chain_match.group()) | ||
tqdm_bars[chain].update(arg) | ||
def _close_tqdm(iter_num, increment): | ||
iter_num = int(iter_num) | ||
increment = int(increment) | ||
chain = _calc_chain_idx(iter_num + 1) # +1 so no collision in idx_map | ||
tqdm_bars[chain].update(increment) | ||
finished_chains.append(chain) | ||
if len(finished_chains) == num_chains: | ||
for chain in range(num_chains): | ||
|
@@ -231,26 +248,20 @@ def _update_progress_bar(iter_num): | |
|
||
_ = lax.cond( | ||
iter_num == 1, | ||
lambda _: host_callback.id_tap( | ||
_update_tqdm, 0, result=iter_num, tap_with_device=True | ||
), | ||
lambda _: iter_num, | ||
lambda _: io_callback(_update_tqdm, None, -1, 0), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm not sure why -1, 0 is used here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The -1 was a mistake, changed to 0,0. The second 0 is consistent with the current implementation. I believe this is 0 instead on 1 so that subsequent calls using print_rate and remainder add to the total. But it does seem to be off by 1 |
||
lambda _: None, | ||
operand=None, | ||
) | ||
_ = lax.cond( | ||
iter_num % print_rate == 0, | ||
lambda _: host_callback.id_tap( | ||
_update_tqdm, print_rate, result=iter_num, tap_with_device=True | ||
), | ||
lambda _: iter_num, | ||
lambda _: io_callback(_update_tqdm, None, iter_num, print_rate), | ||
lambda _: None, | ||
operand=None, | ||
) | ||
_ = lax.cond( | ||
iter_num == num_samples, | ||
lambda _: host_callback.id_tap( | ||
_close_tqdm, remainder, result=iter_num, tap_with_device=True | ||
), | ||
lambda _: iter_num, | ||
lambda _: io_callback(_close_tqdm, None, iter_num, remainder), | ||
lambda _: None, | ||
operand=None, | ||
) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add some comments here for how this works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added by _
calc_chain_idx