Skip to content

Commit

Permalink
add constraints.greater_than_eq, constraints.positive_semidefinite, c…
Browse files Browse the repository at this point in the history
…onstraints.nonnegative (pyro-ppl#1793)

* fix singleton plate bug

* add greater_than_eq, nonnegative, positive_semidefinite
  • Loading branch information
amifalk authored May 5, 2024
1 parent b500936 commit bc7d651
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 0 deletions.
38 changes: 38 additions & 0 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"corr_matrix",
"dependent",
"greater_than",
"greater_than_eq",
"integer_interval",
"integer_greater_than",
"interval",
Expand All @@ -42,9 +43,11 @@
"less_than",
"lower_cholesky",
"multinomial",
"nonnegative",
"nonnegative_integer",
"positive",
"positive_definite",
"positive_semidefinite",
"positive_integer",
"real",
"real_vector",
Expand Down Expand Up @@ -291,11 +294,26 @@ def __eq__(self, other):
return jnp.array_equal(self.lower_bound, other.lower_bound)


class _GreaterThanEq(_GreaterThan):
def __call__(self, x):
return x >= self.lower_bound

def __eq__(self, other):
if not isinstance(other, _GreaterThanEq):
return False
return jnp.array_equal(self.lower_bound, other.lower_bound)


class _Positive(_SingletonConstraint, _GreaterThan):
def __init__(self):
super().__init__(0.0)


class _Nonnegative(_SingletonConstraint, _GreaterThanEq):
def __init__(self):
super().__init__(0.0)


class _IndependentConstraint(Constraint):
"""
Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many
Expand Down Expand Up @@ -614,6 +632,23 @@ def feasible_like(self, prototype):
)


class _PositiveSemiDefinite(_SingletonConstraint):
event_dim = 2

def __call__(self, x):
jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy
# check for symmetric
symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1))
# check for the smallest eigenvalue is nonnegative
nonnegative = jnp.linalg.eigh(x)[0][..., 0] >= 0
return symmetric & nonnegative

def feasible_like(self, prototype):
return jax.numpy.broadcast_to(
jax.numpy.eye(prototype.shape[-1]), prototype.shape
)


class _PositiveOrderedVector(_SingletonConstraint):
"""
Constrains to a positive real-valued tensor where the elements are monotonically
Expand Down Expand Up @@ -731,6 +766,7 @@ def tree_flatten(self):
corr_matrix = _CorrMatrix()
dependent = _Dependent()
greater_than = _GreaterThan
greater_than_eq = _GreaterThanEq
less_than = _LessThan
independent = _IndependentConstraint
integer_interval = _IntegerInterval
Expand All @@ -740,10 +776,12 @@ def tree_flatten(self):
lower_cholesky = _LowerCholesky()
scaled_unit_lower_cholesky = _ScaledUnitLowerCholesky()
multinomial = _Multinomial
nonnegative = _Nonnegative()
nonnegative_integer = _IntegerNonnegative()
ordered_vector = _OrderedVector()
positive = _Positive()
positive_definite = _PositiveDefinite()
positive_semidefinite = _PositiveSemiDefinite()
positive_integer = _IntegerPositive()
positive_ordered_vector = _PositiveOrderedVector()
real = _Real()
Expand Down
3 changes: 3 additions & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,11 +1512,13 @@ def _transform_to_corr_matrix(constraint):


@biject_to.register(type(constraints.positive))
@biject_to.register(type(constraints.nonnegative))
def _transform_to_positive(constraint):
return ExpTransform()


@biject_to.register(constraints.greater_than)
@biject_to.register(constraints.greater_than_eq)
def _transform_to_greater_than(constraint):
return ComposeTransform(
[
Expand Down Expand Up @@ -1586,6 +1588,7 @@ def _transform_to_ordered_vector(constraint):


@biject_to.register(constraints.positive_definite)
@biject_to.register(constraints.positive_semidefinite)
def _transform_to_positive_definite(constraint):
return ComposeTransform([LowerCholeskyTransform(), CholeskyTransform().inv])

Expand Down
3 changes: 3 additions & 0 deletions test/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
"l1_ball": constraints.l1_ball,
"lower_cholesky": constraints.lower_cholesky,
"scaled_unit_lower_cholesky": constraints.scaled_unit_lower_cholesky,
"nonnegative": constraints.nonnegative,
"nonnegative_integer": constraints.nonnegative_integer,
"ordered_vector": constraints.ordered_vector,
"positive": constraints.positive,
"positive_definite": constraints.positive_definite,
"positive_semidefinite": constraints.positive_semidefinite,
"positive_integer": constraints.positive_integer,
"positive_ordered_vector": constraints.positive_ordered_vector,
"real": constraints.real,
Expand All @@ -48,6 +50,7 @@ class T(namedtuple("TestCase", ["constraint_cls", "params", "kwargs"])):
type(constraints.dependent), (), dict(is_discrete=True, event_dim=2)
),
"greater_than": T(constraints.greater_than, (_a(0.0),), dict()),
"greater_than_eq": T(constraints.greater_than_eq, (_a(0.0),), dict()),
"less_than": T(constraints.less_than, (_a(-1.0),), dict()),
"independent": T(
constraints.independent,
Expand Down
3 changes: 3 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,15 @@ def test_batched_recursive_linear_transform():
(constraints.corr_cholesky, (10, 10)),
(constraints.corr_matrix, (21,)),
(constraints.greater_than(3), ()),
(constraints.greater_than_eq(3), ()),
(constraints.interval(8, 13), (17,)),
(constraints.l1_ball, (4,)),
(constraints.less_than(-1), ()),
(constraints.lower_cholesky, (21,)),
(constraints.open_interval(3, 4), ()),
(constraints.ordered_vector, (5,)),
(constraints.positive_definite, (6,)),
(constraints.positive_semidefinite, (6,)),
(constraints.positive_ordered_vector, (7,)),
(constraints.positive, (7,)),
(constraints.real_matrix, (17,)),
Expand All @@ -379,6 +381,7 @@ def test_batched_recursive_linear_transform():
(constraints.softplus_lower_cholesky, (21,)),
(constraints.softplus_positive, (2,)),
(constraints.unit_interval, (4,)),
(constraints.nonnegative, (7,)),
],
ids=str,
)
Expand Down

0 comments on commit bc7d651

Please sign in to comment.