Skip to content

Commit

Permalink
check that mean param of betaproportion is in 0-1 open interval
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Nov 24, 2021
1 parent 9d06ea5 commit a21a5e5
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
15 changes: 15 additions & 0 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,20 @@ def feasible_like(self, prototype):
)


class _OpenInterval(Constraint):
def __init__(self, lower_bound, upper_bound):
self.lower_bound = lower_bound
self.upper_bound = upper_bound

def __call__(self, x):
return (x > self.lower_bound) & (x < self.upper_bound)

def feasible_like(self, prototype):
return jax.numpy.broadcast_to(
(self.lower_bound + self.upper_bound) / 2, jax.numpy.shape(prototype)
)


class _LowerCholesky(Constraint):
event_dim = 2

Expand Down Expand Up @@ -507,3 +521,4 @@ def feasible_like(self, prototype):
softplus_positive = _SoftplusPositive()
sphere = _Sphere()
unit_interval = _Interval(0.0, 1.0)
open_interval = _OpenInterval
4 changes: 2 additions & 2 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def __init__(self, concentration1, concentration0, validate_args=None):
)
concentration1 = jnp.broadcast_to(concentration1, batch_shape)
concentration0 = jnp.broadcast_to(concentration0, batch_shape)
super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
self._dirichlet = Dirichlet(
jnp.stack([concentration1, concentration0], axis=-1)
)
super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

def sample(self, key, sample_shape=()):
assert is_prng_key(key)
Expand Down Expand Up @@ -1659,7 +1659,7 @@ class BetaProportion(Beta):
"""

arg_constraints = {
"mean": constraints.unit_interval,
"mean": constraints.open_interval(0.0, 1.0),
"concentration": constraints.positive,
}
reparametrized_params = ["mean", "concentration"]
Expand Down
13 changes: 13 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,6 +1593,13 @@ def g(x):
assert_allclose(grad_fx, grad_gx, atol=1e-4)


def test_beta_proportion_invalid_mean():
with dist.distribution.validation_enabled(), pytest.raises(
ValueError, match=r"^BetaProportion distribution got invalid mean parameter\.$"
):
dist.BetaProportion(1.0, 1.0)


########################################
# Tests for constraints and transforms #
########################################
Expand Down Expand Up @@ -1713,6 +1720,11 @@ def g(x):
jnp.array([[1, 0, 0], [0.5, 0.5, 0]]),
jnp.array([True, False]),
),
(
constraints.open_interval(0.0, 1.0),
jnp.array([-5, 0, 0.5, 1, 7]),
jnp.array([False, False, True, False, False]),
),
],
)
def test_constraints(constraint, x, expected):
Expand Down Expand Up @@ -1754,6 +1766,7 @@ def test_constraints(constraint, x, expected):
constraints.softplus_positive,
constraints.softplus_lower_cholesky,
constraints.unit_interval,
constraints.open_interval(0.0, 1.0),
],
ids=lambda x: x.__class__,
)
Expand Down

0 comments on commit a21a5e5

Please sign in to comment.