Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checklist for 0.7 release #1082

Merged
merged 14 commits into from
Jul 11, 2021
24 changes: 12 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,22 +237,22 @@ conda install -c conda-forge numpyro
- Provide the `rng_key` argument to `numpyro.sample`. e.g. `numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0))`.
- Wrap the code in a `seed` handler, used either as a context manager or as a function that wraps over the original callable. e.g.

```python
with handlers.seed(rng_seed=0): # random.PRNGKey(0) is used
x = numpyro.sample('x', dist.Beta(1, 1)) # uses a PRNGKey split from random.PRNGKey(0)
y = numpyro.sample('y', dist.Bernoulli(x)) # uses different PRNGKey split from the last one
```
```python
with handlers.seed(rng_seed=0): # random.PRNGKey(0) is used
x = numpyro.sample('x', dist.Beta(1, 1)) # uses a PRNGKey split from random.PRNGKey(0)
y = numpyro.sample('y', dist.Bernoulli(x)) # uses different PRNGKey split from the last one
```
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This indentation is an attempt to make python code rendered correctly here. This only happens on readthedocs, I can't reproduce in my system.


, or as a higher order function:

```python
def fn():
x = numpyro.sample('x', dist.Beta(1, 1))
y = numpyro.sample('y', dist.Bernoulli(x))
return y
```python
def fn():
x = numpyro.sample('x', dist.Beta(1, 1))
y = numpyro.sample('y', dist.Bernoulli(x))
return y

print(handlers.seed(fn, rng_seed=0)())
```
print(handlers.seed(fn, rng_seed=0)())
```

2. Can I use the same Pyro model for doing inference in NumPyro?

Expand Down
2 changes: 1 addition & 1 deletion docker/dev/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04
# note that this image uses Python 3.8
ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04 \
# declare the cuda version for pulling appropriate jaxlib wheel
JAXLIB_CUDA=112
JAXLIB_CUDA=111
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no longer CUDA 112 version because the 111 built will work for the remaining versions in the 11x series.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for claryfying!


# install python3 and pip on top of the base Ubuntu image
# unlike for release, we need to install git and setuptools too
Expand Down
13 changes: 2 additions & 11 deletions docker/release/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@ FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04
# declare the image name
# note that this image uses Python 3.8
ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04 \
# declare what jaxlib, jax, and numpyro versions to use
# right now this is a manual process - in the future it should be automated
# if a CI/CD system is expected to pass in these arguments
# the dockerfile should be modified accordingly
JAXLIB_CUDA=112 \
JAXLIB_VERSION=0.1.62 \
JAX_VERSION=0.2.10 \
NUMPYRO_VERSION=0.6.0
JAXLIB_CUDA=111

# install python3 and pip on top of the base Ubuntu image
RUN apt update && \
Expand All @@ -26,8 +19,6 @@ ENV PATH=/root/.local/bin:$PATH

# install python packages via pip
RUN pip3 install --user \
numpyro==${NUMPYRO_VERSION} \
jax==${JAX_VERSION} \
# we pull wheels from google's api as per https://github.com/google/jax#installation
# the pre-compiled wheels that google provides work for now. This may change in the future (and necessitate building from source)
jaxlib==${JAXLIB_VERSION}+cuda${JAXLIB_CUDA} -f https://storage.googleapis.com/jax-releases/jax_releases.html
numpyro[cuda${JAXLIB_CUDA}] -f https://storage.googleapis.com/jax-releases/jax_releases.html
12 changes: 6 additions & 6 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
dm-haiku
flax
funsor
jax>=0.1.65
jaxlib>=0.1.45
jaxns==0.0.7
optax==0.0.6
funsor>=0.4.1
jax>=0.2.11
jaxlib>=0.1.62
jaxns>=0.0.7
optax>=0.0.6
nbsphinx>=0.8.5
sphinx-gallery
tfp-nightly<=0.14.0.dev20210608 # TODO: change this to tensorflow-probability when it is stable
tensorflow_probability>=0.13
tqdm
48 changes: 0 additions & 48 deletions docs/source/api.rst

This file was deleted.

3 changes: 3 additions & 0 deletions docs/source/contrib.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
Contributed Code
================

Nested Sampling
~~~~~~~~~~~~~~~

Expand Down
Loading