Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow pickle autoguide #1169

Merged
merged 3 commits into from
Oct 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ def inverse_shape(self, shape):
"""
return shape

# Allow for pickle serialization of transforms.
Copy link
Member Author

@fehiepsi fehiepsi Sep 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This follows Pyro patch for PyTorch transforms.

def __getstate__(self):
attrs = {}
for k, v in self.__dict__.items():
if isinstance(v, weakref.ref):
attrs[k] = None
else:
attrs[k] = v
return attrs


class _InverseTransform(Transform):
def __init__(self, transform):
Expand Down
38 changes: 36 additions & 2 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
# Adapted from pyro.infer.autoguide
from abc import ABC, abstractmethod
from contextlib import ExitStack
from functools import partial
import warnings

import numpy as np

import jax
from jax import grad, hessian, lax, random, tree_map
from jax.experimental import stax
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp

import numpyro
Expand Down Expand Up @@ -104,6 +104,11 @@ def _create_plates(self, *args, **kwargs):
)
return self.plates

def __getstate__(self):
state = self.__dict__.copy()
state.pop("plates", None)
Copy link
Member Author

@fehiepsi fehiepsi Sep 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plates is generated during execution, so we don't need to cache. Actually, we typically call the guide under jit, so plates will hold abstract values which cannot be pickled.

return state

@abstractmethod
def __call__(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -450,6 +455,34 @@ def median(self, params):
return locs


def _unravel_dict(x_flat, shape_dict):
"""Return `x` from the flatten version `x_flat`. Shape information
of each item in `x` is defined in `shape_dict`.
"""
assert jnp.ndim(x_flat) == 1
assert isinstance(shape_dict, dict)
x = {}
curr_pos = next_pos = 0
for name, shape in shape_dict.items():
next_pos = curr_pos + int(np.prod(shape))
x[name] = x_flat[curr_pos:next_pos].reshape(shape)
curr_pos = next_pos
assert next_pos == x_flat.shape[0]
return x


def _ravel_dict(x):
"""Return the flatten version of `x` and shapes of each item in `x`."""
assert isinstance(x, dict)
shape_dict = {}
x_flat = []
for name, value in x.items():
shape_dict[name] = jnp.shape(value)
x_flat.append(value.reshape(-1))
x_flat = jnp.concatenate(x_flat) if x_flat else jnp.zeros((0,))
return x_flat, shape_dict


class AutoContinuous(AutoGuide):
"""
Base class for implementations of continuous-valued Automatic
Expand All @@ -474,7 +507,8 @@ class AutoContinuous(AutoGuide):

def _setup_prototype(self, *args, **kwargs):
super()._setup_prototype(*args, **kwargs)
self._init_latent, unpack_latent = ravel_pytree(self._init_locs)
self._init_latent, shape_dict = _ravel_dict(self._init_locs)
unpack_latent = partial(_unravel_dict, shape_dict=shape_dict)
# this is to match the behavior of Pyro, where we can apply
# unpack_latent for a batch of samples
self._unpack_latent = UnpackTransform(unpack_latent)
Expand Down
32 changes: 32 additions & 0 deletions test/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pickle

import numpy as np
import pytest

from jax import random, test_util
Expand All @@ -16,10 +17,13 @@
MCMC,
NUTS,
SA,
SVI,
BarkerMH,
DiscreteHMCGibbs,
MixedHMC,
Predictive,
)
from numpyro.infer.autoguide import AutoDelta, AutoDiagonalNormal, AutoNormal


def normal_model():
Expand Down Expand Up @@ -59,3 +63,31 @@ def test_pickle_hmcecs():
mcmc.run(random.PRNGKey(0))
pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
test_util.check_close(mcmc.get_samples(), pickled_mcmc.get_samples())


def poisson_regression(x, N):
rate = numpyro.sample("param", dist.Gamma(1.0, 1.0))
batch_size = len(x) if x is not None else None
with numpyro.plate("batch", N, batch_size):
numpyro.sample("x", dist.Poisson(rate), obs=x)


@pytest.mark.parametrize("guide_class", [AutoDelta, AutoDiagonalNormal, AutoNormal])
def test_pickle_autoguide(guide_class):
x = np.random.poisson(1.0, size=(100,))

guide = guide_class(poisson_regression)
optim = numpyro.optim.Adam(1e-2)
svi = SVI(poisson_regression, guide, optim, numpyro.infer.Trace_ELBO())
svi_result = svi.run(random.PRNGKey(1), 3, x, len(x))
pickled_guide = pickle.loads(pickle.dumps(guide))

predictive = Predictive(
poisson_regression,
guide=pickled_guide,
params=svi_result.params,
num_samples=1,
return_sites=["param", "x"],
)
samples = predictive(random.PRNGKey(1), None, 1)
assert set(samples.keys()) == {"param", "x"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to assert that pickled_guide returns the same samples as guide instead of just verifying that all sites are present?