Skip to content

Commit

Permalink
Add numpyro.collapse (#773)
Browse files Browse the repository at this point in the history
* add collapse messenger

* add dirichlet multinomial

* add dirichlet distribution

* revert change at funsor

* move collapse to the main handler

* remove legacy code

* add more tests

* fix a bug

* pass the test

* fix the bug

* bump setup dependent versions

* fix failing tests when signatures of NumPyro classes have been changed due to Meta class

* match new coercion pattern and enable beta_bernoulli

* fix lint
  • Loading branch information
fehiepsi authored Oct 7, 2020
1 parent 1f2df20 commit 78716a7
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 15 deletions.
8 changes: 8 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ block
:show-inheritance:
:member-order: bysource

collapse
--------
.. autoclass:: numpyro.handlers.collapse
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

condition
---------
.. autoclass:: numpyro.handlers.condition
Expand Down
26 changes: 21 additions & 5 deletions numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,23 @@ def _transform_to_bijector_constraint(constraint):
return BijectorTransform(constraint.bijector)


class TFPDistributionMixin(NumPyroDistribution):
_TFPDistributionMeta = type(tfd.Distribution)


# XXX: we create this mixin class to avoid metaclass conflict between TFP and NumPyro Ditribution
class _TFPMixinMeta(_TFPDistributionMeta, type(NumPyroDistribution)):
def __init__(cls, name, bases, dct):
# XXX: _TFPDistributionMeta.__init__ registers cls as a PyTree
# for some reasons, when defining metaclass of TFPDistributionMixin to be _TFPMixinMeta,
# TFPDistributionMixin will be registered as a PyTree 2 times, which is not allowed
# in JAX, so we skip registering TFPDistributionMixin as a PyTree.
if name == "TFPDistributionMixin":
super(_TFPDistributionMeta, cls).__init__(name, bases, dct)
else:
super(_TFPMixinMeta, cls).__init__(name, bases, dct)


class TFPDistributionMixin(NumPyroDistribution, metaclass=_TFPMixinMeta):
"""
A mixin layer to make TensorFlow Probability (TFP) distribution compatible
with NumPyro internal.
Expand Down Expand Up @@ -118,11 +134,11 @@ def is_discrete(self):
return self.support is None


class InverseGamma(tfd.InverseGamma):
class InverseGamma(tfd.InverseGamma, TFPDistributionMixin):
arg_constraints = {"concentration": constraints.positive, "scale": constraints.positive}


class OneHotCategorical(tfd.OneHotCategorical):
class OneHotCategorical(tfd.OneHotCategorical, TFPDistributionMixin):
arg_constraints = {"logits": constraints.real_vector}
has_enumerate_support = True
support = constraints.simplex
Expand All @@ -137,11 +153,11 @@ def enumerate_support(self, expand=True):
return values


class OrderedLogistic(tfd.OrderedLogistic):
class OrderedLogistic(tfd.OrderedLogistic, TFPDistributionMixin):
arg_constraints = {"cutpoints": constraints.ordered_vector, "loc": constraints.real}


class Pareto(tfd.Pareto):
class Pareto(tfd.Pareto, TFPDistributionMixin):
arg_constraints = {"concentration": constraints.positive, "scale": constraints.positive}


Expand Down
14 changes: 13 additions & 1 deletion numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,19 @@ def validation_enabled(is_validate=True):
enable_validation(distribution_validation_status)


class Distribution(object):
COERCIONS = []


class DistributionMeta(type):
def __call__(cls, *args, **kwargs):
for coerce_ in COERCIONS:
result = coerce_(cls, args, kwargs)
if result is not None:
return result
return super().__call__(*args, **kwargs)


class Distribution(metaclass=DistributionMeta):
"""
Base class for probability distributions in NumPyro. The design largely
follows from :mod:`torch.distributions`.
Expand Down
71 changes: 70 additions & 1 deletion numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,13 @@
import jax.numpy as jnp

import numpyro
from numpyro.primitives import Messenger, apply_stack
from numpyro.distributions.distribution import COERCIONS
from numpyro.primitives import _PYRO_STACK, Messenger, apply_stack, plate
from numpyro.util import not_jax_tracer

__all__ = [
'block',
'collapse',
'condition',
'lift',
'mask',
Expand Down Expand Up @@ -242,6 +244,73 @@ def process_message(self, msg):
msg['stop'] = True


class collapse(trace):
"""
EXPERIMENTAL Collapses all sites in the context by lazily sampling and
attempting to use conjugacy relations. If no conjugacy is known this will
fail. Code using the results of sample sites must be written to accept
Funsors rather than Tensors. This requires ``funsor`` to be installed.
"""
_coerce = None

def __init__(self, *args, **kwargs):
if collapse._coerce is None:
import funsor
from funsor.distribution import CoerceDistributionToFunsor
funsor.set_backend("jax")
collapse._coerce = CoerceDistributionToFunsor("jax")
super().__init__(*args, **kwargs)

def process_message(self, msg):
from funsor.terms import Funsor

if msg["type"] == "sample":
if msg["value"] is None:
msg["value"] = msg["name"]

if isinstance(msg["fn"], Funsor) or isinstance(msg["value"], (str, Funsor)):
msg["stop"] = True

def __enter__(self):
self.preserved_plates = frozenset(h.name for h in _PYRO_STACK
if isinstance(h, plate))
COERCIONS.append(self._coerce)
return super().__enter__()

def __exit__(self, *args, **kwargs):
import funsor

_coerce = COERCIONS.pop()
assert _coerce is self._coerce
super().__exit__(*args, **kwargs)

# Convert delayed statements to pyro.factor()
reduced_vars = []
log_prob_terms = []
plates = frozenset()
for name, site in self.trace.items():
if not site["is_observed"]:
reduced_vars.append(name)
dim_to_name = {f.dim: f.name for f in site["cond_indep_stack"]}
fn = funsor.to_funsor(site["fn"], funsor.Real, dim_to_name)
value = site["value"]
if not isinstance(value, str):
value = funsor.to_funsor(site["value"], fn.inputs["value"], dim_to_name)
log_prob_terms.append(fn(value=value))
plates |= frozenset(f.name for f in site["cond_indep_stack"])
assert log_prob_terms, "nothing to collapse"
reduced_plates = plates - self.preserved_plates
log_prob = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
log_prob_terms,
eliminate=frozenset(reduced_vars) | reduced_plates,
plates=plates,
)
name = reduced_vars[0]
numpyro.factor(name, log_prob.data)


class condition(Messenger):
"""
Conditions unobserved sample sites to values from `data` or `condition_fn`.
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
author_email='[email protected]',
install_requires=[
# TODO: pin to a specific version for the release (until JAX's API becomes stable)
'jax==0.2',
'jax>=0.2',
# check min version here: https://github.com/google/jax/blob/master/jax/lib/__init__.py#L26
'jaxlib==0.1.55',
'jaxlib>=0.1.55',
'tqdm',
],
extras_require={
Expand All @@ -47,12 +47,12 @@
'pyro-api>=0.1.1'
],
'dev': [
'funsor',
'funsor @ git+https://github.com/pyro-ppl/funsor.git@6575ac2c3f7ac25a6a1e5f1107b1b7c072edd992',
'ipython',
'isort',
'flax',
'dm-haiku',
'tfp-nightly==0.12.0.dev20200930', # TODO: change this to stable release or a specific nightly release
'tfp-nightly', # TODO: change this to stable release or a specific nightly release
],
'examples': ['matplotlib', 'seaborn', 'graphviz', 'arviz'],
},
Expand Down
13 changes: 11 additions & 2 deletions test/contrib/test_tfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from numpy.testing import assert_allclose
import pytest

from jax import random
from jax import jit, random
import jax.numpy as jnp

import numpyro
Expand All @@ -28,7 +28,16 @@ def test_api_consistent():
if type(numpyro_dist).__name__ == "function":
numpyro_dist = getattr(numpyro.distributions, name + "Logits")
for p in tfp_dist.arg_constraints:
assert p in dict(inspect.signature(tfp_dist).parameters)
assert p in inspect.getfullargspec(tfp_dist.__init__)[0]


def test_dist_pytree():
from numpyro.contrib.tfp import distributions as tfd

def f(x):
return tfd.Normal(x, 1)

jit(f)(0)


@pytest.mark.filterwarnings("ignore:can't resolve package")
Expand Down
4 changes: 2 additions & 2 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def test_sample_gradient(jax_dist, sp_dist, params):
if not jax_dist.reparametrized_params:
pytest.skip('{} not reparametrized.'.format(jax_dist.__name__))

dist_args = [p.name for p in inspect.signature(jax_dist).parameters.values()]
dist_args = [p for p in inspect.getfullargspec(jax_dist.__init__)[0][1:]]
params_dict = dict(zip(dist_args[:len(params)], params))
nonrepara_params_dict = {k: v for k, v in params_dict.items()
if k not in jax_dist.reparametrized_params}
Expand Down Expand Up @@ -701,7 +701,7 @@ def test_mean_var(jax_dist, sp_dist, params):
(2, 3),
])
def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):
dist_args = [p.name for p in inspect.signature(jax_dist).parameters.values()]
dist_args = [p for p in inspect.getfullargspec(jax_dist.__init__)[0][1:]]

valid_params, oob_params = list(params), list(params)
key = random.PRNGKey(1)
Expand Down
80 changes: 80 additions & 0 deletions test/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,3 +509,83 @@ def model():
with handlers.seed(rng_seed=1):
with handlers.lift(prior=dist.Normal(0, 1)):
model()


def test_collapse_beta_binomial():
total_count = 10
data = 3.

def model1():
c1 = numpyro.param("c1", 0.5, constraint=dist.constraints.positive)
c0 = numpyro.param("c0", 1.5, constraint=dist.constraints.positive)
with handlers.collapse():
probs = numpyro.sample("probs", dist.Beta(c1, c0))
numpyro.sample("obs", dist.Binomial(total_count, probs), obs=data)

def model2():
c1 = numpyro.param("c1", 0.5, constraint=dist.constraints.positive)
c0 = numpyro.param("c0", 1.5, constraint=dist.constraints.positive)
numpyro.sample("obs", dist.BetaBinomial(c1, c0, total_count),
obs=data)

trace1 = handlers.trace(model1).get_trace()
trace2 = handlers.trace(model2).get_trace()
assert "probs" in trace1
assert "obs" not in trace1
assert "probs" not in trace2
assert "obs" in trace2

svi1 = SVI(model1, lambda: None, numpyro.optim.Adam(1), Trace_ELBO())
svi2 = SVI(model1, lambda: None, numpyro.optim.Adam(1), Trace_ELBO())
svi_state1 = svi1.init(random.PRNGKey(0))
svi_state2 = svi2.init(random.PRNGKey(0))
params1 = svi1.get_params(svi_state1)
params2 = svi2.get_params(svi_state2)
assert_allclose(params1["c1"], params2["c1"])
assert_allclose(params1["c0"], params2["c0"])

params1 = svi1.get_params(svi1.update(svi_state1)[0])
params2 = svi2.get_params(svi2.update(svi_state2)[0])
assert_allclose(params1["c1"], params2["c1"])
assert_allclose(params1["c0"], params2["c0"])


def test_collapse_beta_bernoulli():
data = 0.

def model():
c = numpyro.sample("c", dist.Gamma(1, 1))
with handlers.collapse():
probs = numpyro.sample("probs", dist.Beta(c, 2))
numpyro.sample("obs", dist.Bernoulli(probs), obs=data)

def guide():
a = numpyro.param("a", 1., constraint=constraints.positive)
b = numpyro.param("b", 1., constraint=constraints.positive)
numpyro.sample("c", dist.Gamma(a, b))

svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
svi.update(svi_state)


@pytest.mark.xfail(reason="missing pattern in Funsor")
def test_collapse_beta_binomial_plate():
data = np.array([0., 1., 5., 5.])

def model():
c = numpyro.sample("c", dist.Gamma(1, 1))
with handlers.collapse():
probs = numpyro.sample("probs", dist.Beta(c, 2))
with numpyro.plate("plate", len(data)):
numpyro.sample("obs", dist.Binomial(10, probs),
obs=data)

def guide():
a = numpyro.param("a", 1., constraint=constraints.positive)
b = numpyro.param("b", 1., constraint=constraints.positive)
numpyro.sample("c", dist.Gamma(a, b))

svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
svi.update(svi_state)

0 comments on commit 78716a7

Please sign in to comment.