Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: Dirichlet distribution got invalid concentration parameter for BetaDistribution #1237

Merged
merged 8 commits into from
Nov 29, 2021
6 changes: 6 additions & 0 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,11 @@ def feasible_like(self, prototype):
)


class _OpenInterval(_Interval):
def __call__(self, x):
return (x > self.lower_bound) & (x < self.upper_bound)


class _LowerCholesky(Constraint):
event_dim = 2

Expand Down Expand Up @@ -507,3 +512,4 @@ def feasible_like(self, prototype):
softplus_positive = _SoftplusPositive()
sphere = _Sphere()
unit_interval = _Interval(0.0, 1.0)
open_interval = _OpenInterval
4 changes: 2 additions & 2 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def __init__(self, concentration1, concentration0, validate_args=None):
)
concentration1 = jnp.broadcast_to(concentration1, batch_shape)
concentration0 = jnp.broadcast_to(concentration0, batch_shape)
super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
self._dirichlet = Dirichlet(
jnp.stack([concentration1, concentration0], axis=-1)
)
super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

def sample(self, key, sample_shape=()):
assert is_prng_key(key)
Expand Down Expand Up @@ -1659,7 +1659,7 @@ class BetaProportion(Beta):
"""

arg_constraints = {
"mean": constraints.unit_interval,
"mean": constraints.open_interval(0.0, 1.0),
"concentration": constraints.positive,
}
reparametrized_params = ["mean", "concentration"]
Expand Down
2 changes: 2 additions & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,7 @@ def register(self, constraint, factory=None):
constraint = type(constraint)

self._registry[constraint] = factory
return factory

def __call__(self, constraint):
try:
Expand Down Expand Up @@ -1106,6 +1107,7 @@ def _biject_to_independent(constraint):
)


@biject_to.register(constraints.open_interval)
@biject_to.register(constraints.interval)
def _transform_to_interval(constraint):
if constraint is constraints.unit_interval:
Expand Down
13 changes: 13 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,6 +1593,13 @@ def g(x):
assert_allclose(grad_fx, grad_gx, atol=1e-4)


def test_beta_proportion_invalid_mean():
with dist.distribution.validation_enabled(), pytest.raises(
ValueError, match=r"^BetaProportion distribution got invalid mean parameter\.$"
):
dist.BetaProportion(1.0, 1.0)


########################################
# Tests for constraints and transforms #
########################################
Expand Down Expand Up @@ -1713,6 +1720,11 @@ def g(x):
jnp.array([[1, 0, 0], [0.5, 0.5, 0]]),
jnp.array([True, False]),
),
(
constraints.open_interval(0.0, 1.0),
jnp.array([-5, 0, 0.5, 1, 7]),
jnp.array([False, False, True, False, False]),
),
],
)
def test_constraints(constraint, x, expected):
Expand Down Expand Up @@ -1754,6 +1766,7 @@ def test_constraints(constraint, x, expected):
constraints.softplus_positive,
constraints.softplus_lower_cholesky,
constraints.unit_interval,
constraints.open_interval(0.0, 1.0),
],
ids=lambda x: x.__class__,
)
Expand Down