diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 4d9445db8..56075078f 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -415,6 +415,16 @@ def __eq__(self, other): return jnp.array_equal(self.upper_bound, other.upper_bound) +class _LessThanEq(_LessThan): + def __call__(self, x): + return x <= self.upper_bound + + def __eq__(self, other): + if not isinstance(other, _LessThanEq): + return False + return jnp.array_equal(self.upper_bound, other.upper_bound) + + class _IntegerInterval(Constraint): is_discrete = True @@ -768,6 +778,7 @@ def tree_flatten(self): greater_than = _GreaterThan greater_than_eq = _GreaterThanEq less_than = _LessThan +less_than_eq = _LessThanEq independent = _IndependentConstraint integer_interval = _IntegerInterval integer_greater_than = _IntegerGreaterThan diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index b7d984b5b..564c628d8 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1529,6 +1529,7 @@ def _transform_to_greater_than(constraint): @biject_to.register(constraints.less_than) +@biject_to.register(constraints.less_than_eq) def _transform_to_less_than(constraint): return ComposeTransform( [ diff --git a/test/test_constraints.py b/test/test_constraints.py index e413da219..bfb459bdd 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -52,6 +52,7 @@ class T(namedtuple("TestCase", ["constraint_cls", "params", "kwargs"])): "greater_than": T(constraints.greater_than, (_a(0.0),), dict()), "greater_than_eq": T(constraints.greater_than_eq, (_a(0.0),), dict()), "less_than": T(constraints.less_than, (_a(-1.0),), dict()), + "less_than_eq": T(constraints.less_than_eq, (_a(-1.0),), dict()), "independent": T( constraints.independent, (constraints.greater_than(jnp.zeros((2,))),), diff --git a/test/test_transforms.py b/test/test_transforms.py index e3ada6fd9..8f68eaf6e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -366,6 +366,7 @@ def test_batched_recursive_linear_transform(): (constraints.interval(8, 13), (17,)), (constraints.l1_ball, (4,)), (constraints.less_than(-1), ()), + (constraints.less_than_eq(-1), ()), (constraints.lower_cholesky, (15,)), (constraints.open_interval(3, 4), ()), (constraints.ordered_vector, (5,)),