Skip to content

Commit

Permalink
Support data subsampling (#734)
Browse files Browse the repository at this point in the history
* support subsampling

* add example for subsample ind cs

* simplify the implementation

* fix lint issue

* add subsample primitive and more tests

* fix lint and add subsample to docs

* pin tfp nightly version

* use pypi for haiku

* add subsample gradient test

* add test for value in subsample_gradient test

* make sure that name is checked after type in replay handler
  • Loading branch information
fehiepsi authored Sep 18, 2020
1 parent b4ba9f1 commit cce00d1
Show file tree
Hide file tree
Showing 8 changed files with 281 additions and 34 deletions.
4 changes: 4 additions & 0 deletions docs/source/primitives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ plate_stack
-----------
.. autofunction:: numpyro.primitives.plate_stack

subsample
---------
.. autofunction:: numpyro.primitives.subsample

deterministic
-------------
.. autofunction:: numpyro.primitives.deterministic
Expand Down
3 changes: 2 additions & 1 deletion numpyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from numpyro import compat, diagnostics, distributions, handlers, infer, optim
from numpyro.distributions.distribution import enable_validation, validation_enabled
import numpyro.patch # noqa: F401
from numpyro.primitives import deterministic, factor, module, param, plate, plate_stack, sample
from numpyro.primitives import deterministic, factor, module, param, plate, plate_stack, sample, subsample
from numpyro.util import enable_x64, set_host_device_count, set_platform
from numpyro.version import __version__

Expand All @@ -28,6 +28,7 @@
'plate',
'plate_stack',
'sample',
'subsample',
'set_host_device_count',
'set_platform',
'validation_enabled',
Expand Down
36 changes: 29 additions & 7 deletions numpyro/contrib/funsor/enum_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from enum import Enum

from jax import lax
import jax.numpy as np
import jax.numpy as jnp

import funsor
from numpyro.handlers import trace as OrigTraceMessenger
from numpyro.primitives import CondIndepStackFrame, Messenger, apply_stack
from numpyro.primitives import plate as OrigPlateMessenger

funsor.set_backend("jax")

Expand Down Expand Up @@ -438,14 +439,14 @@ class plate(GlobalNamedMessenger):
def __init__(self, name, size, subsample_size=None, dim=None):
self.name = name
self.size = size
self.subsample_size = size if subsample_size is None else subsample_size
if dim is not None and dim >= 0:
raise ValueError('dim arg must be negative.')
self.dim = dim
self.dim, indices = OrigPlateMessenger._subsample(self.name, self.size, subsample_size, dim)
self.subsample_size = indices.shape[0]
self._indices = funsor.Tensor(
funsor.ops.new_arange(funsor.tensor.get_default_prototype(), self.size),
OrderedDict([(self.name, funsor.bint(self.size))]),
self.size
indices,
OrderedDict([(self.name, funsor.bint(self.subsample_size))]),
self.subsample_size
)
super(plate, self).__init__(None)

Expand Down Expand Up @@ -487,6 +488,27 @@ def process_message(self, msg): # copied almost verbatim from plate
scale = 1. if msg['scale'] is None else msg['scale']
msg['scale'] = scale * self.size / self.subsample_size

def postprocess_message(self, msg):
if msg["type"] in ("subsample", "param") and self.dim is not None:
event_dim = msg["kwargs"].get("event_dim")
if event_dim is not None:
assert event_dim >= 0
dim = self.dim - event_dim
shape = msg["value"].shape
if len(shape) >= -dim and shape[dim] != 1:
if shape[dim] != self.size:
if msg["type"] == "param":
statement = "numpyro.param({}, ..., event_dim={})".format(msg["name"], event_dim)
else:
statement = "numpyro.subsample(..., event_dim={})".format(event_dim)
raise ValueError(
"Inside numpyro.plate({}, {}, dim={}) invalid shape of {}: {}"
.format(self.name, self.size, self.dim, statement, shape))
if self.subsample_size < self.size:
value = msg["value"]
new_value = jnp.take(value, self._indices, dim)
msg["value"] = new_value


class enum(BaseEnumMessenger):
"""
Expand Down Expand Up @@ -514,7 +536,7 @@ def process_message(self, msg):
raise NotImplementedError("expand=True not implemented")

size = msg["fn"].enumerate_support(expand=False).shape[0]
raw_value = np.arange(0, size)
raw_value = jnp.arange(0, size)
funsor_value = funsor.Tensor(
raw_value,
OrderedDict([(msg["name"], funsor.bint(size))]),
Expand Down
9 changes: 6 additions & 3 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def __init__(self, fn=None, guide_trace=None):
super(replay, self).__init__(fn)

def process_message(self, msg):
if msg['name'] in self.guide_trace and msg['type'] in ('sample', 'plate'):
if msg['type'] in ('sample', 'plate') and msg['name'] in self.guide_trace:
msg['value'] = self.guide_trace[msg['name']]['value']


Expand Down Expand Up @@ -547,7 +547,10 @@ def __init__(self, fn=None, rng_seed=None):

def process_message(self, msg):
if (msg['type'] == 'sample' and not msg['is_observed'] and
msg['kwargs']['rng_key'] is None) or msg['type'] == 'control_flow':
msg['kwargs']['rng_key'] is None) or msg['type'] in ['plate', 'control_flow']:
# no need to create a new key when value is available
if msg['value'] is not None:
return
self.rng_key, rng_key_sample = random.split(self.rng_key)
msg['kwargs']['rng_key'] = rng_key_sample

Expand Down Expand Up @@ -596,7 +599,7 @@ def __init__(self, fn=None, data=None, substitute_fn=None):
super(substitute, self).__init__(fn)

def process_message(self, msg):
if (msg['type'] not in ('sample', 'param')) or msg.get('_control_flow_done', False):
if (msg['type'] not in ('sample', 'param', 'plate')) or msg.get('_control_flow_done', False):
if msg['type'] == 'control_flow':
if self.data is not None:
msg['kwargs']['substitute_stack'].append(('substitute', self.data))
Expand Down
129 changes: 109 additions & 20 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from contextlib import ExitStack, contextmanager
import functools

from jax import lax
from jax import lax, random
import jax.numpy as jnp

import numpyro
from numpyro.distributions.discrete import PRNGIdentity
Expand Down Expand Up @@ -125,6 +126,14 @@ def param(name, init_value=None, **kwargs):
the onus of using this to initialize the optimizer is on the user /
inference algorithm, since there is no global parameter store in
NumPyro.
:param constraint: NumPyro constraint, defaults to ``constraints.real``.
:type constraint: numpyro.distributions.constraints.Constraint
:param int event_dim: (optional) number of rightmost dimensions unrelated
to batching. Dimension to the left of this will be considered batch
dimensions; if the param statement is inside a subsampled plate, then
corresponding batch dimensions of the parameter will be correspondingly
subsampled. If unspecified, all dimensions will be considered event
dims and no subsampling will be performed.
:return: value for the parameter. Unless wrapped inside a
handler like :class:`~numpyro.handlers.substitute`, this will simply
return the initial value.
Expand Down Expand Up @@ -202,13 +211,24 @@ def module(name, nn, input_shape=None):
return functools.partial(nn_apply, nn_params)


def _subsample_fn(size, subsample_size, rng_key=None):
assert rng_key is not None, "Missing random key to generate subsample indices."
return random.permutation(rng_key, size)[:subsample_size]


class plate(Messenger):
"""
Construct for annotating conditionally independent variables. Within a
`plate` context manager, `sample` sites will be automatically broadcasted to
the size of the plate. Additionally, a scale factor might be applied by
certain inference algorithms if `subsample_size` is specified.
.. note:: This can be used to subsample minibatches of data::
with plate("data", len(data), subsample_size=100) as ind:
batch = data[ind]
assert len(batch) == 100
:param str name: Name of the plate.
:param int size: Size of the plate.
:param int subsample_size: Optional argument denoting the size of the mini-batch.
Expand All @@ -221,41 +241,48 @@ class plate(Messenger):
def __init__(self, name, size, subsample_size=None, dim=None):
self.name = name
self.size = size
self.subsample_size = size if subsample_size is None else subsample_size
if dim is not None and dim >= 0:
raise ValueError('dim arg must be negative.')
self.dim = dim
self._validate_and_set_dim()
self.dim, self._indices = self._subsample(
self.name, self.size, subsample_size, dim)
self.subsample_size = self._indices.shape[0]
super(plate, self).__init__()

def _validate_and_set_dim(self):
# XXX: different from Pyro, this method returns dim and indices
@staticmethod
def _subsample(name, size, subsample_size, dim):
msg = {
'type': 'plate',
'fn': identity,
'name': self.name,
'args': (None,),
'kwargs': {},
'value': None,
'fn': _subsample_fn,
'name': name,
'args': (size, subsample_size),
'kwargs': {'rng_key': None},
'value': (None
if (subsample_size is not None and size != subsample_size)
else jnp.arange(size)),
'scale': 1.0,
'cond_indep_stack': [],
}
apply_stack(msg)
subsample = msg['value']
if subsample_size is not None and subsample_size != subsample.shape[0]:
raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format(
subsample_size, len(subsample)) +
" Did you accidentally use different subsample_size in the model and guide?")
cond_indep_stack = msg['cond_indep_stack']
occupied_dims = {f.dim for f in cond_indep_stack}
dim = -1
while True:
if dim not in occupied_dims:
break
dim -= 1
if self.dim is None:
self.dim = dim
if dim is None:
new_dim = -1
while new_dim in occupied_dims:
new_dim -= 1
dim = new_dim
else:
assert self.dim not in occupied_dims
assert dim not in occupied_dims
return dim, subsample

def __enter__(self):
super().__enter__()
# XXX: JAX doesn't like slice index, so we cast to list
return list(range(self.subsample_size))
return self._indices

@staticmethod
def _get_batch_shape(cond_indep_stack):
Expand Down Expand Up @@ -291,6 +318,27 @@ def process_message(self, msg):
scale = 1. if msg['scale'] is None else msg['scale']
msg['scale'] = scale * self.size / self.subsample_size

def postprocess_message(self, msg):
if msg["type"] in ("subsample", "param") and self.dim is not None:
event_dim = msg["kwargs"].get("event_dim")
if event_dim is not None:
assert event_dim >= 0
dim = self.dim - event_dim
shape = jnp.shape(msg["value"])
if len(shape) >= -dim and shape[dim] != 1:
if shape[dim] != self.size:
if msg["type"] == "param":
statement = "numpyro.param({}, ..., event_dim={})".format(msg["name"], event_dim)
else:
statement = "numpyro.subsample(..., event_dim={})".format(event_dim)
raise ValueError(
"Inside numpyro.plate({}, {}, dim={}) invalid shape of {}: {}"
.format(self.name, self.size, self.dim, statement, shape))
if self.subsample_size < self.size:
value = msg["value"]
new_value = jnp.take(value, self._indices, dim)
msg["value"] = new_value


@contextmanager
def plate_stack(prefix, sizes, rightmost_dim=-1):
Expand Down Expand Up @@ -322,3 +370,44 @@ def factor(name, log_factor):
unit_dist = numpyro.distributions.distribution.Unit(log_factor)
unit_value = unit_dist.sample(None)
sample(name, unit_dist, obs=unit_value)


def subsample(data, event_dim):
"""
EXPERIMENTAL Subsampling statement to subsample data based on enclosing
:class:`~numpyro.primitives.plate` s.
This is typically called on arguments to ``model()`` when subsampling is
performed automatically by :class:`~numpyro.primitives.plate` s by passing
``subsample_size`` kwarg. For example the following are equivalent::
# Version 1. using indexing
def model(data):
with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind:
data = data[ind]
# ...
# Version 2. using numpyro.subsample()
def model(data):
with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()):
data = numpyro.subsample(data, event_dim=0)
# ...
:param numpy.ndarray data: A tensor of batched data.
:param int event_dim: The event dimension of the data tensor. Dimensions to
the left are considered batch dimensions.
:returns: A subsampled version of ``data``
:rtype: ~numpy.ndarray
"""
if not _PYRO_STACK:
return data

assert isinstance(event_dim, int) and event_dim >= 0
initial_msg = {
'type': 'subsample',
'value': data,
'kwargs': {'event_dim': event_dim}
}

msg = apply_stack(initial_msg)
return msg['value']
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
'ipython',
'isort',
'flax',
'dm-haiku @ https://github.com/deepmind/dm-haiku/archive/v0.0.2.zip',
'tfp-nightly', # TODO: change this to stable release or a specific nightly release
'dm-haiku',
'tfp-nightly==0.12.0.dev20200911', # TODO: change this to stable release or a specific nightly release
],
'examples': ['matplotlib', 'seaborn', 'graphviz'],
},
Expand Down
1 change: 1 addition & 0 deletions test/contrib/test_tfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_independent():
@pytest.mark.filterwarnings("ignore:can't resolve package")
def test_transformed_distributions():
from tensorflow_probability.substrates.jax import bijectors as tfb

from numpyro.contrib.tfp import distributions as tfd

d = dist.TransformedDistribution(dist.Normal(0, 1), dist.transforms.ExpTransform())
Expand Down
Loading

0 comments on commit cce00d1

Please sign in to comment.