Skip to content

Commit

Permalink
Assert exactly one of parameters is specified and update shape infere…
Browse files Browse the repository at this point in the history
…nce.
  • Loading branch information
tillahoffmann committed Apr 13, 2024
1 parent ec86dc3 commit f2b0cad
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 25 deletions.
23 changes: 17 additions & 6 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
)
from numpyro.distributions.util import (
add_diag,
assert_one_of,
betainc,
betaincinv,
cholesky_of_inverse,
Expand Down Expand Up @@ -1442,6 +1443,11 @@ def __init__(
scale_tril=None,
validate_args=None,
):
assert_one_of(
covariance_matrix=covariance_matrix,
precision_matrix=precision_matrix,
scale_tril=scale_tril,
)
if jnp.ndim(loc) == 0:
(loc,) = promote_shapes(loc, shape=(1,))
# temporary append a new axis to loc
Expand All @@ -1454,11 +1460,6 @@ def __init__(
self.scale_tril = cholesky_of_inverse(self.precision_matrix)
elif scale_tril is not None:
loc, self.scale_tril = promote_shapes(loc, scale_tril)
else:
raise ValueError(
"One of `covariance_matrix`, `precision_matrix`, `scale_tril`"
" must be specified."
)
batch_shape = lax.broadcast_shapes(
jnp.shape(loc)[:-2], jnp.shape(self.scale_tril)[:-2]
)
Expand Down Expand Up @@ -1515,12 +1516,17 @@ def variance(self):
def infer_shapes(
loc=(), covariance_matrix=None, precision_matrix=None, scale_tril=None
):
assert_one_of(
covariance_matrix=covariance_matrix,
precision_matrix=precision_matrix,
scale_tril=scale_tril,
)
batch_shape, event_shape = loc[:-1], loc[-1:]
for matrix in [covariance_matrix, precision_matrix, scale_tril]:
if matrix is not None:
batch_shape = lax.broadcast_shapes(batch_shape, matrix[:-2])
event_shape = lax.broadcast_shapes(event_shape, matrix[-1:])
return batch_shape, event_shape
return batch_shape, event_shape


def _is_sparse(A):
Expand Down Expand Up @@ -2647,6 +2653,11 @@ def __init__(
*,
validate_args=None,
):
assert_one_of(
scale_matrix=scale_matrix,
rate_matrix=rate_matrix,
scale_tril=scale_tril,
)
concentration = jnp.asarray(concentration)[..., None, None]
if scale_matrix is not None:
concentration, self.scale_matrix = promote_shapes(
Expand Down
3 changes: 2 additions & 1 deletion numpyro/distributions/directional.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from numpyro.distributions import constraints
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
assert_one_of,
lazy_property,
promote_shapes,
safe_normalize,
Expand Down Expand Up @@ -349,7 +350,7 @@ def __init__(
weighted_correlation=None,
validate_args=None,
):
assert (correlation is None) != (weighted_correlation is None)
assert_one_of(correlation=correlation, weighted_correlation=weighted_correlation)

if weighted_correlation is not None:
correlation = weighted_correlation * jnp.sqrt(
Expand Down
21 changes: 7 additions & 14 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from numpyro.distributions import constraints, transforms
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
assert_one_of,
binary_cross_entropy_with_logits,
binomial,
categorical,
Expand Down Expand Up @@ -151,12 +152,11 @@ def enumerate_support(self, expand=True):


def Bernoulli(probs=None, logits=None, *, validate_args=None):
assert_one_of(probs=probs, logits=logits)
if probs is not None:
return BernoulliProbs(probs, validate_args=validate_args)
elif logits is not None:
return BernoulliLogits(logits, validate_args=validate_args)
else:
raise ValueError("One of `probs` or `logits` must be specified.")


class BinomialProbs(Distribution):
Expand Down Expand Up @@ -284,12 +284,11 @@ def support(self):


def Binomial(total_count=1, probs=None, logits=None, *, validate_args=None):
assert_one_of(probs=probs, logits=logits)
if probs is not None:
return BinomialProbs(probs, total_count, validate_args=validate_args)
elif logits is not None:
return BinomialLogits(logits, total_count, validate_args=validate_args)
else:
raise ValueError("One of `probs` or `logits` must be specified.")


class CategoricalProbs(Distribution):
Expand Down Expand Up @@ -395,12 +394,11 @@ def enumerate_support(self, expand=True):


def Categorical(probs=None, logits=None, *, validate_args=None):
assert_one_of(probs=probs, logits=logits)
if probs is not None:
return CategoricalProbs(probs, validate_args=validate_args)
elif logits is not None:
return CategoricalLogits(logits, validate_args=validate_args)
else:
raise ValueError("One of `probs` or `logits` must be specified.")


class DiscreteUniform(Distribution):
Expand Down Expand Up @@ -648,6 +646,7 @@ def Multinomial(
:param int total_count_max: the maximum number of trials,
i.e. `max(total_count)`
"""
assert_one_of(probs=probs, logits=logits)
if probs is not None:
return MultinomialProbs(
probs,
Expand All @@ -662,8 +661,6 @@ def Multinomial(
total_count_max=total_count_max,
validate_args=validate_args,
)
else:
raise ValueError("One of `probs` or `logits` must be specified.")


class Poisson(Distribution):
Expand Down Expand Up @@ -815,10 +812,7 @@ def ZeroInflatedDistribution(
:param numpy.ndarray gate: probability of extra zeros given via a Bernoulli distribution.
:param numpy.ndarray gate_logits: logits of extra zeros given via a Bernoulli distribution.
"""
if (gate is None) == (gate_logits is None):
raise ValueError(
"Either `gate` or `gate_logits` must be specified, but not both."
)
assert_one_of(gate=gate, gate_logits=gate_logits)
if gate is not None:
return ZeroInflatedProbs(base_dist, gate, validate_args=validate_args)
else:
Expand Down Expand Up @@ -916,9 +910,8 @@ def variance(self):


def Geometric(probs=None, logits=None, *, validate_args=None):
assert_one_of(probs=probs, logits=logits)
if probs is not None:
return GeometricProbs(probs, validate_args=validate_args)
elif logits is not None:
return GeometricLogits(logits, validate_args=validate_args)
else:
raise ValueError("One of `probs` or `logits` must be specified.")
5 changes: 3 additions & 2 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,9 @@ def infer_shapes(cls, *args, **kwargs):
# Assumes distribution is univariate.
batch_shapes = []
for name, shape in kwargs.items():
event_dim = cls.arg_constraints.get(name, constraints.real).event_dim
batch_shapes.append(shape[: len(shape) - event_dim])
if shape is not None:
event_dim = cls.arg_constraints.get(name, constraints.real).event_dim
batch_shapes.append(shape[: len(shape) - event_dim])
batch_shape = lax.broadcast_shapes(*batch_shapes) if batch_shapes else ()
event_shape = ()
return batch_shape, event_shape
Expand Down
11 changes: 11 additions & 0 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,17 @@ def is_prng_key(key):
return False


def assert_one_of(**kwargs):
"""
Assert that exactly one of the keyword arguments is not None.
"""
specified = [key for key, value in kwargs.items() if value is not None]
if len(specified) != 1:
raise ValueError(
f"Exactly one of {list(kwargs)} must be specified; got {specified}."
)


# The is sourced from: torch.distributions.util.py
#
# Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Expand Down
11 changes: 9 additions & 2 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,8 +1233,15 @@ def test_dist_shape(jax_dist, sp_dist, params, prepend_shape):
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_infer_shapes(jax_dist, sp_dist, params):
shapes = tuple(getattr(p, "shape", ()) for p in params)
shapes = tuple(x() if callable(x) else x for x in shapes)
shapes = []
for param in params:
if param is None:
shapes.append(None)
continue
shape = getattr(param, "shape", ())
if callable(shape):
shape = shape()
shapes.append(shape)
jax_dist = jax_dist(*params)
try:
expected_batch_shape, expected_event_shape = type(jax_dist).infer_shapes(
Expand Down

0 comments on commit f2b0cad

Please sign in to comment.