Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add constraints.greater_than_eq, constraints.positive_semidefinite, constraints.nonnegative #1793

Merged
merged 2 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion numpyro/contrib/funsor/infer_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import defaultdict
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
import functools
import re
Expand Down Expand Up @@ -220,6 +220,10 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):
log_prob = scale * log_prob

dim_to_name = site["infer"]["dim_to_name"]

if all(dim == 1 for dim in log_prob.shape) and dim_to_name == OrderedDict():
log_prob = log_prob.squeeze()

log_prob_factor = funsor.to_funsor(
log_prob, output=funsor.Real, dim_to_name=dim_to_name
)
Expand Down
38 changes: 38 additions & 0 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"corr_matrix",
"dependent",
"greater_than",
"greater_than_eq",
"integer_interval",
"integer_greater_than",
"interval",
Expand All @@ -42,9 +43,11 @@
"less_than",
"lower_cholesky",
"multinomial",
"nonnegative",
"nonnegative_integer",
"positive",
"positive_definite",
"positive_semidefinite",
"positive_integer",
"real",
"real_vector",
Expand Down Expand Up @@ -291,11 +294,26 @@ def __eq__(self, other):
return jnp.array_equal(self.lower_bound, other.lower_bound)


class _GreaterThanEq(_GreaterThan):
def __call__(self, x):
return x >= self.lower_bound

def __eq__(self, other):
if not isinstance(other, _GreaterThanEq):
return False
return jnp.array_equal(self.lower_bound, other.lower_bound)


class _Positive(_SingletonConstraint, _GreaterThan):
def __init__(self):
super().__init__(0.0)


class _Nonnegative(_SingletonConstraint, _GreaterThanEq):
def __init__(self):
super().__init__(0.0)


class _IndependentConstraint(Constraint):
"""
Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many
Expand Down Expand Up @@ -614,6 +632,23 @@ def feasible_like(self, prototype):
)


class _PositiveSemiDefinite(_SingletonConstraint):
event_dim = 2

def __call__(self, x):
jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy
# check for symmetric
symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1))
# check for the smallest eigenvalue is nonnegative
nonnegative = jnp.linalg.eigh(x)[0][..., 0] >= 0
return symmetric & nonnegative

def feasible_like(self, prototype):
return jax.numpy.broadcast_to(
jax.numpy.eye(prototype.shape[-1]), prototype.shape
)


class _PositiveOrderedVector(_SingletonConstraint):
"""
Constrains to a positive real-valued tensor where the elements are monotonically
Expand Down Expand Up @@ -731,6 +766,7 @@ def tree_flatten(self):
corr_matrix = _CorrMatrix()
dependent = _Dependent()
greater_than = _GreaterThan
greater_than_eq = _GreaterThanEq
less_than = _LessThan
independent = _IndependentConstraint
integer_interval = _IntegerInterval
Expand All @@ -740,10 +776,12 @@ def tree_flatten(self):
lower_cholesky = _LowerCholesky()
scaled_unit_lower_cholesky = _ScaledUnitLowerCholesky()
multinomial = _Multinomial
nonnegative = _Nonnegative()
nonnegative_integer = _IntegerNonnegative()
ordered_vector = _OrderedVector()
positive = _Positive()
positive_definite = _PositiveDefinite()
positive_semidefinite = _PositiveSemiDefinite()
positive_integer = _IntegerPositive()
positive_ordered_vector = _PositiveOrderedVector()
real = _Real()
Expand Down
3 changes: 3 additions & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,11 +1512,13 @@ def _transform_to_corr_matrix(constraint):


@biject_to.register(type(constraints.positive))
@biject_to.register(type(constraints.nonnegative))
def _transform_to_positive(constraint):
return ExpTransform()


@biject_to.register(constraints.greater_than)
@biject_to.register(constraints.greater_than_eq)
def _transform_to_greater_than(constraint):
return ComposeTransform(
[
Expand Down Expand Up @@ -1586,6 +1588,7 @@ def _transform_to_ordered_vector(constraint):


@biject_to.register(constraints.positive_definite)
@biject_to.register(constraints.positive_semidefinite)
def _transform_to_positive_definite(constraint):
return ComposeTransform([LowerCholeskyTransform(), CholeskyTransform().inv])

Expand Down
15 changes: 15 additions & 0 deletions test/contrib/test_funsor.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,3 +575,18 @@ def gmm(data):
with pytest.raises(Exception):
mcmc.run(random.PRNGKey(2), data)
assert len(_PYRO_STACK) == 0


@pytest.mark.parametrize(
"i_size, j_size, k_size", [(1, 1, 1), (1, 2, 1), (2, 1, 1), (1, 1, 2)]
)
def test_singleton_plate_works(i_size, j_size, k_size):
def model():
with numpyro.plate("i", i_size, dim=-3):
with numpyro.plate("j", j_size, dim=-2):
with numpyro.plate("k", k_size, dim=-1):
numpyro.sample("a", dist.Normal())

model = enum(numpyro.handlers.seed(model, rng_seed=0), first_available_dim=-4)

log_density(model, (), {}, {})
3 changes: 3 additions & 0 deletions test/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
"l1_ball": constraints.l1_ball,
"lower_cholesky": constraints.lower_cholesky,
"scaled_unit_lower_cholesky": constraints.scaled_unit_lower_cholesky,
"nonnegative": constraints.nonnegative,
"nonnegative_integer": constraints.nonnegative_integer,
"ordered_vector": constraints.ordered_vector,
"positive": constraints.positive,
"positive_definite": constraints.positive_definite,
"positive_semidefinite": constraints.positive_semidefinite,
"positive_integer": constraints.positive_integer,
"positive_ordered_vector": constraints.positive_ordered_vector,
"real": constraints.real,
Expand All @@ -48,6 +50,7 @@ class T(namedtuple("TestCase", ["constraint_cls", "params", "kwargs"])):
type(constraints.dependent), (), dict(is_discrete=True, event_dim=2)
),
"greater_than": T(constraints.greater_than, (_a(0.0),), dict()),
"greater_than_eq": T(constraints.greater_than_eq, (_a(0.0),), dict()),
"less_than": T(constraints.less_than, (_a(-1.0),), dict()),
"independent": T(
constraints.independent,
Expand Down
3 changes: 3 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,15 @@ def test_batched_recursive_linear_transform():
(constraints.corr_cholesky, (10, 10)),
(constraints.corr_matrix, (21,)),
(constraints.greater_than(3), ()),
(constraints.greater_than_eq(3), ()),
(constraints.interval(8, 13), (17,)),
(constraints.l1_ball, (4,)),
(constraints.less_than(-1), ()),
(constraints.lower_cholesky, (21,)),
(constraints.open_interval(3, 4), ()),
(constraints.ordered_vector, (5,)),
(constraints.positive_definite, (6,)),
(constraints.positive_semidefinite, (6,)),
(constraints.positive_ordered_vector, (7,)),
(constraints.positive, (7,)),
(constraints.real_matrix, (17,)),
Expand All @@ -379,6 +381,7 @@ def test_batched_recursive_linear_transform():
(constraints.softplus_lower_cholesky, (21,)),
(constraints.softplus_positive, (2,)),
(constraints.unit_interval, (4,)),
(constraints.nonnegative, (7,)),
],
ids=str,
)
Expand Down
Loading