From b106f6f664b306e4f0f256011e479ea6e62138da Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 25 Sep 2021 10:48:38 -0400 Subject: [PATCH 1/3] allow pickle ravelpytree --- numpyro/infer/autoguide.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 4a89c76a5..7090db1a0 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -4,6 +4,7 @@ # Adapted from pyro.infer.autoguide from abc import ABC, abstractmethod from contextlib import ExitStack +from functools import partial import warnings import numpy as np @@ -450,6 +451,29 @@ def median(self, params): return locs +def unravel_pytree(x_flat, treedef): + assert jnp.shape(x_flat) == 1 + x = {} + curr_pos = next_pos = 0 + for name, shape in treedef.items(): + next_pos = curr_pos + np.prod(shape) + x[name] = x_flat[curr_pos:next_pos].reshape(shape) + curr_pos = next_pos + assert next_pos == x_flat.shape[0] + return x + + +def ravel_pytree(x): + assert isinstance(x, dict) + treedef = {} + x_flat = [] + for name, value in x.items(): + treedef[name] = jnp.shape(value) + x_flat.append(value.reshape(-1)) + x_flat = jnp.concatenate(x_flat) if x_flat else jnp.zeros((0,)) + return x_flat, treedef + + class AutoContinuous(AutoGuide): """ Base class for implementations of continuous-valued Automatic @@ -474,7 +498,8 @@ class AutoContinuous(AutoGuide): def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) - self._init_latent, unpack_latent = ravel_pytree(self._init_locs) + self._init_latent, treedef = ravel_pytree(self._init_locs) + unpack_latent = partial(unravel_pytree, treedef=treedef) # this is to match the behavior of Pyro, where we can apply # unpack_latent for a batch of samples self._unpack_latent = UnpackTransform(unpack_latent) From f89244746b1b689eb1b464e8853fe55f2d52dcb9 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 25 Sep 2021 16:32:12 -0400 Subject: [PATCH 2/3] make pickle work and add tests --- numpyro/distributions/transforms.py | 10 +++++++++ numpyro/infer/autoguide.py | 31 +++++++++++++++++---------- test/test_pickle.py | 33 ++++++++++++++++++++++++++--- 3 files changed, 60 insertions(+), 14 deletions(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 443a5b8fc..1fffe42ac 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -106,6 +106,16 @@ def inverse_shape(self, shape): """ return shape + # Allow for pickle serialization of transforms. + def __getstate__(self): + attrs = {} + for k, v in self.__dict__.items(): + if isinstance(v, weakref.ref): + attrs[k] = None + else: + attrs[k] = v + return attrs + class _InverseTransform(Transform): def __init__(self, transform): diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 7090db1a0..01be6640e 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -12,7 +12,6 @@ import jax from jax import grad, hessian, lax, random, tree_map from jax.experimental import stax -from jax.flatten_util import ravel_pytree import jax.numpy as jnp import numpyro @@ -105,6 +104,11 @@ def _create_plates(self, *args, **kwargs): ) return self.plates + def __getstate__(self): + state = self.__dict__.copy() + state.pop("plates", None) + return state + @abstractmethod def __call__(self, *args, **kwargs): """ @@ -451,27 +455,32 @@ def median(self, params): return locs -def unravel_pytree(x_flat, treedef): - assert jnp.shape(x_flat) == 1 +def _unravel_dict(x_flat, shape_dict): + """Return `x` from the flatten version `x_flat`. Shape information + of each item in `x` is defined in `shape_dict`. + """ + assert jnp.ndim(x_flat) == 1 + assert isinstance(shape_dict, dict) x = {} curr_pos = next_pos = 0 - for name, shape in treedef.items(): - next_pos = curr_pos + np.prod(shape) + for name, shape in shape_dict.items(): + next_pos = curr_pos + int(np.prod(shape)) x[name] = x_flat[curr_pos:next_pos].reshape(shape) curr_pos = next_pos assert next_pos == x_flat.shape[0] return x -def ravel_pytree(x): +def _ravel_dict(x): + """Return the flatten version of `x` and shapes of each item in `x`.""" assert isinstance(x, dict) - treedef = {} + shape_dict = {} x_flat = [] for name, value in x.items(): - treedef[name] = jnp.shape(value) + shape_dict[name] = jnp.shape(value) x_flat.append(value.reshape(-1)) x_flat = jnp.concatenate(x_flat) if x_flat else jnp.zeros((0,)) - return x_flat, treedef + return x_flat, shape_dict class AutoContinuous(AutoGuide): @@ -498,8 +507,8 @@ class AutoContinuous(AutoGuide): def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) - self._init_latent, treedef = ravel_pytree(self._init_locs) - unpack_latent = partial(unravel_pytree, treedef=treedef) + self._init_latent, shape_dict = _ravel_dict(self._init_locs) + unpack_latent = partial(_unravel_dict, shape_dict=shape_dict) # this is to match the behavior of Pyro, where we can apply # unpack_latent for a batch of samples self._unpack_latent = UnpackTransform(unpack_latent) diff --git a/test/test_pickle.py b/test/test_pickle.py index bf0d79793..b4f20c1ef 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -3,6 +3,7 @@ import pickle +import numpy as np import pytest from jax import random, test_util @@ -11,15 +12,18 @@ import numpyro import numpyro.distributions as dist from numpyro.infer import ( + BarkerMH, + DiscreteHMCGibbs, HMC, HMCECS, MCMC, + MixedHMC, NUTS, + Predictive, SA, - BarkerMH, - DiscreteHMCGibbs, - MixedHMC, + SVI, ) +from numpyro.infer.autoguide import AutoDiagonalNormal, AutoDelta, AutoNormal def normal_model(): @@ -59,3 +63,26 @@ def test_pickle_hmcecs(): mcmc.run(random.PRNGKey(0)) pickled_mcmc = pickle.loads(pickle.dumps(mcmc)) test_util.check_close(mcmc.get_samples(), pickled_mcmc.get_samples()) + + +def poisson_regression(x, N): + rate = numpyro.sample('param', dist.Gamma(1., 1.)) + batch_size = len(x) if x is not None else None + with numpyro.plate('batch', N, batch_size): + numpyro.sample('x', dist.Poisson(rate), obs=x) + +@pytest.mark.parametrize("guide_class", [AutoDelta, AutoDiagonalNormal, AutoNormal]) +def test_pickle_autoguide(guide_class): + x = np.random.poisson(1.0, size=(100,)) + d = 2 + + guide = guide_class(poisson_regression) + optim = numpyro.optim.Adam(1e-2) + svi = SVI(poisson_regression, guide, optim, numpyro.infer.Trace_ELBO()) + svi_result = svi.run(random.PRNGKey(1), 3, x, len(x)) + pickled_guide = pickle.loads(pickle.dumps(guide)) + + predictive = Predictive(poisson_regression, guide=pickled_guide, params=svi_result.params, + num_samples=1, return_sites=['param', 'x']) + samples = predictive(random.PRNGKey(1), None, 1) + assert set(samples.keys()) == {'param', 'x'} From 3f97f2c9a211b2837131d82b45b5b81d00192cee Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 25 Sep 2021 16:33:15 -0400 Subject: [PATCH 3/3] fix lint --- test/test_pickle.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/test/test_pickle.py b/test/test_pickle.py index b4f20c1ef..f5ee768c2 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -12,18 +12,18 @@ import numpyro import numpyro.distributions as dist from numpyro.infer import ( - BarkerMH, - DiscreteHMCGibbs, HMC, HMCECS, MCMC, - MixedHMC, NUTS, - Predictive, SA, SVI, + BarkerMH, + DiscreteHMCGibbs, + MixedHMC, + Predictive, ) -from numpyro.infer.autoguide import AutoDiagonalNormal, AutoDelta, AutoNormal +from numpyro.infer.autoguide import AutoDelta, AutoDiagonalNormal, AutoNormal def normal_model(): @@ -66,15 +66,15 @@ def test_pickle_hmcecs(): def poisson_regression(x, N): - rate = numpyro.sample('param', dist.Gamma(1., 1.)) + rate = numpyro.sample("param", dist.Gamma(1.0, 1.0)) batch_size = len(x) if x is not None else None - with numpyro.plate('batch', N, batch_size): - numpyro.sample('x', dist.Poisson(rate), obs=x) + with numpyro.plate("batch", N, batch_size): + numpyro.sample("x", dist.Poisson(rate), obs=x) + @pytest.mark.parametrize("guide_class", [AutoDelta, AutoDiagonalNormal, AutoNormal]) def test_pickle_autoguide(guide_class): x = np.random.poisson(1.0, size=(100,)) - d = 2 guide = guide_class(poisson_regression) optim = numpyro.optim.Adam(1e-2) @@ -82,7 +82,12 @@ def test_pickle_autoguide(guide_class): svi_result = svi.run(random.PRNGKey(1), 3, x, len(x)) pickled_guide = pickle.loads(pickle.dumps(guide)) - predictive = Predictive(poisson_regression, guide=pickled_guide, params=svi_result.params, - num_samples=1, return_sites=['param', 'x']) + predictive = Predictive( + poisson_regression, + guide=pickled_guide, + params=svi_result.params, + num_samples=1, + return_sites=["param", "x"], + ) samples = predictive(random.PRNGKey(1), None, 1) - assert set(samples.keys()) == {'param', 'x'} + assert set(samples.keys()) == {"param", "x"}