Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support data subsampling #734

Merged
merged 12 commits into from
Sep 18, 2020
9 changes: 5 additions & 4 deletions numpyro/contrib/funsor/enum_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import funsor
from numpyro.handlers import trace as OrigTraceMessenger
from numpyro.primitives import CondIndepStackFrame, Messenger, apply_stack
from numpyro.primitives import CondIndepStackFrame, Messenger, apply_stack, plate as OrigPlateMessenger

funsor.set_backend("jax")

Expand Down Expand Up @@ -442,10 +442,11 @@ def __init__(self, name, size, subsample_size=None, dim=None):
if dim is not None and dim >= 0:
raise ValueError('dim arg must be negative.')
self.dim = dim
_, indices = OrigPlateMessenger._subsample(self.name, self.size, self.subsample_size, dim)
self._indices = funsor.Tensor(
funsor.ops.new_arange(funsor.tensor.get_default_prototype(), self.size),
OrderedDict([(self.name, funsor.bint(self.size))]),
self.size
indices,
OrderedDict([(self.name, funsor.bint(self.subsample_size))]),
self.subsample_size
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
)
super(plate, self).__init__(None)

Expand Down
9 changes: 7 additions & 2 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,12 @@ def __init__(self, fn=None, rng_seed=None):

def process_message(self, msg):
if (msg['type'] == 'sample' and not msg['is_observed'] and
msg['kwargs']['rng_key'] is None) or msg['type'] == 'control_flow':
msg['kwargs']['rng_key'] is None) or msg['type'] in ['plate', 'control_flow']:
# no need to split key if size = subsample_size
if msg['type'] == 'plate':
size, subsample_size = msg['args']
if size == subsample_size:
return
self.rng_key, rng_key_sample = random.split(self.rng_key)
msg['kwargs']['rng_key'] = rng_key_sample

Expand Down Expand Up @@ -596,7 +601,7 @@ def __init__(self, fn=None, data=None, substitute_fn=None):
super(substitute, self).__init__(fn)

def process_message(self, msg):
if (msg['type'] not in ('sample', 'param')) or msg.get('_control_flow_done', False):
if (msg['type'] not in ('sample', 'param', 'plate')) or msg.get('_control_flow_done', False):
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
if msg['type'] == 'control_flow':
if self.data is not None:
msg['kwargs']['substitute_stack'].append(('substitute', self.data))
Expand Down
57 changes: 39 additions & 18 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from contextlib import ExitStack, contextmanager
import functools

from jax import lax
from jax import lax, random
import jax.numpy as jnp

import numpyro
from numpyro.distributions.discrete import PRNGIdentity
Expand Down Expand Up @@ -202,13 +203,27 @@ def module(name, nn, input_shape=None):
return functools.partial(nn_apply, nn_params)


def _subsample_fn(size, subsample_size, rng_key=None):
if size == subsample_size:
return jnp.arange(size)
else:
assert rng_key is not None, "Missing random key to generate subsample indices."
return random.permutation(rng_key, size)[:subsample_size]


class plate(Messenger):
"""
Construct for annotating conditionally independent variables. Within a
`plate` context manager, `sample` sites will be automatically broadcasted to
the size of the plate. Additionally, a scale factor might be applied by
certain inference algorithms if `subsample_size` is specified.

.. note:: This can be used to subsample minibatches of data::

with plate("data", len(data), subsample_size=100) as ind:
batch = data[ind]
assert len(batch) == 100

:param str name: Name of the plate.
:param int size: Size of the plate.
:param int subsample_size: Optional argument denoting the size of the mini-batch.
Expand All @@ -224,38 +239,44 @@ def __init__(self, name, size, subsample_size=None, dim=None):
self.subsample_size = size if subsample_size is None else subsample_size
if dim is not None and dim >= 0:
raise ValueError('dim arg must be negative.')
self.dim = dim
self._validate_and_set_dim()
self.dim, self._indices = self._subsample(self.name, self.size, self.subsample_size, dim)
super(plate, self).__init__()

def _validate_and_set_dim(self):
# XXX: different from Pyro, this method returns dim and indices
@staticmethod
def _subsample(name, size, subsample_size, dim):
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
msg = {
'type': 'plate',
'fn': identity,
'name': self.name,
'args': (None,),
'kwargs': {},
'fn': _subsample_fn,
'name': name,
'args': (size, subsample_size),
'kwargs': {'rng_key': None},
'value': None,
'scale': 1.0,
'cond_indep_stack': [],
}
apply_stack(msg)
subsample = msg['value']
if subsample_size != subsample.shape[0]:
raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format(
subsample_size, len(subsample)) +
" Did you accidentally use different subsample_size in the model and guide?")
cond_indep_stack = msg['cond_indep_stack']
occupied_dims = {f.dim for f in cond_indep_stack}
dim = -1
while True:
if dim not in occupied_dims:
break
dim -= 1
if self.dim is None:
self.dim = dim
if dim is None:
new_dim = -1
while True:
if new_dim not in occupied_dims:
break
new_dim -= 1
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
dim = new_dim
else:
assert self.dim not in occupied_dims
assert dim not in occupied_dims
return dim, subsample

def __enter__(self):
super().__enter__()
# XXX: JAX doesn't like slice index, so we cast to list
return list(range(self.subsample_size))
return self._indices

@staticmethod
def _get_batch_shape(cond_indep_stack):
Expand Down