Skip to content

Commit

Permalink
enable progressbar for multi-gpu (pyro-ppl#1849)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdipper authored Aug 18, 2024
1 parent d61f15c commit b19a83d
Showing 1 changed file with 0 additions and 8 deletions.
8 changes: 0 additions & 8 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,14 +325,6 @@ def fori_collect(
init_val_transformed = transform(init_val)
start_idx = lower + (upper - lower) % thinning
num_chains = progbar_opts.pop("num_chains", 1)
# host_callback does not work yet with multi-GPU platforms
# See: https://github.com/google/jax/issues/6447
if num_chains > 1 and jax.default_backend() == "gpu":
warnings.warn(
"We will disable progress bar because it does not work yet on multi-GPUs platforms.",
stacklevel=find_stack_level(),
)
progbar = False

@partial(maybe_jit, donate_argnums=2)
@cached_by(fori_collect, body_fun, transform)
Expand Down

0 comments on commit b19a83d

Please sign in to comment.