Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make omnistaging work with validate_args=True #775

Merged
merged 6 commits into from
Oct 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import numpy as np

from jax.dtypes import canonicalize_dtype
import jax.numpy as jnp
from tensorflow_probability.substrates.jax import bijectors as tfb
Expand All @@ -22,7 +24,7 @@ def _get_codomain(bijector):
return constraints.positive
elif bijector.__class__.__name__ == "GeneralizedPareto":
loc, scale, concentration = bijector.loc, bijector.scale, bijector.concentration
if not_jax_tracer(concentration) and jnp.all(concentration < 0):
if not_jax_tracer(concentration) and np.all(concentration < 0):
return constraints.interval(loc, loc + scale / jnp.abs(concentration))
# XXX: here we suppose concentration > 0
# which is not true in general, but should cover enough usage cases
Expand Down
45 changes: 15 additions & 30 deletions numpyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@

from numpyro.distributions import constraints
from numpyro.distributions.continuous import Beta, Dirichlet, Gamma
from numpyro.distributions.discrete import Binomial, Multinomial, Poisson
from numpyro.distributions.discrete import BinomialProbs, MultinomialProbs, Poisson
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import promote_shapes, validate_sample
from numpyro.util import not_jax_tracer


def _log_beta_1(alpha, value):
Expand All @@ -35,20 +34,23 @@ class BetaBinomial(Distribution):
'total_count': constraints.nonnegative_integer}
has_enumerate_support = True
is_discrete = True
enumerate_support = BinomialProbs.enumerate_support

def __init__(self, concentration1, concentration0, total_count=1, validate_args=None):
self.concentration1, self.concentration0, self.total_count = promote_shapes(
concentration1, concentration0, total_count
)
batch_shape = lax.broadcast_shapes(jnp.shape(concentration1), jnp.shape(concentration0),
jnp.shape(total_count))
self.concentration1 = jnp.broadcast_to(concentration1, batch_shape)
self.concentration0 = jnp.broadcast_to(concentration0, batch_shape)
self.total_count, = promote_shapes(total_count, shape=batch_shape)
self._beta = Beta(self.concentration1, self.concentration0)
concentration1 = jnp.broadcast_to(concentration1, batch_shape)
concentration0 = jnp.broadcast_to(concentration0, batch_shape)
self._beta = Beta(concentration1, concentration0)
super(BetaBinomial, self).__init__(batch_shape, validate_args=validate_args)

def sample(self, key, sample_shape=()):
key_beta, key_binom = random.split(key)
probs = self._beta.sample(key_beta, sample_shape)
return Binomial(self.total_count, probs).sample(key_binom)
return BinomialProbs(total_count=self.total_count, probs=probs).sample(key_binom)

@validate_sample
def log_prob(self, value):
Expand All @@ -68,18 +70,6 @@ def variance(self):
def support(self):
return constraints.integer_interval(0, self.total_count)

def enumerate_support(self, expand=True):
total_count = jnp.amax(self.total_count)
if not_jax_tracer(total_count):
# NB: the error can't be raised if inhomogeneous issue happens when tracing
if jnp.amin(self.total_count) != total_count:
raise NotImplementedError("Inhomogeneous total count not supported"
" by `enumerate_support`.")
values = jnp.arange(total_count + 1).reshape((-1,) + (1,) * len(self.batch_shape))
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values


class DirichletMultinomial(Distribution):
r"""
Expand All @@ -102,22 +92,18 @@ def __init__(self, concentration, total_count=1, validate_args=None):
raise ValueError("`concentration` parameter must be at least one-dimensional.")

batch_shape = lax.broadcast_shapes(jnp.shape(concentration)[:-1], jnp.shape(total_count))
self.concentration = jnp.broadcast_to(concentration, batch_shape + jnp.shape(concentration)[-1:])
self._dirichlet = Dirichlet(self.concentration)
concentration_shape = batch_shape + jnp.shape(concentration)[-1:]
self.concentration, = promote_shapes(concentration, shape=concentration_shape)
self.total_count, = promote_shapes(total_count, shape=batch_shape)
concentration = jnp.broadcast_to(self.concentration, concentration_shape)
self._dirichlet = Dirichlet(concentration)
super().__init__(
self._dirichlet.batch_shape, self._dirichlet.event_shape, validate_args=validate_args)

def sample(self, key, sample_shape=()):
key_dirichlet, key_multinom = random.split(key)
probs = self._dirichlet.sample(key_dirichlet, sample_shape)
total_count = jnp.amax(self.total_count)
if not_jax_tracer(total_count):
# NB: the error can't be raised if inhomogeneous issue happens when tracing
if jnp.amin(self.total_count) != total_count:
raise NotImplementedError("Inhomogeneous total count not supported"
" by `sample`.")
return Multinomial(total_count, probs).sample(key_multinom)
return MultinomialProbs(total_count=self.total_count, probs=probs).sample(key_multinom)

@validate_sample
def log_prob(self, value):
Expand Down Expand Up @@ -157,9 +143,8 @@ class GammaPoisson(Distribution):
is_discrete = True

def __init__(self, concentration, rate=1., validate_args=None):
self.concentration, self.rate = promote_shapes(concentration, rate)
self._gamma = Gamma(concentration, rate)
self.concentration = self._gamma.concentration
self.rate = self._gamma.rate
super(GammaPoisson, self).__init__(self._gamma.batch_shape, validate_args=validate_args)

def sample(self, key, sample_shape=()):
Expand Down
22 changes: 14 additions & 8 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@
'Constraint',
]

import jax.numpy as jnp
import numpy as np

import jax.numpy


class Constraint(object):
Expand Down Expand Up @@ -79,6 +81,7 @@ def __call__(self, x):

class _CorrCholesky(Constraint):
def __call__(self, x):
jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy
tril = jnp.tril(x)
lower_triangular = jnp.all(jnp.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1)
positive_diagonal = jnp.all(jnp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1)
Expand All @@ -89,6 +92,7 @@ def __call__(self, x):

class _CorrMatrix(Constraint):
def __call__(self, x):
jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy
# check for symmetric
symmetric = jnp.all(jnp.all(x == jnp.swapaxes(x, -2, -1), axis=-1), axis=-1)
# check for the smallest eigenvalue is positive
Expand Down Expand Up @@ -129,7 +133,7 @@ def __init__(self, lower_bound, upper_bound):
self.upper_bound = upper_bound

def __call__(self, x):
return (x >= self.lower_bound) & (x <= self.upper_bound) & (x == jnp.floor(x))
return (x >= self.lower_bound) & (x <= self.upper_bound) & (x % 1 == 0)


class _IntegerGreaterThan(Constraint):
Expand All @@ -151,6 +155,7 @@ def __call__(self, x):

class _LowerCholesky(Constraint):
def __call__(self, x):
jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy
tril = jnp.tril(x)
lower_triangular = jnp.all(jnp.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1)
positive_diagonal = jnp.all(jnp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1)
Expand All @@ -162,16 +167,17 @@ def __init__(self, upper_bound):
self.upper_bound = upper_bound

def __call__(self, x):
return jnp.all(x >= 0, axis=-1) & (jnp.sum(x, -1) == self.upper_bound)
return (x >= 0).all(axis=-1) & (x.sum(axis=-1) == self.upper_bound)


class _OrderedVector(Constraint):
def __call__(self, x):
return jnp.all(x[..., 1:] > x[..., :-1], axis=-1)
return (x[..., 1:] > x[..., :-1]).all(axis=-1)


class _PositiveDefinite(Constraint):
def __call__(self, x):
jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy
# check for symmetric
symmetric = jnp.all(jnp.all(x == jnp.swapaxes(x, -2, -1), axis=-1), axis=-1)
# check for the smallest eigenvalue is positive
Expand All @@ -182,18 +188,18 @@ def __call__(self, x):
class _Real(Constraint):
def __call__(self, x):
# XXX: consider to relax this condition to [-inf, inf] interval
return jnp.isfinite(x)
return (x == x) & (x != float('inf')) & (x != float('-inf'))


class _RealVector(Constraint):
def __call__(self, x):
return jnp.all(jnp.isfinite(x), axis=-1)
return ((x == x) & (x != float('inf')) & (x != float('-inf'))).all(axis=-1)


class _Simplex(Constraint):
def __call__(self, x):
x_sum = jnp.sum(x, axis=-1)
return jnp.all(x >= 0, axis=-1) & (x_sum < 1 + 1e-6) & (x_sum > 1 - 1e-6)
x_sum = x.sum(axis=-1)
return (x >= 0).all(axis=-1) & (x_sum < 1 + 1e-6) & (x_sum > 1 - 1e-6)


# TODO: Make types consistent
Expand Down
36 changes: 19 additions & 17 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ class Beta(Distribution):
support = constraints.unit_interval

def __init__(self, concentration1, concentration0, validate_args=None):
self.concentration1, self.concentration0 = promote_shapes(concentration1, concentration0)
batch_shape = lax.broadcast_shapes(jnp.shape(concentration1), jnp.shape(concentration0))
self.concentration1 = jnp.broadcast_to(concentration1, batch_shape)
self.concentration0 = jnp.broadcast_to(concentration0, batch_shape)
self._dirichlet = Dirichlet(jnp.stack([self.concentration1, self.concentration0],
concentration1 = jnp.broadcast_to(concentration1, batch_shape)
concentration0 = jnp.broadcast_to(concentration0, batch_shape)
self._dirichlet = Dirichlet(jnp.stack([concentration1, concentration0],
axis=-1))
super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

Expand Down Expand Up @@ -687,8 +688,8 @@ class MultivariateNormal(Distribution):

def __init__(self, loc=0., covariance_matrix=None, precision_matrix=None, scale_tril=None,
validate_args=None):
if jnp.isscalar(loc):
loc = jnp.expand_dims(loc, axis=-1)
if jnp.ndim(loc) == 0:
loc, = promote_shapes(loc, shape=(1,))
# temporary append a new axis to loc
loc = loc[..., jnp.newaxis]
if covariance_matrix is not None:
Expand All @@ -704,7 +705,7 @@ def __init__(self, loc=0., covariance_matrix=None, precision_matrix=None, scale_
' must be specified.')
batch_shape = lax.broadcast_shapes(jnp.shape(loc)[:-2], jnp.shape(self.scale_tril)[:-2])
event_shape = jnp.shape(self.scale_tril)[-1:]
self.loc = jnp.broadcast_to(jnp.squeeze(loc, axis=-1), batch_shape + event_shape)
self.loc = loc[..., 0]
super(MultivariateNormal, self).__init__(batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args)
Expand All @@ -731,7 +732,7 @@ def precision_matrix(self):

@property
def mean(self):
return self.loc
return jnp.broadcast_to(self.loc, self.shape())

@property
def variance(self):
Expand Down Expand Up @@ -817,7 +818,7 @@ def __init__(self, loc, cov_factor, cov_diag, validate_args=None):

loc, cov_factor, cov_diag = promote_shapes(loc[..., jnp.newaxis], cov_factor, cov_diag[..., jnp.newaxis])
batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(cov_factor), jnp.shape(cov_diag))[:-2]
self.loc = jnp.broadcast_to(loc[..., 0], batch_shape + event_shape)
self.loc = loc[..., 0]
self.cov_factor = cov_factor
cov_diag = cov_diag[..., 0]
self.cov_diag = cov_diag
Expand Down Expand Up @@ -937,22 +938,23 @@ class Pareto(TransformedDistribution):
arg_constraints = {'scale': constraints.positive, 'alpha': constraints.positive}

def __init__(self, scale, alpha, validate_args=None):
self.scale, self.alpha = promote_shapes(scale, alpha)
batch_shape = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(alpha))
self.scale, self.alpha = jnp.broadcast_to(scale, batch_shape), jnp.broadcast_to(alpha, batch_shape)
base_dist = Exponential(self.alpha)
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
scale, alpha = jnp.broadcast_to(scale, batch_shape), jnp.broadcast_to(alpha, batch_shape)
base_dist = Exponential(alpha)
transforms = [ExpTransform(), AffineTransform(loc=0, scale=scale)]
super(Pareto, self).__init__(base_dist, transforms, validate_args=validate_args)

@property
def mean(self):
# mean is inf for alpha <= 1
a = lax.div(self.alpha * self.scale, (self.alpha - 1))
a = jnp.divide(self.alpha * self.scale, (self.alpha - 1))
return jnp.where(self.alpha <= 1, jnp.inf, a)

@property
def variance(self):
# var is inf for alpha <= 2
a = lax.div((self.scale ** 2) * self.alpha, (self.alpha - 1) ** 2 * (self.alpha - 2))
a = jnp.divide((self.scale ** 2) * self.alpha, (self.alpha - 1) ** 2 * (self.alpha - 2))
return jnp.where(self.alpha <= 2, jnp.inf, a)

# override the default behaviour to save computations
Expand All @@ -971,9 +973,9 @@ class StudentT(Distribution):

def __init__(self, df, loc=0., scale=1., validate_args=None):
batch_shape = lax.broadcast_shapes(jnp.shape(df), jnp.shape(loc), jnp.shape(scale))
self.df = jnp.broadcast_to(df, batch_shape)
self.loc, self.scale = promote_shapes(loc, scale, shape=batch_shape)
self._chi2 = Chi2(self.df)
self.df, self.loc, self.scale = promote_shapes(df, loc, scale, shape=batch_shape)
df = jnp.broadcast_to(df, batch_shape)
self._chi2 = Chi2(df)
super(StudentT, self).__init__(batch_shape, validate_args=validate_args)

def sample(self, key, sample_shape=()):
Expand All @@ -997,7 +999,7 @@ def mean(self):

@property
def variance(self):
var = jnp.where(self.df > 2, self.scale ** 2 * self.df / (self.df - 2.0), jnp.inf)
var = jnp.where(self.df > 2, jnp.divide(self.scale ** 2 * self.df, self.df - 2.0), jnp.inf)
var = jnp.where(self.df <= 1, jnp.nan, var)
return jnp.broadcast_to(var, self.batch_shape)

Expand Down
27 changes: 11 additions & 16 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,14 @@ def support(self):
return constraints.integer_interval(0, self.total_count)

def enumerate_support(self, expand=True):
total_count = jnp.amax(self.total_count)
if not_jax_tracer(total_count):
if not_jax_tracer(self.total_count):
total_count = np.amax(self.total_count)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as my comment below - is this np needed because self.total_count could be a python scalar. In that case does it make sense to use jnp.array or device_put everywhere? That way, we can freely use jnp functions instead of a mix of np and jnp which becomes hard to reason about.

# NB: the error can't be raised if inhomogeneous issue happens when tracing
if jnp.amin(self.total_count) != total_count:
if np.amin(self.total_count) != total_count:
raise NotImplementedError("Inhomogeneous total count not supported"
" by `enumerate_support`.")
else:
total_count = jnp.amax(self.total_count)
values = jnp.arange(total_count + 1).reshape((-1,) + (1,) * len(self.batch_shape))
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
Expand All @@ -200,6 +202,7 @@ class BinomialLogits(Distribution):
'total_count': constraints.nonnegative_integer}
has_enumerate_support = True
is_discrete = True
enumerate_support = BinomialProbs.enumerate_support

def __init__(self, logits, total_count=1, validate_args=None):
self.logits, self.total_count = promote_shapes(logits, total_count)
Expand Down Expand Up @@ -235,18 +238,6 @@ def variance(self):
def support(self):
return constraints.integer_interval(0, self.total_count)

def enumerate_support(self, expand=True):
total_count = jnp.amax(self.total_count)
if not_jax_tracer(total_count):
# NB: the error can't be raised if inhomogeneous issue happens when tracing
if jnp.amin(self.total_count) != total_count:
raise NotImplementedError("Inhomogeneous total count not supported"
" by `enumerate_support`.")
values = jnp.arange(total_count + 1).reshape((-1,) + (1,) * len(self.batch_shape))
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values


def Binomial(total_count=1, probs=None, logits=None, validate_args=None):
if probs is not None:
Expand Down Expand Up @@ -421,7 +412,11 @@ class OrderedLogistic(CategoricalProbs):
'cutpoints': constraints.ordered_vector}

def __init__(self, predictor, cutpoints, validate_args=None):
predictor, self.cutpoints = promote_shapes(jnp.expand_dims(predictor, -1), cutpoints)
if jnp.ndim(predictor) == 0:
predictor, = promote_shapes(predictor, shape=(1,))
else:
predictor = predictor[..., None]
predictor, self.cutpoints = promote_shapes(predictor, cutpoints)
self.predictor = predictor[..., 0]
cumulative_probs = expit(cutpoints - predictor)
# add two boundary points 0 and 1
Expand Down
8 changes: 5 additions & 3 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from contextlib import contextmanager
import warnings

import numpy as np

from jax import lax, tree_util
import jax.numpy as jnp

Expand Down Expand Up @@ -150,9 +152,9 @@ def __init__(self, batch_shape=(), event_shape=(), validate_args=None):
continue
if is_dependent(constraint):
continue # skip constraints that cannot be checked
is_valid = jnp.all(constraint(getattr(self, param)))
is_valid = constraint(getattr(self, param))
if not_jax_tracer(is_valid):
if not is_valid:
if not np.all(is_valid):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, if we were to always cast a python bool to a numpy or jax bool type (i.e. the default mask would be a jnp.array(True)), we would not have needed any change here. Is that correct?

Copy link
Member Author

@fehiepsi fehiepsi Oct 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We we need this change because under jit. jnp.all(True), device_put(True) will create a tracer. So validation code won't activate.

from jax import device_put, jit
import jax.numpy as jnp
from numpyro.util import not_jax_tracer

def f():
    assert not not_jax_tracer(jnp.all(True))
    assert not not_jax_tracer(device_put(True))

jit(f)()  # pass

You are right that if we cast mask in _validate_sample method to a device array, then the error will be fixed. In that case, no validation for parameters/sample are activated under jit. It is fine for single chain mcmc because we have some code to trace the model outside of jit context, so if there are some wrong specification, the users will get the warning/error. But for multi chain mcmc, no validation will be activated because we run everything under pmap.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining about pmap, I missed that. In that case, is it correct to say that for distributions that have all parameters as jax devices (no python scalars), there won't be any validation performed under pmap because not_jax_tracer(param) will be true?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's true.

raise ValueError("{} distribution got invalid {} parameter.".format(
self.__class__.__name__, param))
super(Distribution, self).__init__()
Expand Down Expand Up @@ -255,7 +257,7 @@ def variance(self):
def _validate_sample(self, value):
mask = self.support(value)
if not_jax_tracer(mask):
if not jnp.all(mask):
if not np.all(mask):
warnings.warn('Out-of-support values provided to log prob method. '
'The value argument should be within the support.')
return mask
Expand Down
Loading