Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support data subsampling #734

Merged
merged 12 commits into from
Sep 18, 2020
4 changes: 4 additions & 0 deletions docs/source/primitives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ plate_stack
-----------
.. autofunction:: numpyro.primitives.plate_stack

subsample
---------
.. autofunction:: numpyro.primitives.subsample

deterministic
-------------
.. autofunction:: numpyro.primitives.deterministic
Expand Down
3 changes: 2 additions & 1 deletion numpyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from numpyro import compat, diagnostics, distributions, handlers, infer, optim
from numpyro.distributions.distribution import enable_validation, validation_enabled
import numpyro.patch # noqa: F401
from numpyro.primitives import deterministic, factor, module, param, plate, plate_stack, sample
from numpyro.primitives import deterministic, factor, module, param, plate, plate_stack, sample, subsample
from numpyro.util import enable_x64, set_host_device_count, set_platform
from numpyro.version import __version__

Expand All @@ -28,6 +28,7 @@
'plate',
'plate_stack',
'sample',
'subsample',
'set_host_device_count',
'set_platform',
'validation_enabled',
Expand Down
36 changes: 29 additions & 7 deletions numpyro/contrib/funsor/enum_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from enum import Enum

from jax import lax
import jax.numpy as np
import jax.numpy as jnp

import funsor
from numpyro.handlers import trace as OrigTraceMessenger
from numpyro.primitives import CondIndepStackFrame, Messenger, apply_stack
from numpyro.primitives import plate as OrigPlateMessenger

funsor.set_backend("jax")

Expand Down Expand Up @@ -438,14 +439,14 @@ class plate(GlobalNamedMessenger):
def __init__(self, name, size, subsample_size=None, dim=None):
self.name = name
self.size = size
self.subsample_size = size if subsample_size is None else subsample_size
if dim is not None and dim >= 0:
raise ValueError('dim arg must be negative.')
self.dim = dim
self.dim, indices = OrigPlateMessenger._subsample(self.name, self.size, subsample_size, dim)
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
self.subsample_size = indices.shape[0]
self._indices = funsor.Tensor(
funsor.ops.new_arange(funsor.tensor.get_default_prototype(), self.size),
OrderedDict([(self.name, funsor.bint(self.size))]),
self.size
indices,
OrderedDict([(self.name, funsor.bint(self.subsample_size))]),
self.subsample_size
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
)
super(plate, self).__init__(None)

Expand Down Expand Up @@ -487,6 +488,27 @@ def process_message(self, msg): # copied almost verbatim from plate
scale = 1. if msg['scale'] is None else msg['scale']
msg['scale'] = scale * self.size / self.subsample_size

def postprocess_message(self, msg):
if msg["type"] in ("subsample", "param") and self.dim is not None:
event_dim = msg["kwargs"].get("event_dim")
if event_dim is not None:
assert event_dim >= 0
dim = self.dim - event_dim
shape = msg["value"].shape
if len(shape) >= -dim and shape[dim] != 1:
if shape[dim] != self.size:
if msg["type"] == "param":
statement = "numpyro.param({}, ..., event_dim={})".format(msg["name"], event_dim)
else:
statement = "numpyro.subsample(..., event_dim={})".format(event_dim)
raise ValueError(
"Inside numpyro.plate({}, {}, dim={}) invalid shape of {}: {}"
.format(self.name, self.size, self.dim, statement, shape))
if self.subsample_size < self.size:
value = msg["value"]
new_value = jnp.take(value, self._indices, dim)
msg["value"] = new_value


class enum(BaseEnumMessenger):
"""
Expand Down Expand Up @@ -514,7 +536,7 @@ def process_message(self, msg):
raise NotImplementedError("expand=True not implemented")

size = msg["fn"].enumerate_support(expand=False).shape[0]
raw_value = np.arange(0, size)
raw_value = jnp.arange(0, size)
funsor_value = funsor.Tensor(
raw_value,
OrderedDict([(msg["name"], funsor.bint(size))]),
Expand Down
7 changes: 5 additions & 2 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,10 @@ def __init__(self, fn=None, rng_seed=None):

def process_message(self, msg):
if (msg['type'] == 'sample' and not msg['is_observed'] and
msg['kwargs']['rng_key'] is None) or msg['type'] == 'control_flow':
msg['kwargs']['rng_key'] is None) or msg['type'] in ['plate', 'control_flow']:
# no need to create a new key when value is available
if msg['value'] is not None:
return
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
self.rng_key, rng_key_sample = random.split(self.rng_key)
msg['kwargs']['rng_key'] = rng_key_sample

Expand Down Expand Up @@ -596,7 +599,7 @@ def __init__(self, fn=None, data=None, substitute_fn=None):
super(substitute, self).__init__(fn)

def process_message(self, msg):
if (msg['type'] not in ('sample', 'param')) or msg.get('_control_flow_done', False):
if (msg['type'] not in ('sample', 'param', 'plate')) or msg.get('_control_flow_done', False):
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
if msg['type'] == 'control_flow':
if self.data is not None:
msg['kwargs']['substitute_stack'].append(('substitute', self.data))
Expand Down
129 changes: 109 additions & 20 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from contextlib import ExitStack, contextmanager
import functools

from jax import lax
from jax import lax, random
import jax.numpy as jnp

import numpyro
from numpyro.distributions.discrete import PRNGIdentity
Expand Down Expand Up @@ -125,6 +126,14 @@ def param(name, init_value=None, **kwargs):
the onus of using this to initialize the optimizer is on the user /
inference algorithm, since there is no global parameter store in
NumPyro.
:param constraint: NumPyro constraint, defaults to ``constraints.real``.
:type constraint: numpyro.distributions.constraints.Constraint
:param int event_dim: (optional) number of rightmost dimensions unrelated
to batching. Dimension to the left of this will be considered batch
dimensions; if the param statement is inside a subsampled plate, then
corresponding batch dimensions of the parameter will be correspondingly
subsampled. If unspecified, all dimensions will be considered event
dims and no subsampling will be performed.
:return: value for the parameter. Unless wrapped inside a
handler like :class:`~numpyro.handlers.substitute`, this will simply
return the initial value.
Expand Down Expand Up @@ -202,13 +211,24 @@ def module(name, nn, input_shape=None):
return functools.partial(nn_apply, nn_params)


def _subsample_fn(size, subsample_size, rng_key=None):
assert rng_key is not None, "Missing random key to generate subsample indices."
return random.permutation(rng_key, size)[:subsample_size]
fritzo marked this conversation as resolved.
Show resolved Hide resolved


class plate(Messenger):
"""
Construct for annotating conditionally independent variables. Within a
`plate` context manager, `sample` sites will be automatically broadcasted to
the size of the plate. Additionally, a scale factor might be applied by
certain inference algorithms if `subsample_size` is specified.

.. note:: This can be used to subsample minibatches of data::

with plate("data", len(data), subsample_size=100) as ind:
batch = data[ind]
assert len(batch) == 100

:param str name: Name of the plate.
:param int size: Size of the plate.
:param int subsample_size: Optional argument denoting the size of the mini-batch.
Expand All @@ -221,41 +241,48 @@ class plate(Messenger):
def __init__(self, name, size, subsample_size=None, dim=None):
self.name = name
self.size = size
self.subsample_size = size if subsample_size is None else subsample_size
if dim is not None and dim >= 0:
raise ValueError('dim arg must be negative.')
self.dim = dim
self._validate_and_set_dim()
self.dim, self._indices = self._subsample(
self.name, self.size, subsample_size, dim)
self.subsample_size = self._indices.shape[0]
super(plate, self).__init__()

def _validate_and_set_dim(self):
# XXX: different from Pyro, this method returns dim and indices
@staticmethod
def _subsample(name, size, subsample_size, dim):
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
msg = {
'type': 'plate',
'fn': identity,
'name': self.name,
'args': (None,),
'kwargs': {},
'value': None,
'fn': _subsample_fn,
'name': name,
'args': (size, subsample_size),
'kwargs': {'rng_key': None},
'value': (None
if (subsample_size is not None and size != subsample_size)
else jnp.arange(size)),
'scale': 1.0,
'cond_indep_stack': [],
}
apply_stack(msg)
subsample = msg['value']
if subsample_size is not None and subsample_size != subsample.shape[0]:
raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format(
subsample_size, len(subsample)) +
" Did you accidentally use different subsample_size in the model and guide?")
cond_indep_stack = msg['cond_indep_stack']
occupied_dims = {f.dim for f in cond_indep_stack}
dim = -1
while True:
if dim not in occupied_dims:
break
dim -= 1
if self.dim is None:
self.dim = dim
if dim is None:
new_dim = -1
while new_dim in occupied_dims:
new_dim -= 1
dim = new_dim
else:
assert self.dim not in occupied_dims
assert dim not in occupied_dims
return dim, subsample

def __enter__(self):
super().__enter__()
# XXX: JAX doesn't like slice index, so we cast to list
return list(range(self.subsample_size))
return self._indices

@staticmethod
def _get_batch_shape(cond_indep_stack):
Expand Down Expand Up @@ -291,6 +318,27 @@ def process_message(self, msg):
scale = 1. if msg['scale'] is None else msg['scale']
msg['scale'] = scale * self.size / self.subsample_size

def postprocess_message(self, msg):
if msg["type"] in ("subsample", "param") and self.dim is not None:
event_dim = msg["kwargs"].get("event_dim")
if event_dim is not None:
assert event_dim >= 0
dim = self.dim - event_dim
shape = jnp.shape(msg["value"])
if len(shape) >= -dim and shape[dim] != 1:
if shape[dim] != self.size:
if msg["type"] == "param":
statement = "numpyro.param({}, ..., event_dim={})".format(msg["name"], event_dim)
else:
statement = "numpyro.subsample(..., event_dim={})".format(event_dim)
raise ValueError(
"Inside numpyro.plate({}, {}, dim={}) invalid shape of {}: {}"
.format(self.name, self.size, self.dim, statement, shape))
if self.subsample_size < self.size:
value = msg["value"]
new_value = jnp.take(value, self._indices, dim)
msg["value"] = new_value


@contextmanager
def plate_stack(prefix, sizes, rightmost_dim=-1):
Expand Down Expand Up @@ -322,3 +370,44 @@ def factor(name, log_factor):
unit_dist = numpyro.distributions.distribution.Unit(log_factor)
unit_value = unit_dist.sample(None)
sample(name, unit_dist, obs=unit_value)


def subsample(data, event_dim):
"""
EXPERIMENTAL Subsampling statement to subsample data based on enclosing
:class:`~numpyro.primitives.plate` s.

This is typically called on arguments to ``model()`` when subsampling is
performed automatically by :class:`~numpyro.primitives.plate` s by passing
``subsample_size`` kwarg. For example the following are equivalent::

# Version 1. using indexing
def model(data):
with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind:
data = data[ind]
# ...

# Version 2. using numpyro.subsample()
def model(data):
with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()):
data = numpyro.subsample(data, event_dim=0)
# ...

:param numpy.ndarray data: A tensor of batched data.
:param int event_dim: The event dimension of the data tensor. Dimensions to
the left are considered batch dimensions.
:returns: A subsampled version of ``data``
:rtype: ~numpy.ndarray
"""
if not _PYRO_STACK:
return data

assert isinstance(event_dim, int) and event_dim >= 0
initial_msg = {
'type': 'subsample',
'value': data,
'kwargs': {'event_dim': event_dim}
}

msg = apply_stack(initial_msg)
return msg['value']
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
'ipython',
'isort',
'flax',
'dm-haiku @ https://github.com/deepmind/dm-haiku/archive/v0.0.2.zip',
'tfp-nightly', # TODO: change this to stable release or a specific nightly release
'dm-haiku',
'tfp-nightly==0.12.0.dev20200911', # TODO: change this to stable release or a specific nightly release
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
],
'examples': ['matplotlib', 'seaborn', 'graphviz'],
},
Expand Down
1 change: 1 addition & 0 deletions test/contrib/test_tfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_independent():
@pytest.mark.filterwarnings("ignore:can't resolve package")
def test_transformed_distributions():
from tensorflow_probability.substrates.jax import bijectors as tfb

from numpyro.contrib.tfp import distributions as tfd

d = dist.TransformedDistribution(dist.Normal(0, 1), dist.transforms.ExpTransform())
Expand Down
Loading