Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make omnistaging work with validate_args=True #775

Merged
merged 6 commits into from
Oct 14, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import numpy as np

from jax.dtypes import canonicalize_dtype
import jax.numpy as jnp
from tensorflow_probability.substrates.jax import bijectors as tfb
Expand All @@ -22,7 +24,7 @@ def _get_codomain(bijector):
return constraints.positive
elif bijector.__class__.__name__ == "GeneralizedPareto":
loc, scale, concentration = bijector.loc, bijector.scale, bijector.concentration
if not_jax_tracer(concentration) and jnp.all(concentration < 0):
if not_jax_tracer(concentration) and np.all(concentration < 0):
return constraints.interval(loc, loc + scale / jnp.abs(concentration))
# XXX: here we suppose concentration > 0
# which is not true in general, but should cover enough usage cases
Expand Down
26 changes: 4 additions & 22 deletions numpyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@

from numpyro.distributions import constraints
from numpyro.distributions.continuous import Beta, Dirichlet, Gamma
from numpyro.distributions.discrete import Binomial, Multinomial, Poisson
from numpyro.distributions.discrete import BinomialProbs, MultinomialProbs, Poisson
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import promote_shapes, validate_sample
from numpyro.util import not_jax_tracer


def _log_beta_1(alpha, value):
Expand All @@ -35,6 +34,7 @@ class BetaBinomial(Distribution):
'total_count': constraints.nonnegative_integer}
has_enumerate_support = True
is_discrete = True
enumerate_support = BinomialProbs.enumerate_support

def __init__(self, concentration1, concentration0, total_count=1, validate_args=None):
batch_shape = lax.broadcast_shapes(jnp.shape(concentration1), jnp.shape(concentration0),
Expand All @@ -48,7 +48,7 @@ def __init__(self, concentration1, concentration0, total_count=1, validate_args=
def sample(self, key, sample_shape=()):
key_beta, key_binom = random.split(key)
probs = self._beta.sample(key_beta, sample_shape)
return Binomial(self.total_count, probs).sample(key_binom)
return BinomialProbs(total_count=self.total_count, probs=probs).sample(key_binom)

@validate_sample
def log_prob(self, value):
Expand All @@ -68,18 +68,6 @@ def variance(self):
def support(self):
return constraints.integer_interval(0, self.total_count)

def enumerate_support(self, expand=True):
total_count = jnp.amax(self.total_count)
if not_jax_tracer(total_count):
# NB: the error can't be raised if inhomogeneous issue happens when tracing
if jnp.amin(self.total_count) != total_count:
raise NotImplementedError("Inhomogeneous total count not supported"
" by `enumerate_support`.")
values = jnp.arange(total_count + 1).reshape((-1,) + (1,) * len(self.batch_shape))
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values


class DirichletMultinomial(Distribution):
r"""
Expand Down Expand Up @@ -111,13 +99,7 @@ def __init__(self, concentration, total_count=1, validate_args=None):
def sample(self, key, sample_shape=()):
key_dirichlet, key_multinom = random.split(key)
probs = self._dirichlet.sample(key_dirichlet, sample_shape)
total_count = jnp.amax(self.total_count)
if not_jax_tracer(total_count):
# NB: the error can't be raised if inhomogeneous issue happens when tracing
if jnp.amin(self.total_count) != total_count:
raise NotImplementedError("Inhomogeneous total count not supported"
" by `sample`.")
return Multinomial(total_count, probs).sample(key_multinom)
return MultinomialProbs(total_count=self.total_count, probs=probs).sample(key_multinom)

@validate_sample
def log_prob(self, value):
Expand Down
21 changes: 6 additions & 15 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,14 @@ def support(self):
return constraints.integer_interval(0, self.total_count)

def enumerate_support(self, expand=True):
total_count = jnp.amax(self.total_count)
if not_jax_tracer(total_count):
if not_jax_tracer(self.total_count):
total_count = np.amax(self.total_count)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as my comment below - is this np needed because self.total_count could be a python scalar. In that case does it make sense to use jnp.array or device_put everywhere? That way, we can freely use jnp functions instead of a mix of np and jnp which becomes hard to reason about.

# NB: the error can't be raised if inhomogeneous issue happens when tracing
if jnp.amin(self.total_count) != total_count:
if np.amin(self.total_count) != total_count:
raise NotImplementedError("Inhomogeneous total count not supported"
" by `enumerate_support`.")
else:
total_count = jnp.amax(self.total_count)
values = jnp.arange(total_count + 1).reshape((-1,) + (1,) * len(self.batch_shape))
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
Expand All @@ -200,6 +202,7 @@ class BinomialLogits(Distribution):
'total_count': constraints.nonnegative_integer}
has_enumerate_support = True
is_discrete = True
enumerate_support = BinomialProbs.enumerate_support

def __init__(self, logits, total_count=1, validate_args=None):
self.logits, self.total_count = promote_shapes(logits, total_count)
Expand Down Expand Up @@ -235,18 +238,6 @@ def variance(self):
def support(self):
return constraints.integer_interval(0, self.total_count)

def enumerate_support(self, expand=True):
total_count = jnp.amax(self.total_count)
if not_jax_tracer(total_count):
# NB: the error can't be raised if inhomogeneous issue happens when tracing
if jnp.amin(self.total_count) != total_count:
raise NotImplementedError("Inhomogeneous total count not supported"
" by `enumerate_support`.")
values = jnp.arange(total_count + 1).reshape((-1,) + (1,) * len(self.batch_shape))
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values


def Binomial(total_count=1, probs=None, logits=None, validate_args=None):
if probs is not None:
Expand Down
8 changes: 5 additions & 3 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from contextlib import contextmanager
import warnings

import numpy as np

from jax import lax, tree_util
import jax.numpy as jnp

Expand Down Expand Up @@ -138,9 +140,9 @@ def __init__(self, batch_shape=(), event_shape=(), validate_args=None):
continue
if is_dependent(constraint):
continue # skip constraints that cannot be checked
is_valid = jnp.all(constraint(getattr(self, param)))
is_valid = constraint(getattr(self, param))
if not_jax_tracer(is_valid):
if not is_valid:
if not np.all(is_valid):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, if we were to always cast a python bool to a numpy or jax bool type (i.e. the default mask would be a jnp.array(True)), we would not have needed any change here. Is that correct?

Copy link
Member Author

@fehiepsi fehiepsi Oct 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We we need this change because under jit. jnp.all(True), device_put(True) will create a tracer. So validation code won't activate.

from jax import device_put, jit
import jax.numpy as jnp
from numpyro.util import not_jax_tracer

def f():
    assert not not_jax_tracer(jnp.all(True))
    assert not not_jax_tracer(device_put(True))

jit(f)()  # pass

You are right that if we cast mask in _validate_sample method to a device array, then the error will be fixed. In that case, no validation for parameters/sample are activated under jit. It is fine for single chain mcmc because we have some code to trace the model outside of jit context, so if there are some wrong specification, the users will get the warning/error. But for multi chain mcmc, no validation will be activated because we run everything under pmap.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining about pmap, I missed that. In that case, is it correct to say that for distributions that have all parameters as jax devices (no python scalars), there won't be any validation performed under pmap because not_jax_tracer(param) will be true?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's true.

raise ValueError("{} distribution got invalid {} parameter.".format(
self.__class__.__name__, param))
super(Distribution, self).__init__()
Expand Down Expand Up @@ -243,7 +245,7 @@ def variance(self):
def _validate_sample(self, value):
mask = self.support(value)
if not_jax_tracer(mask):
if not jnp.all(mask):
if not np.all(mask):
warnings.warn('Out-of-support values provided to log prob method. '
'The value argument should be within the support.')
return mask
Expand Down
8 changes: 5 additions & 3 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import math
import warnings

import numpy as np

from jax import ops, tree_flatten, tree_map, vmap
from jax.dtypes import canonicalize_dtype
from jax.flatten_util import ravel_pytree
Expand Down Expand Up @@ -96,19 +98,19 @@ def codomain(self):
elif self.domain is constraints.real_vector:
return constraints.real_vector
elif isinstance(self.domain, constraints.greater_than):
if not_jax_tracer(self.scale) and jnp.all(self.scale < 0):
if not_jax_tracer(self.scale) and np.all(self.scale < 0):
return constraints.less_than(self(self.domain.lower_bound))
# we suppose scale > 0 for any tracer
else:
return constraints.greater_than(self(self.domain.lower_bound))
elif isinstance(self.domain, constraints.less_than):
if not_jax_tracer(self.scale) and jnp.all(self.scale < 0):
if not_jax_tracer(self.scale) and np.all(self.scale < 0):
return constraints.greater_than(self(self.domain.upper_bound))
# we suppose scale > 0 for any tracer
else:
return constraints.less_than(self(self.domain.upper_bound))
elif isinstance(self.domain, constraints.interval):
if not_jax_tracer(self.scale) and jnp.all(self.scale < 0):
if not_jax_tracer(self.scale) and np.all(self.scale < 0):
return constraints.interval(self(self.domain.upper_bound),
self(self.domain.lower_bound))
else:
Expand Down
4 changes: 3 additions & 1 deletion numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from contextlib import ExitStack
import warnings

import numpy as np

from jax import hessian, lax, random, tree_map
from jax.experimental import stax
from jax.flatten_util import ravel_pytree
Expand Down Expand Up @@ -647,7 +649,7 @@ def loss_fn(z):
precision = hessian(loss_fn)(loc)
scale_tril = cholesky_of_inverse(precision)
if not_jax_tracer(scale_tril):
if jnp.any(jnp.isnan(scale_tril)):
if np.any(np.isnan(scale_tril)):
warnings.warn("Hessian of log posterior at the MAP point is singular. Posterior"
" samples from AutoLaplaceApproxmiation will be constant (equal to"
" the MAP point).")
Expand Down
16 changes: 16 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,22 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):
d.log_prob(oob_samples)


def test_omnistaging_invalid_param():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have tests that check whether arg constraints are satisfied, but I suppose we do not have tests that check validation for out of support values in distributions? If so, it will be good to add tests for all relevant distributions and do a patch release. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is fine, but my concern is if we have missed something like this for other distributions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think we already have some tests that check for both constraints and support. Let me also add jit check for them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you raised a very good point. It is just lucky that my test passed for positive constraint. For many other constraints, no warning is raised! :( The issue is if the constraint uses jax.numpy to check, the output will be a traced array (under jit), regardless the input is a DeviceArray or not.

def f(x):
return dist.LogNormal(x, -np.ones(2), validate_args=True).log_prob(0)

with pytest.raises(ValueError, match="got invalid"):
jax.jit(f)(0)


def test_omnistaging_invalid_sample():
def f(x):
return dist.LogNormal(x, np.ones(2), validate_args=True).log_prob(-1)

with pytest.warns(UserWarning, match="Out-of-support"):
jax.jit(f)(0)


def test_categorical_log_prob_grad():
data = jnp.repeat(jnp.arange(3), 10)

Expand Down