From bc7d651bd3d75c941d7ffea7eabbbe248b31201a Mon Sep 17 00:00:00 2001 From: Ami Falk <96739930+amifalk@users.noreply.github.com> Date: Sat, 4 May 2024 22:24:23 -0400 Subject: [PATCH] add constraints.greater_than_eq, constraints.positive_semidefinite, constraints.nonnegative (#1793) * fix singleton plate bug * add greater_than_eq, nonnegative, positive_semidefinite --- numpyro/distributions/constraints.py | 38 ++++++++++++++++++++++++++++ numpyro/distributions/transforms.py | 3 +++ test/test_constraints.py | 3 +++ test/test_transforms.py | 3 +++ 4 files changed, 47 insertions(+) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index af29eb038..4d9445db8 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -34,6 +34,7 @@ "corr_matrix", "dependent", "greater_than", + "greater_than_eq", "integer_interval", "integer_greater_than", "interval", @@ -42,9 +43,11 @@ "less_than", "lower_cholesky", "multinomial", + "nonnegative", "nonnegative_integer", "positive", "positive_definite", + "positive_semidefinite", "positive_integer", "real", "real_vector", @@ -291,11 +294,26 @@ def __eq__(self, other): return jnp.array_equal(self.lower_bound, other.lower_bound) +class _GreaterThanEq(_GreaterThan): + def __call__(self, x): + return x >= self.lower_bound + + def __eq__(self, other): + if not isinstance(other, _GreaterThanEq): + return False + return jnp.array_equal(self.lower_bound, other.lower_bound) + + class _Positive(_SingletonConstraint, _GreaterThan): def __init__(self): super().__init__(0.0) +class _Nonnegative(_SingletonConstraint, _GreaterThanEq): + def __init__(self): + super().__init__(0.0) + + class _IndependentConstraint(Constraint): """ Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many @@ -614,6 +632,23 @@ def feasible_like(self, prototype): ) +class _PositiveSemiDefinite(_SingletonConstraint): + event_dim = 2 + + def __call__(self, x): + jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + # check for symmetric + symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) + # check for the smallest eigenvalue is nonnegative + nonnegative = jnp.linalg.eigh(x)[0][..., 0] >= 0 + return symmetric & nonnegative + + def feasible_like(self, prototype): + return jax.numpy.broadcast_to( + jax.numpy.eye(prototype.shape[-1]), prototype.shape + ) + + class _PositiveOrderedVector(_SingletonConstraint): """ Constrains to a positive real-valued tensor where the elements are monotonically @@ -731,6 +766,7 @@ def tree_flatten(self): corr_matrix = _CorrMatrix() dependent = _Dependent() greater_than = _GreaterThan +greater_than_eq = _GreaterThanEq less_than = _LessThan independent = _IndependentConstraint integer_interval = _IntegerInterval @@ -740,10 +776,12 @@ def tree_flatten(self): lower_cholesky = _LowerCholesky() scaled_unit_lower_cholesky = _ScaledUnitLowerCholesky() multinomial = _Multinomial +nonnegative = _Nonnegative() nonnegative_integer = _IntegerNonnegative() ordered_vector = _OrderedVector() positive = _Positive() positive_definite = _PositiveDefinite() +positive_semidefinite = _PositiveSemiDefinite() positive_integer = _IntegerPositive() positive_ordered_vector = _PositiveOrderedVector() real = _Real() diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index a057d86d2..b7d984b5b 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1512,11 +1512,13 @@ def _transform_to_corr_matrix(constraint): @biject_to.register(type(constraints.positive)) +@biject_to.register(type(constraints.nonnegative)) def _transform_to_positive(constraint): return ExpTransform() @biject_to.register(constraints.greater_than) +@biject_to.register(constraints.greater_than_eq) def _transform_to_greater_than(constraint): return ComposeTransform( [ @@ -1586,6 +1588,7 @@ def _transform_to_ordered_vector(constraint): @biject_to.register(constraints.positive_definite) +@biject_to.register(constraints.positive_semidefinite) def _transform_to_positive_definite(constraint): return ComposeTransform([LowerCholeskyTransform(), CholeskyTransform().inv]) diff --git a/test/test_constraints.py b/test/test_constraints.py index 67832c476..e413da219 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -20,10 +20,12 @@ "l1_ball": constraints.l1_ball, "lower_cholesky": constraints.lower_cholesky, "scaled_unit_lower_cholesky": constraints.scaled_unit_lower_cholesky, + "nonnegative": constraints.nonnegative, "nonnegative_integer": constraints.nonnegative_integer, "ordered_vector": constraints.ordered_vector, "positive": constraints.positive, "positive_definite": constraints.positive_definite, + "positive_semidefinite": constraints.positive_semidefinite, "positive_integer": constraints.positive_integer, "positive_ordered_vector": constraints.positive_ordered_vector, "real": constraints.real, @@ -48,6 +50,7 @@ class T(namedtuple("TestCase", ["constraint_cls", "params", "kwargs"])): type(constraints.dependent), (), dict(is_discrete=True, event_dim=2) ), "greater_than": T(constraints.greater_than, (_a(0.0),), dict()), + "greater_than_eq": T(constraints.greater_than_eq, (_a(0.0),), dict()), "less_than": T(constraints.less_than, (_a(-1.0),), dict()), "independent": T( constraints.independent, diff --git a/test/test_transforms.py b/test/test_transforms.py index 4329d8b63..f35345193 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -362,6 +362,7 @@ def test_batched_recursive_linear_transform(): (constraints.corr_cholesky, (10, 10)), (constraints.corr_matrix, (21,)), (constraints.greater_than(3), ()), + (constraints.greater_than_eq(3), ()), (constraints.interval(8, 13), (17,)), (constraints.l1_ball, (4,)), (constraints.less_than(-1), ()), @@ -369,6 +370,7 @@ def test_batched_recursive_linear_transform(): (constraints.open_interval(3, 4), ()), (constraints.ordered_vector, (5,)), (constraints.positive_definite, (6,)), + (constraints.positive_semidefinite, (6,)), (constraints.positive_ordered_vector, (7,)), (constraints.positive, (7,)), (constraints.real_matrix, (17,)), @@ -379,6 +381,7 @@ def test_batched_recursive_linear_transform(): (constraints.softplus_lower_cholesky, (21,)), (constraints.softplus_positive, (2,)), (constraints.unit_interval, (4,)), + (constraints.nonnegative, (7,)), ], ids=str, )