From f874e2674bf67dc68453a81b1a0260f0d2722888 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Mon, 12 Aug 2024 12:59:52 -0700 Subject: [PATCH] enable progressbar for multi-gpu --- numpyro/util.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/numpyro/util.py b/numpyro/util.py index 3744c282c..be8b46b0d 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -325,14 +325,6 @@ def fori_collect( init_val_transformed = transform(init_val) start_idx = lower + (upper - lower) % thinning num_chains = progbar_opts.pop("num_chains", 1) - # host_callback does not work yet with multi-GPU platforms - # See: https://github.com/google/jax/issues/6447 - if num_chains > 1 and jax.default_backend() == "gpu": - warnings.warn( - "We will disable progress bar because it does not work yet on multi-GPUs platforms.", - stacklevel=find_stack_level(), - ) - progbar = False @partial(maybe_jit, donate_argnums=2) @cached_by(fori_collect, body_fun, transform)