You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Everthing works fine if using GPU for sampling. However, if I try to sample on the CPU by os.environ["JAX_PLATFORM_NAME"] = "cpu" and set cores to >1, as soon as I set progressbar=True, the following error occurs:
File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:588, in _raise_if_using_outfeed_with_pjrt_c_api(backend)
[586](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:586) """Should be called whenever outfeed (or infeed) will be used."""
[587](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:587) if xb.using_pjrt_c_api(backend):
--> [588](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:588) raise NotImplementedError(
[589](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:589) "host_callback functionality isn't supported with PJRT C API. "
[590](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:590) "See https://jax.readthedocs.io/en/latest/debugging/index.html and "
[591](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:591) "https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html"
[592](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:592) " for alternatives. Please file a feature request at "
[593](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:593) "https://github.com/google/jax/issues if none of the alternatives are "
[594](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:594) "sufficient.")
NotImplementedError: host_callback functionality isn't supported with PJRT C API. See https://jax.readthedocs.io/en/latest/debugging/index.html and https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html for alternatives. Please file a feature request at https://github.com/google/jax/issues if none of the alternatives are sufficient.
Reproduceable code example:
importosos.environ["JAX_PLATFORM_NAME"] ="cpu"importnumpyasnpimportpandasaspdimportarvizasazimportpymcaspmimportpytensor.tensorasptRANDOM_SEED=8927rng=np.random.default_rng(RANDOM_SEED)
az.style.use("arviz-darkgrid")
size=200true_intercept=1true_slope=2x=np.linspace(0, 1, size)
# y = a + b*xtrue_regression_line=true_intercept+true_slope*x# add noisey=true_regression_line+rng.normal(scale=0.5, size=size)
data=pd.DataFrame(dict(x=x, y=y))
withpm.Model() asmodel: # model specifications in PyMC are wrapped in a with-statement# Define priorssigma=pm.HalfCauchy("sigma", beta=10)
intercept=pm.Normal("Intercept", 0, sigma=20)
slope=pm.Normal("slope", 0, sigma=20)
# Define likelihoodlikelihood=pm.Normal("y", mu=intercept+slope*x, sigma=sigma, observed=y)
withmodel:
trace=pm.sample(
2000,
tune=1000,
chains=8,
cores=4,
random_seed=RANDOM_SEED,
nuts_sampler="numpyro",
)
Error message:
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
File /mnt/Nova/scripts/error_example.py:3
1 # %%
2 with model:
----> 3 trace = pm.sample(
4 2000,
5 tune=1000,
6 chains=8,
7 cores=4,
8 random_seed=RANDOM_SEED,
9 nuts_sampler="numpyro",
10 )
File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/pymc/sampling/mcmc.py:725, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
720 raise ValueError(
721 "Model can not be sampled with NUTS alone. Your model is probably not continuous."
722 )
724 with joined_blas_limiter():
--> 725 return _sample_external_nuts(
726 sampler=nuts_sampler,
727 draws=draws,
728 tune=tune,
729 chains=chains,
730 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
731 random_seed=random_seed,
732 initvals=initvals,
733 model=model,
734 var_names=var_names,
735 progressbar=progressbar,
736 idata_kwargs=idata_kwargs,
737 compute_convergence_checks=compute_convergence_checks,
738 nuts_sampler_kwargs=nuts_sampler_kwargs,
739 **kwargs,
740 )
742 if isinstance(step, list):
743 step = CompoundStep(step)
File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/pymc/sampling/mcmc.py:356, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
353 elif sampler in ("numpyro", "blackjax"):
354 import pymc.sampling.jax as pymc_jax
--> 356 idata = pymc_jax.sample_jax_nuts(
357 draws=draws,
358 tune=tune,
359 chains=chains,
360 target_accept=target_accept,
361 random_seed=random_seed,
362 initvals=initvals,
363 model=model,
364 var_names=var_names,
365 progressbar=progressbar,
366 nuts_sampler=sampler,
367 idata_kwargs=idata_kwargs,
368 compute_convergence_checks=compute_convergence_checks,
369 **nuts_sampler_kwargs,
370 )
371 return idata
373 else:
File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/pymc/sampling/jax.py:625, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
622 raise ValueError(f"{nuts_sampler=} not recognized")
624 tic1 = datetime.now()
--> 625 raw_mcmc_samples, sample_stats, library = sampler_fn(
626 model=model,
627 target_accept=target_accept,
628 tune=tune,
629 draws=draws,
630 chains=chains,
631 chain_method=chain_method,
632 progressbar=progressbar,
633 random_seed=random_seed,
634 initial_points=initial_points,
635 nuts_kwargs=nuts_kwargs,
636 )
637 tic2 = datetime.now()
639 if idata_kwargs is None:
File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/pymc/sampling/jax.py:465, in _sample_numpyro_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
462 if chains > 1:
463 map_seed = jax.random.split(map_seed, chains)
--> 465 pmap_numpyro.run(
466 map_seed,
467 init_params=initial_points,
468 extra_fields=(
469 "num_steps",
470 "potential_energy",
471 "energy",
472 "adapt_state.step_size",
473 "accept_prob",
474 "diverging",
475 ),
476 )
478 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
479 sample_stats = _numpyro_stats_to_dict(pmap_numpyro)
File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/numpyro/infer/mcmc.py:688, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
686 states, last_state = _laxmap(partial_map_fn, map_args)
687 elif self.chain_method == "parallel":
--> 688 states, last_state = pmap(partial_map_fn)(map_args)
689 elif callable(self.chain_method):
690 states, last_state = self.chain_method(partial_map_fn)(map_args)
[... skipping hidden 11 frame]
File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/numpyro/infer/mcmc.py:467, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites)
461 collection_size = self._collection_params["collection_size"]
462 collection_size = (
463 collection_size
464 if collection_size is None
465 else collection_size // self.thinning
466 )
--> 467 collect_vals = fori_collect(
468 lower_idx,
469 upper_idx,
470 sample_fn,
471 init_val,
472 transform=_collect_fn(collect_fields, remove_sites),
473 progbar=self.progress_bar,
474 return_last_val=True,
475 thinning=self.thinning,
476 collection_size=collection_size,
477 progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase),
478 diagnostics_fn=diagnostics,
479 num_chains=self.num_chains
480 if (callable(self.chain_method) or self.chain_method == "parallel")
481 else 1,
482 )
483 states, last_val = collect_vals
484 # Get first argument of type `HMCState`
File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/numpyro/util.py:384, in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
379 def loop_fn(collection):
380 return fori_loop(
381 0, upper, _body_fn_pbar, (init_val, collection, start_idx, thinning)
382 )
--> 384 last_val, collection, _, _ = maybe_jit(loop_fn, donate_argnums=0)(collection)
386 else:
387 diagnostics_fn = progbar_opts.pop("diagnostics_fn", None)
[... skipping hidden 11 frame]
File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/numpyro/util.py:380, in fori_collect.<locals>.loop_fn(collection)
379 def loop_fn(collection):
--> 380 return fori_loop(
381 0, upper, _body_fn_pbar, (init_val, collection, start_idx, thinning)
382 )
File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/numpyro/util.py:147, in fori_loop(lower, upper, body_fun, init_val)
145 return val
146 else:
--> 147 return lax.fori_loop(lower, upper, body_fun, init_val)
[... skipping hidden 12 frame]
File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/numpyro/util.py:265, in progress_bar_factory.<locals>.progress_bar_fori_loop.<locals>.wrapper_progress_bar(i, vals)
263 def wrapper_progress_bar(i, vals):
264 result = func(i, vals)
--> 265 _update_progress_bar(i + 1)
266 return result
File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/numpyro/util.py:232, in progress_bar_factory.<locals>._update_progress_bar(iter_num)
227 def _update_progress_bar(iter_num):
228 """Updates tqdm progress bar of a JAX loop only if the iteration number is a multiple of the print_rate 229 Usage: carry = progress_bar((iter_num, print_rate), carry) 230 """
--> 232 _ = lax.cond(
233 iter_num == 1,
234 lambda _: host_callback.id_tap(
235 _update_tqdm, 0, result=iter_num, tap_with_device=True
236 ),
237 lambda _: iter_num,
238 operand=None,
239 )
240 _ = lax.cond(
241 iter_num % print_rate == 0,
242 lambda _: host_callback.id_tap(
(...)
246 operand=None,
247 )
248 _ = lax.cond(
249 iter_num == num_samples,
250 lambda _: host_callback.id_tap(
(...)
254 operand=None,
255 )
[... skipping hidden 12 frame]
File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/numpyro/util.py:234, in progress_bar_factory.<locals>._update_progress_bar.<locals>.<lambda>(_)
227 def _update_progress_bar(iter_num):
228 """Updates tqdm progress bar of a JAX loop only if the iteration number is a multiple of the print_rate 229 Usage: carry = progress_bar((iter_num, print_rate), carry) 230 """
232 _ = lax.cond(
233 iter_num == 1,
--> 234 lambda _: host_callback.id_tap(
235 _update_tqdm, 0, result=iter_num, tap_with_device=True
236 ),
237 lambda _: iter_num,
238 operand=None,
239 )
240 _ = lax.cond(
241 iter_num % print_rate == 0,
242 lambda _: host_callback.id_tap(
(...)
246 operand=None,
247 )
248 _ = lax.cond(
249 iter_num == num_samples,
250 lambda _: host_callback.id_tap(
(...)
254 operand=None,
255 )
File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:683, in _deprecated_id_tap(tap_func, arg, result, tap_with_device, device_index, callback_flavor, **kwargs)
680 forrin flat_results:
681 dispatch.check_arg(r)
--> 683 call_res = _call(
684 tap_func,
685 arg,
686 call_with_device=tap_with_device,
687 result_shape=None,
688 identity=True,
689 device_index=device_index,
690 callback_flavor=callback_flavor)
692 if result is not None:
693 return result
File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:865, in _call(callback_func, arg, result_shape, call_with_device, device_index, identity, callback_flavor)
855 def _call(callback_func: Callable,
856 arg,
857 *,
(...)
861 identity=False,
862 callback_flavor=CallbackFlavor.IO_CALLBACK):
863 if _HOST_CALLBACK_LEGACY.value:
864 # Lazy initialization
--> 865 _initialize_outfeed_receiver(
866 max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value)
867 api.check_callable(callback_func)
868 flat_args, arg_treedef = tree_util.tree_flatten(arg)
File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:1886, in _initialize_outfeed_receiver(max_callback_queue_size_bytes)
1884 clients_with_outfeed = [c forcin clients if _use_outfeed(c.platform)]
1885 forclientin clients_with_outfeed:
-> 1886 _raise_if_using_outfeed_with_pjrt_c_api(client)
1887 if clients_with_outfeed:
1888 devices_with_outfeed = list(
1889 itertools.chain(*[backend.local_devices() forbackendin clients_with_outfeed]))
File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:588, in _raise_if_using_outfeed_with_pjrt_c_api(backend)
586 """Should be called whenever outfeed (or infeed) will be used."""
587 if xb.using_pjrt_c_api(backend):
--> 588 raise NotImplementedError(
589 "host_callback functionality isn't supported with PJRT C API. "
590 "See https://jax.readthedocs.io/en/latest/debugging/index.html and "
591 "https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html"
592 " for alternatives. Please file a feature request at "
593 "https://github.com/google/jax/issues if none of the alternatives are "
594 "sufficient.")
NotImplementedError: host_callback functionality isn't supported with PJRT C API. See https://jax.readthedocs.io/en/latest/debugging/index.html and https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html for alternatives. Please file a feature request at https://github.com/google/jax/issues if none of the alternatives are sufficient.
Some models sample faster on the CPU than on the GPU, so it is helpful to be able to switch back and forth between devices. Not being able to use the progress bar on CPU significantly affects how fast I can detect bad model specifications early in sampling.
I am not sure whether this is related to Blackjax progress bar not working, but since it is a Jax error, maybe it is also related?
The text was updated successfully, but these errors were encountered:
Describe the issue:
Everthing works fine if using GPU for sampling. However, if I try to sample on the CPU by
os.environ["JAX_PLATFORM_NAME"] = "cpu"
and set cores to >1, as soon as I setprogressbar=True
, the following error occurs:Reproduceable code example:
Error message:
PyMC version information:
pymc version: 5.16.2 (conda-forge)
jax version: 0.4.30 (cuda 12)
numpyro version: 0.15.1
Python: 3.11.6
OS: openSUSE Tumbleweed 6.9.9-1-default
conda (libmamba solver) env file:
Context for the issue:
Some models sample faster on the CPU than on the GPU, so it is helpful to be able to switch back and forth between devices. Not being able to use the progress bar on CPU significantly affects how fast I can detect bad model specifications early in sampling.
I am not sure whether this is related to Blackjax progress bar not working, but since it is a Jax error, maybe it is also related?
The text was updated successfully, but these errors were encountered: