Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pmap seems to drastically improve performance in the example notebook #251

Closed
elanmart opened this issue Jul 10, 2022 · 6 comments · Fixed by #273
Closed

pmap seems to drastically improve performance in the example notebook #251

elanmart opened this issue Jul 10, 2022 · 6 comments · Fixed by #273
Labels
documentation Improvements or additions to documentation enhancement New feature or request

Comments

@elanmart
Copy link
Contributor

Description

I am running blackjax in WSL2 on a 32 core CPU.

I was playing with the example notebook (https://blackjax-devs.github.io/blackjax/examples/Introduction.html#),
and noticed that the CPU utilization is actually quite low when running multiple chains.

I have modified the code by first running

import numpyro as npr
npr.util.set_host_device_count(32)

Then I re-used the inference loop from the single-chain example, but instead of using vmap I used pmap to parallelize the execution:

rng_key = jax.random.PRNGKey(0)

keys = jax.random.split(rng_key, num_chains)
inference_loop = jax.pmap(
    inference_loop, in_axes=(0, 0, None, None), static_broadcasted_argnums=(2, 3)
)

states = inference_loop(keys, initial_states, nuts.step, 1_000)

And this seems to cut the running from 2 minutes to 3 seconds

# vmap
Wall time: 2min 10s
# pmap
Wall time: 2.91 s

Am I doing something wrong here, or should the example actually be adjusted to use pmap?

Reproducing

See full notebooks here:
https://gist.github.com/elanmart/810f1964738b0ddd8f108b17b7969f82

Setup

Python implementation: CPython
Python version       : 3.9.12
IPython version      : 8.4.0

jax     : 0.3.14
jaxlib  : 0.3.14
blackjax: 0.8.2

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 5.10.102.1-microsoft-standard-WSL2
Machine     : x86_64
Processor   : x86_64
CPU cores   : 32
Architecture: 64bit
@junpenglao
Copy link
Member

It is expected that pmap is faster than vmap using NUTS: in vmap at each sampling step, we are waiting for the chain with the longest leapfrog. This could potentially explain the CPU utilization as most of the time the other chains are finished one sample and just waiting for few chains with large number of leapfrog
But in this case the speed differences is pretty huge, likely because of the poor performance of an un-tuned NUTS.

@junpenglao
Copy link
Member

junpenglao commented Jul 11, 2022

And yes I think it is a great idea to add a pmap example! I think we dont have a lot of those currently.

@rlouf
Copy link
Member

rlouf commented Jul 11, 2022

I agree with what @junpenglao said: I would expect that to happen with NUTS; each step can only be as fast as the slowest chain and these delays can add up to quite a lot after a few thousand steps. As opposed to pmap which will run the chains completely independently.

Would you like to add an example with pmap at the end of this notebook with a short explanation of the difference @elanmart ?

@elanmart
Copy link
Contributor Author

elanmart commented Jul 11, 2022

Thanks a lot for the explanation! I'll be happy to open a PR adding a small section with pmap

@rlouf rlouf added documentation Improvements or additions to documentation enhancement New feature or request labels Jul 12, 2022
@elanmart
Copy link
Contributor Author

Apologies, I forgot about this issue, I'll get to it this week and open a PR once #243 is merged.

@rlouf
Copy link
Member

rlouf commented Aug 29, 2022

No problem! Thank you for letting us know.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants