From a373ec7101f28240aa8bb5def053db3bc04f3d1f Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 11 Sep 2020 00:38:28 -0500 Subject: [PATCH 01/11] support subsampling --- numpyro/contrib/funsor/enum_messenger.py | 9 +++-- numpyro/handlers.py | 9 ++++- numpyro/primitives.py | 51 +++++++++++++++--------- 3 files changed, 45 insertions(+), 24 deletions(-) diff --git a/numpyro/contrib/funsor/enum_messenger.py b/numpyro/contrib/funsor/enum_messenger.py index 4d9c41c7d..9e6593a9c 100644 --- a/numpyro/contrib/funsor/enum_messenger.py +++ b/numpyro/contrib/funsor/enum_messenger.py @@ -10,7 +10,7 @@ import funsor from numpyro.handlers import trace as OrigTraceMessenger -from numpyro.primitives import CondIndepStackFrame, Messenger, apply_stack +from numpyro.primitives import CondIndepStackFrame, Messenger, apply_stack, plate as OrigPlateMessenger funsor.set_backend("jax") @@ -442,10 +442,11 @@ def __init__(self, name, size, subsample_size=None, dim=None): if dim is not None and dim >= 0: raise ValueError('dim arg must be negative.') self.dim = dim + _, indices = OrigPlateMessenger._subsample(self.name, self.size, self.subsample_size, dim) self._indices = funsor.Tensor( - funsor.ops.new_arange(funsor.tensor.get_default_prototype(), self.size), - OrderedDict([(self.name, funsor.bint(self.size))]), - self.size + indices, + OrderedDict([(self.name, funsor.bint(self.subsample_size))]), + self.subsample_size ) super(plate, self).__init__(None) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index e2affe8c8..805ca8f97 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -547,7 +547,12 @@ def __init__(self, fn=None, rng_seed=None): def process_message(self, msg): if (msg['type'] == 'sample' and not msg['is_observed'] and - msg['kwargs']['rng_key'] is None) or msg['type'] == 'control_flow': + msg['kwargs']['rng_key'] is None) or msg['type'] in ['plate', 'control_flow']: + # no need to split key if size = subsample_size + if msg['type'] == 'plate': + size, subsample_size = msg['args'] + if size == subsample_size: + return self.rng_key, rng_key_sample = random.split(self.rng_key) msg['kwargs']['rng_key'] = rng_key_sample @@ -596,7 +601,7 @@ def __init__(self, fn=None, data=None, substitute_fn=None): super(substitute, self).__init__(fn) def process_message(self, msg): - if (msg['type'] not in ('sample', 'param')) or msg.get('_control_flow_done', False): + if (msg['type'] not in ('sample', 'param', 'plate')) or msg.get('_control_flow_done', False): if msg['type'] == 'control_flow': if self.data is not None: msg['kwargs']['substitute_stack'].append(('substitute', self.data)) diff --git a/numpyro/primitives.py b/numpyro/primitives.py index ef6161a1c..d031c3642 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -5,7 +5,8 @@ from contextlib import ExitStack, contextmanager import functools -from jax import lax +from jax import lax, random +import jax.numpy as jnp import numpyro from numpyro.distributions.discrete import PRNGIdentity @@ -202,6 +203,14 @@ def module(name, nn, input_shape=None): return functools.partial(nn_apply, nn_params) +def _subsample_fn(size, subsample_size, rng_key=None): + if size == subsample_size: + return jnp.arange(size) + else: + assert rng_key is not None, "Missing random key to generate subsample indices." + return random.permutation(rng_key, size)[:subsample_size] + + class plate(Messenger): """ Construct for annotating conditionally independent variables. Within a @@ -224,38 +233,44 @@ def __init__(self, name, size, subsample_size=None, dim=None): self.subsample_size = size if subsample_size is None else subsample_size if dim is not None and dim >= 0: raise ValueError('dim arg must be negative.') - self.dim = dim - self._validate_and_set_dim() + self.dim, self._indices = self._subsample(self.name, self.size, self.subsample_size, dim) super(plate, self).__init__() - def _validate_and_set_dim(self): + # XXX: different from Pyro, this method returns dim and indices + @staticmethod + def _subsample(name, size, subsample_size, dim): msg = { 'type': 'plate', - 'fn': identity, - 'name': self.name, - 'args': (None,), - 'kwargs': {}, + 'fn': _subsample_fn, + 'name': name, + 'args': (size, subsample_size), + 'kwargs': {'rng_key': None}, 'value': None, 'scale': 1.0, 'cond_indep_stack': [], } apply_stack(msg) + subsample = msg['value'] + if subsample_size != subsample.shape[0]: + raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format( + subsample_size, len(subsample)) + + " Did you accidentally use different subsample_size in the model and guide?") cond_indep_stack = msg['cond_indep_stack'] occupied_dims = {f.dim for f in cond_indep_stack} - dim = -1 - while True: - if dim not in occupied_dims: - break - dim -= 1 - if self.dim is None: - self.dim = dim + if dim is None: + new_dim = -1 + while True: + if new_dim not in occupied_dims: + break + new_dim -= 1 + dim = new_dim else: - assert self.dim not in occupied_dims + assert dim not in occupied_dims + return dim, subsample def __enter__(self): super().__enter__() - # XXX: JAX doesn't like slice index, so we cast to list - return list(range(self.subsample_size)) + return self._indices @staticmethod def _get_batch_shape(cond_indep_stack): From 8a9dd387cacb9bfd4b438e4e2a4bc33eb4f9012e Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 11 Sep 2020 00:43:59 -0500 Subject: [PATCH 02/11] add example for subsample ind cs --- numpyro/primitives.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/numpyro/primitives.py b/numpyro/primitives.py index d031c3642..eacb173d8 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -218,6 +218,12 @@ class plate(Messenger): the size of the plate. Additionally, a scale factor might be applied by certain inference algorithms if `subsample_size` is specified. + .. note:: This can be used to subsample minibatches of data:: + + with plate("data", len(data), subsample_size=100) as ind: + batch = data[ind] + assert len(batch) == 100 + :param str name: Name of the plate. :param int size: Size of the plate. :param int subsample_size: Optional argument denoting the size of the mini-batch. From abef45aa1fbc2b616efea20e323ee0342c240b98 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 11 Sep 2020 01:39:02 -0500 Subject: [PATCH 03/11] simplify the implementation --- numpyro/handlers.py | 8 +++----- numpyro/primitives.py | 9 +++------ test/test_handlers.py | 19 +++++++++++++++++++ 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 805ca8f97..886c9b482 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -548,11 +548,9 @@ def __init__(self, fn=None, rng_seed=None): def process_message(self, msg): if (msg['type'] == 'sample' and not msg['is_observed'] and msg['kwargs']['rng_key'] is None) or msg['type'] in ['plate', 'control_flow']: - # no need to split key if size = subsample_size - if msg['type'] == 'plate': - size, subsample_size = msg['args'] - if size == subsample_size: - return + # no need to create a new key when value is available + if msg['value'] is not None: + return self.rng_key, rng_key_sample = random.split(self.rng_key) msg['kwargs']['rng_key'] = rng_key_sample diff --git a/numpyro/primitives.py b/numpyro/primitives.py index eacb173d8..091ef2658 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -204,11 +204,8 @@ def module(name, nn, input_shape=None): def _subsample_fn(size, subsample_size, rng_key=None): - if size == subsample_size: - return jnp.arange(size) - else: - assert rng_key is not None, "Missing random key to generate subsample indices." - return random.permutation(rng_key, size)[:subsample_size] + assert rng_key is not None, "Missing random key to generate subsample indices." + return random.permutation(rng_key, size)[:subsample_size] class plate(Messenger): @@ -251,7 +248,7 @@ def _subsample(name, size, subsample_size, dim): 'name': name, 'args': (size, subsample_size), 'kwargs': {'rng_key': None}, - 'value': None, + 'value': None if size != subsample_size else jnp.arange(size), 'scale': 1.0, 'cond_indep_stack': [], } diff --git a/test/test_handlers.py b/test/test_handlers.py index 859a5aa84..bbdf2c9ce 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -235,6 +235,25 @@ def test_plate(model): assert_allclose(jit_trace[name]['value'], site['value']) +def test_subsample_data(): + data = jnp.arange(100.) + subsample_size = 7 + with handlers.trace() as tr, handlers.seed(rng_seed=0): + with numpyro.plate("a", len(data), subsample_size=subsample_size) as idx: + assert data[idx].shape == (subsample_size,) + + +def test_subsample_substitute(): + data = jnp.arange(100.) + subsample_size = 7 + subsample = jnp.array([13, 3, 30, 4, 1, 68, 5]) + with handlers.trace() as tr, handlers.seed(rng_seed=0), handlers.substitute(data={"a": subsample}): + with numpyro.plate("a", len(data), subsample_size=subsample_size) as idx: + assert data[idx].shape == (subsample_size,) + assert_allclose(idx, subsample) + assert tr["a"]["kwargs"]["rng_key"] is None + + def test_messenger_fn_invalid(): with pytest.raises(ValueError, match="to be a Python callable object"): with numpyro.handlers.mask(False): From 7d3c0c86608ec98f6ee1ddbfac0a2e123b434596 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 11 Sep 2020 01:50:16 -0500 Subject: [PATCH 04/11] fix lint issue --- test/test_handlers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_handlers.py b/test/test_handlers.py index bbdf2c9ce..9422a4da4 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -238,7 +238,7 @@ def test_plate(model): def test_subsample_data(): data = jnp.arange(100.) subsample_size = 7 - with handlers.trace() as tr, handlers.seed(rng_seed=0): + with handlers.seed(rng_seed=0): with numpyro.plate("a", len(data), subsample_size=subsample_size) as idx: assert data[idx].shape == (subsample_size,) From 87c33ac7434e4e21dc61b1b70a0a97f90e970064 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 11 Sep 2020 23:56:07 -0500 Subject: [PATCH 05/11] add subsample primitive and more tests --- numpyro/__init__.py | 3 +- numpyro/contrib/funsor/enum_messenger.py | 29 +++++++-- numpyro/primitives.py | 81 ++++++++++++++++++++++-- test/contrib/test_tfp.py | 1 + test/test_handlers.py | 47 ++++++++++++++ 5 files changed, 152 insertions(+), 9 deletions(-) diff --git a/numpyro/__init__.py b/numpyro/__init__.py index a70783f4c..db2da6a29 100644 --- a/numpyro/__init__.py +++ b/numpyro/__init__.py @@ -4,7 +4,7 @@ from numpyro import compat, diagnostics, distributions, handlers, infer, optim from numpyro.distributions.distribution import enable_validation, validation_enabled import numpyro.patch # noqa: F401 -from numpyro.primitives import deterministic, factor, module, param, plate, plate_stack, sample +from numpyro.primitives import deterministic, factor, module, param, plate, plate_stack, sample, subsample from numpyro.util import enable_x64, set_host_device_count, set_platform from numpyro.version import __version__ @@ -28,6 +28,7 @@ 'plate', 'plate_stack', 'sample', + 'subsample', 'set_host_device_count', 'set_platform', 'validation_enabled', diff --git a/numpyro/contrib/funsor/enum_messenger.py b/numpyro/contrib/funsor/enum_messenger.py index 9e6593a9c..ca2dc317c 100644 --- a/numpyro/contrib/funsor/enum_messenger.py +++ b/numpyro/contrib/funsor/enum_messenger.py @@ -10,7 +10,8 @@ import funsor from numpyro.handlers import trace as OrigTraceMessenger -from numpyro.primitives import CondIndepStackFrame, Messenger, apply_stack, plate as OrigPlateMessenger +from numpyro.primitives import CondIndepStackFrame, Messenger, apply_stack +from numpyro.primitives import plate as OrigPlateMessenger funsor.set_backend("jax") @@ -438,11 +439,10 @@ class plate(GlobalNamedMessenger): def __init__(self, name, size, subsample_size=None, dim=None): self.name = name self.size = size - self.subsample_size = size if subsample_size is None else subsample_size if dim is not None and dim >= 0: raise ValueError('dim arg must be negative.') - self.dim = dim - _, indices = OrigPlateMessenger._subsample(self.name, self.size, self.subsample_size, dim) + self.dim, indices = OrigPlateMessenger._subsample(self.name, self.size, subsample_size, dim) + self.subsample_size = indices.shape[0] self._indices = funsor.Tensor( indices, OrderedDict([(self.name, funsor.bint(self.subsample_size))]), @@ -488,6 +488,27 @@ def process_message(self, msg): # copied almost verbatim from plate scale = 1. if msg['scale'] is None else msg['scale'] msg['scale'] = scale * self.size / self.subsample_size + def postprocess_message(self, msg): + if msg["type"] in ("subsample", "param") and self.dim is not None: + event_dim = msg["kwargs"].get("event_dim") + if event_dim is not None: + assert event_dim >= 0 + dim = self.dim - event_dim + shape = msg["value"].shape + if len(shape) >= -dim and shape[dim] != 1: + if shape[dim] != self.size: + if msg["type"] == "param": + statement = "numpyro.param({}, ..., event_dim={})".format(msg["name"], event_dim) + else: + statement = "numpyro.subsample(..., event_dim={})".format(event_dim) + raise ValueError( + "Inside numpyro.plate({}, {}, dim={}) invalid shape of {}: {}" + .format(self.name, self.size, self.dim, statement, shape)) + if self.subsample_size < self.size: + value = msg["value"] + new_value = jnp.take_along_axis(value, self._indices, dim) + msg["value"] = new_value + class enum(BaseEnumMessenger): """ diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 091ef2658..4696d9fd1 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -126,6 +126,14 @@ def param(name, init_value=None, **kwargs): the onus of using this to initialize the optimizer is on the user / inference algorithm, since there is no global parameter store in NumPyro. + :param constraint: NumPyro constraint, defaults to ``constraints.real``. + :type constraint: numpyro.distributions.constraints.Constraint + :param int event_dim: (optional) number of rightmost dimensions unrelated + to batching. Dimension to the left of this will be considered batch + dimensions; if the param statement is inside a subsampled plate, then + corresponding batch dimensions of the parameter will be correspondingly + subsampled. If unspecified, all dimensions will be considered event + dims and no subsampling will be performed. :return: value for the parameter. Unless wrapped inside a handler like :class:`~numpyro.handlers.substitute`, this will simply return the initial value. @@ -233,10 +241,11 @@ class plate(Messenger): def __init__(self, name, size, subsample_size=None, dim=None): self.name = name self.size = size - self.subsample_size = size if subsample_size is None else subsample_size if dim is not None and dim >= 0: raise ValueError('dim arg must be negative.') - self.dim, self._indices = self._subsample(self.name, self.size, self.subsample_size, dim) + self.dim, self._indices = self._subsample( + self.name, self.size, subsample_size, dim) + self.subsample_size = self._indices.shape[0] super(plate, self).__init__() # XXX: different from Pyro, this method returns dim and indices @@ -248,13 +257,15 @@ def _subsample(name, size, subsample_size, dim): 'name': name, 'args': (size, subsample_size), 'kwargs': {'rng_key': None}, - 'value': None if size != subsample_size else jnp.arange(size), + 'value': (None + if (subsample_size is not None and size != subsample_size) + else jnp.arange(size)), 'scale': 1.0, 'cond_indep_stack': [], } apply_stack(msg) subsample = msg['value'] - if subsample_size != subsample.shape[0]: + if subsample_size is not None and subsample_size != subsample.shape[0]: raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format( subsample_size, len(subsample)) + " Did you accidentally use different subsample_size in the model and guide?") @@ -309,6 +320,27 @@ def process_message(self, msg): scale = 1. if msg['scale'] is None else msg['scale'] msg['scale'] = scale * self.size / self.subsample_size + def postprocess_message(self, msg): + if msg["type"] in ("subsample", "param") and self.dim is not None: + event_dim = msg["kwargs"].get("event_dim") + if event_dim is not None: + assert event_dim >= 0 + dim = self.dim - event_dim + shape = jnp.shape(msg["value"]) + if len(shape) >= -dim and shape[dim] != 1: + if shape[dim] != self.size: + if msg["type"] == "param": + statement = "numpyro.param({}, ..., event_dim={})".format(msg["name"], event_dim) + else: + statement = "numpyro.subsample(..., event_dim={})".format(event_dim) + raise ValueError( + "Inside numpyro.plate({}, {}, dim={}) invalid shape of {}: {}" + .format(self.name, self.size, self.dim, statement, shape)) + if self.subsample_size < self.size: + value = msg["value"] + new_value = jnp.take(value, self._indices, dim) + msg["value"] = new_value + @contextmanager def plate_stack(prefix, sizes, rightmost_dim=-1): @@ -340,3 +372,44 @@ def factor(name, log_factor): unit_dist = numpyro.distributions.distribution.Unit(log_factor) unit_value = unit_dist.sample(None) sample(name, unit_dist, obs=unit_value) + + +def subsample(data, event_dim): + """ + EXPERIMENTAL Subsampling statement to subsample data based on enclosing + :class:`~numpyro.primitives.plate` s. + + This is typically called on arguments to ``model()`` when subsampling is + performed automatically by :class:`~numpyro.primitives.plate` s by passing + ``subsample_size`` kwarg. For example the following are equivalent:: + + # Version 1. using indexing + def model(data): + with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind: + data = data[ind] + # ... + + # Version 2. using numpyro.subsample() + def model(data): + with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()): + data = numpyro.subsample(data, event_dim=0) + # ... + + :param numpy.ndarray data: A tensor of batched data. + :param int event_dim: The event dimension of the data tensor. Dimensions to + the left are considered batch dimensions. + :returns: A subsampled version of ``data`` + :rtype: ~numpy.ndarray + """ + if not _PYRO_STACK: + return data + + assert isinstance(event_dim, int) and event_dim >= 0 + initial_msg = { + 'type': 'subsample', + 'value': data, + 'kwargs': {'event_dim': event_dim} + } + + msg = apply_stack(initial_msg) + return msg['value'] diff --git a/test/contrib/test_tfp.py b/test/contrib/test_tfp.py index 0370c9976..9078fc50a 100644 --- a/test/contrib/test_tfp.py +++ b/test/contrib/test_tfp.py @@ -41,6 +41,7 @@ def test_independent(): @pytest.mark.filterwarnings("ignore:can't resolve package") def test_transformed_distributions(): from tensorflow_probability.substrates.jax import bijectors as tfb + from numpyro.contrib.tfp import distributions as tfd d = dist.TransformedDistribution(dist.Normal(0, 1), dist.transforms.ExpTransform()) diff --git a/test/test_handlers.py b/test/test_handlers.py index 9422a4da4..f2e0603ff 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -218,6 +218,26 @@ def model_subsample_1(): assert xy.shape == (5, 1, 10) +def model_subsample_2(): + data = jnp.ones((10, 1, 20)) + outer = numpyro.plate('outer', data.shape[-1], subsample_size=10) + inner = numpyro.plate('inner', data.shape[-3], subsample_size=5, dim=-3) + with outer: + x = numpyro.sample('x', dist.Normal(0., 1.)) + assert x.shape == (10,) + with inner: + y = numpyro.sample('y', dist.Normal(0., 1.)) + assert y.shape == (5, 1, 1) + z = numpyro.deterministic('z', x ** 2) + assert z.shape == (10,) + + with outer, inner: + xy = numpyro.sample('xy', dist.Normal(0., 1.)) + assert xy.shape == (5, 1, 10) + subsample_data = numpyro.subsample(data, event_dim=0) + assert subsample_data.shape == (5, 1, 10) + + @pytest.mark.parametrize('model', [ model_nested_plates_0, model_nested_plates_1, @@ -225,6 +245,7 @@ def model_subsample_1(): model_nested_plates_3, model_dist_batch_shape, model_subsample_1, + model_subsample_2, ]) def test_plate(model): trace = handlers.trace(handlers.seed(model, random.PRNGKey(1))).get_trace() @@ -241,6 +262,19 @@ def test_subsample_data(): with handlers.seed(rng_seed=0): with numpyro.plate("a", len(data), subsample_size=subsample_size) as idx: assert data[idx].shape == (subsample_size,) + subsample_data = numpyro.subsample(data, event_dim=0) + assert subsample_data.shape == (subsample_size,) + + +def test_subsample_param(): + data = jnp.arange(100.) + subsample_size = 7 + with handlers.seed(rng_seed=0): + with numpyro.plate("a", len(data), subsample_size=subsample_size): + p0 = numpyro.param("p0", 0., event_dim=0) + assert jnp.shape(p0) == () + p = numpyro.param("p", 0.5 * jnp.ones(len(data)), event_dim=0) + assert len(p) == subsample_size def test_subsample_substitute(): @@ -254,6 +288,19 @@ def test_subsample_substitute(): assert tr["a"]["kwargs"]["rng_key"] is None +def test_subsample_replay(): + data = jnp.arange(100.) + subsample_size = 7 + + with handlers.trace() as guide_trace, handlers.seed(rng_seed=0): + with numpyro.plate("a", len(data), subsample_size=subsample_size) as idx: + pass + + with handlers.trace() as model, handlers.seed(rng_seed=0), handlers.replay(guide_trace=guide_trace): + with numpyro.plate("a", len(data)) as idx: + assert data[idx].shape == (subsample_size,) + + def test_messenger_fn_invalid(): with pytest.raises(ValueError, match="to be a Python callable object"): with numpyro.handlers.mask(False): From 91e3da7c3ac85383eb28756ebbc3e32a6a9f4ceb Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 12 Sep 2020 00:07:56 -0500 Subject: [PATCH 06/11] fix lint and add subsample to docs --- docs/source/primitives.rst | 4 ++++ numpyro/contrib/funsor/enum_messenger.py | 6 +++--- test/test_handlers.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/source/primitives.rst b/docs/source/primitives.rst index ef021e9e0..b1b7d1362 100644 --- a/docs/source/primitives.rst +++ b/docs/source/primitives.rst @@ -19,6 +19,10 @@ plate_stack ----------- .. autofunction:: numpyro.primitives.plate_stack +subsample +--------- +.. autofunction:: numpyro.primitives.subsample + deterministic ------------- .. autofunction:: numpyro.primitives.deterministic diff --git a/numpyro/contrib/funsor/enum_messenger.py b/numpyro/contrib/funsor/enum_messenger.py index ca2dc317c..bac7fc171 100644 --- a/numpyro/contrib/funsor/enum_messenger.py +++ b/numpyro/contrib/funsor/enum_messenger.py @@ -6,7 +6,7 @@ from enum import Enum from jax import lax -import jax.numpy as np +import jax.numpy as jnp import funsor from numpyro.handlers import trace as OrigTraceMessenger @@ -506,7 +506,7 @@ def postprocess_message(self, msg): .format(self.name, self.size, self.dim, statement, shape)) if self.subsample_size < self.size: value = msg["value"] - new_value = jnp.take_along_axis(value, self._indices, dim) + new_value = jnp.take(value, self._indices, dim) msg["value"] = new_value @@ -536,7 +536,7 @@ def process_message(self, msg): raise NotImplementedError("expand=True not implemented") size = msg["fn"].enumerate_support(expand=False).shape[0] - raw_value = np.arange(0, size) + raw_value = jnp.arange(0, size) funsor_value = funsor.Tensor( raw_value, OrderedDict([(msg["name"], funsor.bint(size))]), diff --git a/test/test_handlers.py b/test/test_handlers.py index f2e0603ff..2440ea7aa 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -296,7 +296,7 @@ def test_subsample_replay(): with numpyro.plate("a", len(data), subsample_size=subsample_size) as idx: pass - with handlers.trace() as model, handlers.seed(rng_seed=0), handlers.replay(guide_trace=guide_trace): + with handlers.seed(rng_seed=1), handlers.replay(guide_trace=guide_trace): with numpyro.plate("a", len(data)) as idx: assert data[idx].shape == (subsample_size,) From c944cbafdb686e3eed4baae6e763f99b411ff528 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 12 Sep 2020 22:02:28 -0500 Subject: [PATCH 07/11] pin tfp nightly version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c1b913105..1d1faf1c8 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ 'isort', 'flax', 'dm-haiku @ https://github.com/deepmind/dm-haiku/archive/v0.0.2.zip', - 'tfp-nightly', # TODO: change this to stable release or a specific nightly release + 'tfp-nightly==0.12.0.dev20200911', # TODO: change this to stable release or a specific nightly release ], 'examples': ['matplotlib', 'seaborn', 'graphviz'], }, From bade9f6d8a4c587f439a6d5fdbf6598f3be15abc Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sun, 13 Sep 2020 02:09:58 -0500 Subject: [PATCH 08/11] use pypi for haiku --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1d1faf1c8..5568382d0 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ 'ipython', 'isort', 'flax', - 'dm-haiku @ https://github.com/deepmind/dm-haiku/archive/v0.0.2.zip', + 'dm-haiku', 'tfp-nightly==0.12.0.dev20200911', # TODO: change this to stable release or a specific nightly release ], 'examples': ['matplotlib', 'seaborn', 'graphviz'], From 079c884cbbc4fd18cf5ba9947589e5ff05158cf3 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sun, 13 Sep 2020 23:05:25 -0500 Subject: [PATCH 09/11] add subsample gradient test --- numpyro/primitives.py | 4 +--- test/test_handlers.py | 55 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 4696d9fd1..788a438cb 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -273,9 +273,7 @@ def _subsample(name, size, subsample_size, dim): occupied_dims = {f.dim for f in cond_indep_stack} if dim is None: new_dim = -1 - while True: - if new_dim not in occupied_dims: - break + while new_dim in occupied_dims: new_dim -= 1 dim = new_dim else: diff --git a/test/test_handlers.py b/test/test_handlers.py index 2440ea7aa..6f13a092d 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -5,14 +5,16 @@ from numpy.testing import assert_allclose, assert_raises import pytest -from jax import jit, random, vmap +from jax import jit, random, tree_multimap, value_and_grad, vmap import jax.numpy as jnp import numpyro from numpyro import handlers import numpyro.distributions as dist from numpyro.distributions import constraints +from numpyro.infer import ELBO, SVI from numpyro.infer.util import log_density +import numpyro.optim as optim from numpyro.util import optional @@ -301,6 +303,57 @@ def test_subsample_replay(): assert data[idx].shape == (subsample_size,) +@pytest.mark.parametrize("scale", [1., 2.], ids=["unscaled", "scaled"]) +@pytest.mark.parametrize("subsample", [False, True], ids=["full", "subsample"]) +def test_subsample_gradient(scale, subsample): + data = jnp.array([-0.5, 2.0]) + subsample_size = 1 if subsample else len(data) + precision = 0.06 * scale + + def model(subsample): + with handlers.substitute(data={"data": subsample}): + with numpyro.plate("data", len(data), subsample_size) as ind: + x = data[ind] + z = numpyro.sample("z", dist.Normal(0, 1)) + numpyro.sample("x", dist.Normal(z, 1), obs=x) + + def guide(subsample): + scale = numpyro.param("scale", 1.) + with handlers.substitute(data={"data": subsample}): + with numpyro.plate("data", len(data), subsample_size): + loc = numpyro.param("loc", jnp.zeros(len(data)), event_dim=0) + numpyro.sample("z", dist.Normal(loc, scale)) + + if scale != 1.: + model = handlers.scale(model, scale=scale) + guide = handlers.scale(guide, scale=scale) + + num_particles = 50000 + optimizer = optim.Adam(0.1) + elbo = ELBO(num_particles=num_particles) + svi = SVI(model, guide, optimizer, loss=elbo) + svi_state = svi.init(random.PRNGKey(0), None) + params = svi.optim.get_params(svi_state.optim_state) + normalizer = 2 if subsample else 1 + if subsample_size == 1: + subsample = jnp.array([0]) + _, grads1 = value_and_grad(lambda x: svi.loss.loss( + svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, subsample))(params) + subsample = jnp.array([1]) + _, grads2 = value_and_grad(lambda x: svi.loss.loss( + svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, subsample))(params) + grads = tree_multimap(lambda *vals: vals[0] + vals[1], grads1, grads2) + else: + subsample = jnp.array([0, 1]) + _, grads = value_and_grad(lambda x: svi.loss.loss( + svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, subsample))(params) + actual_grads = {name: grad / normalizer for name, grad in grads.items()} + expected_grads = {'loc': scale * jnp.array([0.5, -2.0]), 'scale': scale * jnp.array([2.0])} + assert actual_grads.keys() == expected_grads.keys() + for name in expected_grads: + assert_allclose(actual_grads[name], expected_grads[name], rtol=precision, atol=precision) + + def test_messenger_fn_invalid(): with pytest.raises(ValueError, match="to be a Python callable object"): with numpyro.handlers.mask(False): From 107c38e2a66054b2a2d4b9d4f63669ded8703c31 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 17 Sep 2020 12:23:41 -0500 Subject: [PATCH 10/11] add test for value in subsample_gradient test --- test/test_handlers.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/test/test_handlers.py b/test/test_handlers.py index 6f13a092d..9af925862 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -337,16 +337,23 @@ def guide(subsample): normalizer = 2 if subsample else 1 if subsample_size == 1: subsample = jnp.array([0]) - _, grads1 = value_and_grad(lambda x: svi.loss.loss( + loss1, grads1 = value_and_grad(lambda x: svi.loss.loss( svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, subsample))(params) subsample = jnp.array([1]) - _, grads2 = value_and_grad(lambda x: svi.loss.loss( + loss2, grads2 = value_and_grad(lambda x: svi.loss.loss( svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, subsample))(params) grads = tree_multimap(lambda *vals: vals[0] + vals[1], grads1, grads2) + loss = loss1 + loss2 else: subsample = jnp.array([0, 1]) - _, grads = value_and_grad(lambda x: svi.loss.loss( + loss, grads = value_and_grad(lambda x: svi.loss.loss( svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, subsample))(params) + + actual_loss = loss / normalizer + expected_loss, _ = value_and_grad(lambda x: svi.loss.loss( + svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, None))(params) + assert_allclose(actual_loss, expected_loss, rtol=precision, atol=precision) + actual_grads = {name: grad / normalizer for name, grad in grads.items()} expected_grads = {'loc': scale * jnp.array([0.5, -2.0]), 'scale': scale * jnp.array([2.0])} assert actual_grads.keys() == expected_grads.keys() From b95bf9b4b19d9cfdf5a2d7c9434b321debe9fe36 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 18 Sep 2020 00:54:09 -0500 Subject: [PATCH 11/11] make sure that name is checked after type in replay handler --- numpyro/handlers.py | 2 +- test/test_handlers.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 886c9b482..b76a1b74a 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -192,7 +192,7 @@ def __init__(self, fn=None, guide_trace=None): super(replay, self).__init__(fn) def process_message(self, msg): - if msg['name'] in self.guide_trace and msg['type'] in ('sample', 'plate'): + if msg['type'] in ('sample', 'plate') and msg['name'] in self.guide_trace: msg['value'] = self.guide_trace[msg['name']]['value'] diff --git a/test/test_handlers.py b/test/test_handlers.py index 9af925862..54bb15793 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -295,12 +295,13 @@ def test_subsample_replay(): subsample_size = 7 with handlers.trace() as guide_trace, handlers.seed(rng_seed=0): - with numpyro.plate("a", len(data), subsample_size=subsample_size) as idx: + with numpyro.plate("a", len(data), subsample_size=subsample_size): pass with handlers.seed(rng_seed=1), handlers.replay(guide_trace=guide_trace): - with numpyro.plate("a", len(data)) as idx: - assert data[idx].shape == (subsample_size,) + with numpyro.plate("a", len(data)): + subsample_data = numpyro.subsample(data, event_dim=0) + assert subsample_data.shape == (subsample_size,) @pytest.mark.parametrize("scale", [1., 2.], ids=["unscaled", "scaled"])