-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow pickle autoguide #1169
Allow pickle autoguide #1169
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,14 +4,14 @@ | |
# Adapted from pyro.infer.autoguide | ||
from abc import ABC, abstractmethod | ||
from contextlib import ExitStack | ||
from functools import partial | ||
import warnings | ||
|
||
import numpy as np | ||
|
||
import jax | ||
from jax import grad, hessian, lax, random, tree_map | ||
from jax.experimental import stax | ||
from jax.flatten_util import ravel_pytree | ||
import jax.numpy as jnp | ||
|
||
import numpyro | ||
|
@@ -104,6 +104,11 @@ def _create_plates(self, *args, **kwargs): | |
) | ||
return self.plates | ||
|
||
def __getstate__(self): | ||
state = self.__dict__.copy() | ||
state.pop("plates", None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
return state | ||
|
||
@abstractmethod | ||
def __call__(self, *args, **kwargs): | ||
""" | ||
|
@@ -450,6 +455,34 @@ def median(self, params): | |
return locs | ||
|
||
|
||
def _unravel_dict(x_flat, shape_dict): | ||
"""Return `x` from the flatten version `x_flat`. Shape information | ||
of each item in `x` is defined in `shape_dict`. | ||
""" | ||
assert jnp.ndim(x_flat) == 1 | ||
assert isinstance(shape_dict, dict) | ||
x = {} | ||
curr_pos = next_pos = 0 | ||
for name, shape in shape_dict.items(): | ||
next_pos = curr_pos + int(np.prod(shape)) | ||
x[name] = x_flat[curr_pos:next_pos].reshape(shape) | ||
curr_pos = next_pos | ||
assert next_pos == x_flat.shape[0] | ||
return x | ||
|
||
|
||
def _ravel_dict(x): | ||
"""Return the flatten version of `x` and shapes of each item in `x`.""" | ||
assert isinstance(x, dict) | ||
shape_dict = {} | ||
x_flat = [] | ||
for name, value in x.items(): | ||
shape_dict[name] = jnp.shape(value) | ||
x_flat.append(value.reshape(-1)) | ||
x_flat = jnp.concatenate(x_flat) if x_flat else jnp.zeros((0,)) | ||
return x_flat, shape_dict | ||
|
||
|
||
class AutoContinuous(AutoGuide): | ||
""" | ||
Base class for implementations of continuous-valued Automatic | ||
|
@@ -474,7 +507,8 @@ class AutoContinuous(AutoGuide): | |
|
||
def _setup_prototype(self, *args, **kwargs): | ||
super()._setup_prototype(*args, **kwargs) | ||
self._init_latent, unpack_latent = ravel_pytree(self._init_locs) | ||
self._init_latent, shape_dict = _ravel_dict(self._init_locs) | ||
unpack_latent = partial(_unravel_dict, shape_dict=shape_dict) | ||
# this is to match the behavior of Pyro, where we can apply | ||
# unpack_latent for a batch of samples | ||
self._unpack_latent = UnpackTransform(unpack_latent) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
|
||
import pickle | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from jax import random, test_util | ||
|
@@ -16,10 +17,13 @@ | |
MCMC, | ||
NUTS, | ||
SA, | ||
SVI, | ||
BarkerMH, | ||
DiscreteHMCGibbs, | ||
MixedHMC, | ||
Predictive, | ||
) | ||
from numpyro.infer.autoguide import AutoDelta, AutoDiagonalNormal, AutoNormal | ||
|
||
|
||
def normal_model(): | ||
|
@@ -59,3 +63,31 @@ def test_pickle_hmcecs(): | |
mcmc.run(random.PRNGKey(0)) | ||
pickled_mcmc = pickle.loads(pickle.dumps(mcmc)) | ||
test_util.check_close(mcmc.get_samples(), pickled_mcmc.get_samples()) | ||
|
||
|
||
def poisson_regression(x, N): | ||
rate = numpyro.sample("param", dist.Gamma(1.0, 1.0)) | ||
batch_size = len(x) if x is not None else None | ||
with numpyro.plate("batch", N, batch_size): | ||
numpyro.sample("x", dist.Poisson(rate), obs=x) | ||
|
||
|
||
@pytest.mark.parametrize("guide_class", [AutoDelta, AutoDiagonalNormal, AutoNormal]) | ||
def test_pickle_autoguide(guide_class): | ||
x = np.random.poisson(1.0, size=(100,)) | ||
|
||
guide = guide_class(poisson_regression) | ||
optim = numpyro.optim.Adam(1e-2) | ||
svi = SVI(poisson_regression, guide, optim, numpyro.infer.Trace_ELBO()) | ||
svi_result = svi.run(random.PRNGKey(1), 3, x, len(x)) | ||
pickled_guide = pickle.loads(pickle.dumps(guide)) | ||
|
||
predictive = Predictive( | ||
poisson_regression, | ||
guide=pickled_guide, | ||
params=svi_result.params, | ||
num_samples=1, | ||
return_sites=["param", "x"], | ||
) | ||
samples = predictive(random.PRNGKey(1), None, 1) | ||
assert set(samples.keys()) == {"param", "x"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe better to assert that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This follows Pyro patch for PyTorch transforms.