Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: numpyro sampler cannot use progress bar when running on CPU #7426

Open
Mithrillion opened this issue Jul 24, 2024 · 0 comments
Open

BUG: numpyro sampler cannot use progress bar when running on CPU #7426

Mithrillion opened this issue Jul 24, 2024 · 0 comments
Labels

Comments

@Mithrillion
Copy link

Mithrillion commented Jul 24, 2024

Describe the issue:

Everthing works fine if using GPU for sampling. However, if I try to sample on the CPU by os.environ["JAX_PLATFORM_NAME"] = "cpu" and set cores to >1, as soon as I set progressbar=True, the following error occurs:

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:588, in _raise_if_using_outfeed_with_pjrt_c_api(backend)
    [586](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:586) """Should be called whenever outfeed (or infeed) will be used."""
    [587](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:587) if xb.using_pjrt_c_api(backend):
--> [588](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:588)   raise NotImplementedError(
    [589](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:589)       "host_callback functionality isn't supported with PJRT C API. "
    [590](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:590)       "See https://jax.readthedocs.io/en/latest/debugging/index.html and "
    [591](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:591)       "https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html"
    [592](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:592)       " for alternatives. Please file a feature request at "
    [593](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:593)       "https://github.com/google/jax/issues if none of the alternatives are "
    [594](https://file+.vscode-resource.vscode-cdn.net/mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:594)       "sufficient.")

NotImplementedError: host_callback functionality isn't supported with PJRT C API. See https://jax.readthedocs.io/en/latest/debugging/index.html and https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html for alternatives. Please file a feature request at https://github.com/google/jax/issues if none of the alternatives are sufficient.

Reproduceable code example:

import os

os.environ["JAX_PLATFORM_NAME"] = "cpu"

import numpy as np
import pandas as pd
import arviz as az
import pymc as pm
import pytensor.tensor as pt

RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)
az.style.use("arviz-darkgrid")

size = 200
true_intercept = 1
true_slope = 2

x = np.linspace(0, 1, size)
# y = a + b*x
true_regression_line = true_intercept + true_slope * x
# add noise
y = true_regression_line + rng.normal(scale=0.5, size=size)

data = pd.DataFrame(dict(x=x, y=y))


with pm.Model() as model:  # model specifications in PyMC are wrapped in a with-statement
    # Define priors
    sigma = pm.HalfCauchy("sigma", beta=10)
    intercept = pm.Normal("Intercept", 0, sigma=20)
    slope = pm.Normal("slope", 0, sigma=20)

    # Define likelihood
    likelihood = pm.Normal("y", mu=intercept + slope * x, sigma=sigma, observed=y)

with model:
    trace = pm.sample(
        2000,
        tune=1000,
        chains=8,
        cores=4,
        random_seed=RANDOM_SEED,
        nuts_sampler="numpyro",
    )

Error message:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
File /mnt/Nova/scripts/error_example.py:3
      1 # %%
      2 with model:
----> 3     trace = pm.sample(
      4         2000,
      5         tune=1000,
      6         chains=8,
      7         cores=4,
      8         random_seed=RANDOM_SEED,
      9         nuts_sampler="numpyro",
     10     )

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/pymc/sampling/mcmc.py:725, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
    720         raise ValueError(
    721             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    722         )
    724     with joined_blas_limiter():
--> 725         return _sample_external_nuts(
    726             sampler=nuts_sampler,
    727             draws=draws,
    728             tune=tune,
    729             chains=chains,
    730             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    731             random_seed=random_seed,
    732             initvals=initvals,
    733             model=model,
    734             var_names=var_names,
    735             progressbar=progressbar,
    736             idata_kwargs=idata_kwargs,
    737             compute_convergence_checks=compute_convergence_checks,
    738             nuts_sampler_kwargs=nuts_sampler_kwargs,
    739             **kwargs,
    740         )
    742 if isinstance(step, list):
    743     step = CompoundStep(step)

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/pymc/sampling/mcmc.py:356, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
    353 elif sampler in ("numpyro", "blackjax"):
    354     import pymc.sampling.jax as pymc_jax
--> 356     idata = pymc_jax.sample_jax_nuts(
    357         draws=draws,
    358         tune=tune,
    359         chains=chains,
    360         target_accept=target_accept,
    361         random_seed=random_seed,
    362         initvals=initvals,
    363         model=model,
    364         var_names=var_names,
    365         progressbar=progressbar,
    366         nuts_sampler=sampler,
    367         idata_kwargs=idata_kwargs,
    368         compute_convergence_checks=compute_convergence_checks,
    369         **nuts_sampler_kwargs,
    370     )
    371     return idata
    373 else:

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/pymc/sampling/jax.py:625, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
    622     raise ValueError(f"{nuts_sampler=} not recognized")
    624 tic1 = datetime.now()
--> 625 raw_mcmc_samples, sample_stats, library = sampler_fn(
    626     model=model,
    627     target_accept=target_accept,
    628     tune=tune,
    629     draws=draws,
    630     chains=chains,
    631     chain_method=chain_method,
    632     progressbar=progressbar,
    633     random_seed=random_seed,
    634     initial_points=initial_points,
    635     nuts_kwargs=nuts_kwargs,
    636 )
    637 tic2 = datetime.now()
    639 if idata_kwargs is None:

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/pymc/sampling/jax.py:465, in _sample_numpyro_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
    462 if chains > 1:
    463     map_seed = jax.random.split(map_seed, chains)
--> 465 pmap_numpyro.run(
    466     map_seed,
    467     init_params=initial_points,
    468     extra_fields=(
    469         "num_steps",
    470         "potential_energy",
    471         "energy",
    472         "adapt_state.step_size",
    473         "accept_prob",
    474         "diverging",
    475     ),
    476 )
    478 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
    479 sample_stats = _numpyro_stats_to_dict(pmap_numpyro)

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/numpyro/infer/mcmc.py:688, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    686     states, last_state = _laxmap(partial_map_fn, map_args)
    687 elif self.chain_method == "parallel":
--> 688     states, last_state = pmap(partial_map_fn)(map_args)
    689 elif callable(self.chain_method):
    690     states, last_state = self.chain_method(partial_map_fn)(map_args)

    [... skipping hidden 11 frame]

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/numpyro/infer/mcmc.py:467, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites)
    461 collection_size = self._collection_params["collection_size"]
    462 collection_size = (
    463     collection_size
    464     if collection_size is None
    465     else collection_size // self.thinning
    466 )
--> 467 collect_vals = fori_collect(
    468     lower_idx,
    469     upper_idx,
    470     sample_fn,
    471     init_val,
    472     transform=_collect_fn(collect_fields, remove_sites),
    473     progbar=self.progress_bar,
    474     return_last_val=True,
    475     thinning=self.thinning,
    476     collection_size=collection_size,
    477     progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase),
    478     diagnostics_fn=diagnostics,
    479     num_chains=self.num_chains
    480     if (callable(self.chain_method) or self.chain_method == "parallel")
    481     else 1,
    482 )
    483 states, last_val = collect_vals
    484 # Get first argument of type `HMCState`

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/numpyro/util.py:384, in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
    379     def loop_fn(collection):
    380         return fori_loop(
    381             0, upper, _body_fn_pbar, (init_val, collection, start_idx, thinning)
    382         )
--> 384     last_val, collection, _, _ = maybe_jit(loop_fn, donate_argnums=0)(collection)
    386 else:
    387     diagnostics_fn = progbar_opts.pop("diagnostics_fn", None)

    [... skipping hidden 11 frame]

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/numpyro/util.py:380, in fori_collect.<locals>.loop_fn(collection)
    379 def loop_fn(collection):
--> 380     return fori_loop(
    381         0, upper, _body_fn_pbar, (init_val, collection, start_idx, thinning)
    382     )

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/numpyro/util.py:147, in fori_loop(lower, upper, body_fun, init_val)
    145     return val
    146 else:
--> 147     return lax.fori_loop(lower, upper, body_fun, init_val)

    [... skipping hidden 12 frame]

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/numpyro/util.py:265, in progress_bar_factory.<locals>.progress_bar_fori_loop.<locals>.wrapper_progress_bar(i, vals)
    263 def wrapper_progress_bar(i, vals):
    264     result = func(i, vals)
--> 265     _update_progress_bar(i + 1)
    266     return result

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/numpyro/util.py:232, in progress_bar_factory.<locals>._update_progress_bar(iter_num)
    227 def _update_progress_bar(iter_num):
    228     """Updates tqdm progress bar of a JAX loop only if the iteration number is a multiple of the print_rate
    229     Usage: carry = progress_bar((iter_num, print_rate), carry)
    230     """
--> 232     _ = lax.cond(
    233         iter_num == 1,
    234         lambda _: host_callback.id_tap(
    235             _update_tqdm, 0, result=iter_num, tap_with_device=True
    236         ),
    237         lambda _: iter_num,
    238         operand=None,
    239     )
    240     _ = lax.cond(
    241         iter_num % print_rate == 0,
    242         lambda _: host_callback.id_tap(
   (...)
    246         operand=None,
    247     )
    248     _ = lax.cond(
    249         iter_num == num_samples,
    250         lambda _: host_callback.id_tap(
   (...)
    254         operand=None,
    255     )

    [... skipping hidden 12 frame]

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/numpyro/util.py:234, in progress_bar_factory.<locals>._update_progress_bar.<locals>.<lambda>(_)
    227 def _update_progress_bar(iter_num):
    228     """Updates tqdm progress bar of a JAX loop only if the iteration number is a multiple of the print_rate
    229     Usage: carry = progress_bar((iter_num, print_rate), carry)
    230     """
    232     _ = lax.cond(
    233         iter_num == 1,
--> 234         lambda _: host_callback.id_tap(
    235             _update_tqdm, 0, result=iter_num, tap_with_device=True
    236         ),
    237         lambda _: iter_num,
    238         operand=None,
    239     )
    240     _ = lax.cond(
    241         iter_num % print_rate == 0,
    242         lambda _: host_callback.id_tap(
   (...)
    246         operand=None,
    247     )
    248     _ = lax.cond(
    249         iter_num == num_samples,
    250         lambda _: host_callback.id_tap(
   (...)
    254         operand=None,
    255     )

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:683, in _deprecated_id_tap(tap_func, arg, result, tap_with_device, device_index, callback_flavor, **kwargs)
    680   for r in flat_results:
    681     dispatch.check_arg(r)
--> 683 call_res = _call(
    684     tap_func,
    685     arg,
    686     call_with_device=tap_with_device,
    687     result_shape=None,
    688     identity=True,
    689     device_index=device_index,
    690     callback_flavor=callback_flavor)
    692 if result is not None:
    693   return result

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:865, in _call(callback_func, arg, result_shape, call_with_device, device_index, identity, callback_flavor)
    855 def _call(callback_func: Callable,
    856           arg,
    857           *,
   (...)
    861           identity=False,
    862           callback_flavor=CallbackFlavor.IO_CALLBACK):
    863   if _HOST_CALLBACK_LEGACY.value:
    864     # Lazy initialization
--> 865     _initialize_outfeed_receiver(
    866         max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value)
    867   api.check_callable(callback_func)
    868   flat_args, arg_treedef = tree_util.tree_flatten(arg)

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:1886, in _initialize_outfeed_receiver(max_callback_queue_size_bytes)
   1884 clients_with_outfeed = [c for c in clients if _use_outfeed(c.platform)]
   1885 for client in clients_with_outfeed:
-> 1886   _raise_if_using_outfeed_with_pjrt_c_api(client)
   1887 if clients_with_outfeed:
   1888   devices_with_outfeed = list(
   1889     itertools.chain(*[backend.local_devices() for backend in clients_with_outfeed]))

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/jax/experimental/host_callback.py:588, in _raise_if_using_outfeed_with_pjrt_c_api(backend)
    586 """Should be called whenever outfeed (or infeed) will be used."""
    587 if xb.using_pjrt_c_api(backend):
--> 588   raise NotImplementedError(
    589       "host_callback functionality isn't supported with PJRT C API. "
    590       "See https://jax.readthedocs.io/en/latest/debugging/index.html and "
    591       "https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html"
    592       " for alternatives. Please file a feature request at "
    593       "https://github.com/google/jax/issues if none of the alternatives are "
    594       "sufficient.")

NotImplementedError: host_callback functionality isn't supported with PJRT C API. See https://jax.readthedocs.io/en/latest/debugging/index.html and https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html for alternatives. Please file a feature request at https://github.com/google/jax/issues if none of the alternatives are sufficient.

PyMC version information:

pymc version: 5.16.2 (conda-forge)
jax version: 0.4.30 (cuda 12)
numpyro version: 0.15.1

Python: 3.11.6
OS: openSUSE Tumbleweed 6.9.9-1-default

conda (libmamba solver) env file:

name: ml_dev_cpu
channels:
  - conda-forge
  - pytorch
dependencies:
  - python>=3.11
  - pytorch::pytorch
  - pytorch::pytorch-cuda=12.1
  - cuda-version=12.1
  - cuda-nvcc=12.1
  - cxx-compiler
  - pymc>=5
  - numpyro
  - blackjax
  - nutpie
  - pip:
    - ortools
    - jax[cuda12]

Context for the issue:

Some models sample faster on the CPU than on the GPU, so it is helpful to be able to switch back and forth between devices. Not being able to use the progress bar on CPU significantly affects how fast I can detect bad model specifications early in sampling.

I am not sure whether this is related to Blackjax progress bar not working, but since it is a Jax error, maybe it is also related?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant