Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: Dirichlet distribution got invalid concentration parameter for BetaDistribution #1237

Merged
merged 8 commits into from
Nov 29, 2021
2 changes: 1 addition & 1 deletion numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def inverse_shape(self, shape):
return batch_shape + in_shape


@biject_to.register(BijectorConstraint)
@biject_to.register([BijectorConstraint])
def _transform_to_bijector_constraint(constraint):
return BijectorTransform(constraint.bijector)

Expand Down
6 changes: 6 additions & 0 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,11 @@ def feasible_like(self, prototype):
)


class _OpenInterval(_Interval):
def __call__(self, x):
return (x > self.lower_bound) & (x < self.upper_bound)


class _LowerCholesky(Constraint):
event_dim = 2

Expand Down Expand Up @@ -507,3 +512,4 @@ def feasible_like(self, prototype):
softplus_positive = _SoftplusPositive()
sphere = _Sphere()
unit_interval = _Interval(0.0, 1.0)
open_interval = _OpenInterval
4 changes: 2 additions & 2 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def __init__(self, concentration1, concentration0, validate_args=None):
)
concentration1 = jnp.broadcast_to(concentration1, batch_shape)
concentration0 = jnp.broadcast_to(concentration0, batch_shape)
super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
self._dirichlet = Dirichlet(
jnp.stack([concentration1, concentration0], axis=-1)
)
super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

def sample(self, key, sample_shape=()):
assert is_prng_key(key)
Expand Down Expand Up @@ -1659,7 +1659,7 @@ class BetaProportion(Beta):
"""

arg_constraints = {
"mean": constraints.unit_interval,
"mean": constraints.open_interval(0.0, 1.0),
"concentration": constraints.positive,
}
reparametrized_params = ["mean", "concentration"]
Expand Down
44 changes: 22 additions & 22 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,14 +1044,14 @@ class ConstraintRegistry(object):
def __init__(self):
self._registry = {}

def register(self, constraint, factory=None):
def register(self, constraint_list, factory=None):
if factory is None:
return lambda factory: self.register(constraint, factory)
return lambda factory: self.register(constraint_list, factory)

if isinstance(constraint, constraints.Constraint):
constraint = type(constraint)

self._registry[constraint] = factory
for constraint in constraint_list:
if isinstance(constraint, constraints.Constraint):
constraint = type(constraint)
self._registry[constraint] = factory
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, really sorry, I think we need to return factory here. Could you revert to the previous commit and add

def register(...):
    ...
    return factory

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right - that works, thanks!


def __call__(self, constraint):
try:
Expand All @@ -1065,19 +1065,19 @@ def __call__(self, constraint):
biject_to = ConstraintRegistry()


@biject_to.register(constraints.corr_cholesky)
@biject_to.register([constraints.corr_cholesky])
def _transform_to_corr_cholesky(constraint):
return CorrCholeskyTransform()


@biject_to.register(constraints.corr_matrix)
@biject_to.register([constraints.corr_matrix])
def _transform_to_corr_matrix(constraint):
return ComposeTransform(
[CorrCholeskyTransform(), CorrMatrixCholeskyTransform().inv]
)


@biject_to.register(constraints.greater_than)
@biject_to.register([constraints.greater_than])
def _transform_to_greater_than(constraint):
if constraint is constraints.positive:
return ExpTransform()
Expand All @@ -1089,7 +1089,7 @@ def _transform_to_greater_than(constraint):
)


@biject_to.register(constraints.less_than)
@biject_to.register([constraints.less_than])
def _transform_to_less_than(constraint):
return ComposeTransform(
[
Expand All @@ -1099,14 +1099,14 @@ def _transform_to_less_than(constraint):
)


@biject_to.register(constraints.independent)
@biject_to.register([constraints.independent])
def _biject_to_independent(constraint):
return IndependentTransform(
biject_to(constraint.base_constraint), constraint.reinterpreted_batch_ndims
)


@biject_to.register(constraints.interval)
@biject_to.register([constraints.interval, constraints.open_interval])
def _transform_to_interval(constraint):
if constraint is constraints.unit_interval:
return SigmoidTransform()
Expand All @@ -1121,51 +1121,51 @@ def _transform_to_interval(constraint):
)


@biject_to.register(constraints.l1_ball)
@biject_to.register([constraints.l1_ball])
def _transform_to_l1_ball(constraint):
return L1BallTransform()


@biject_to.register(constraints.lower_cholesky)
@biject_to.register([constraints.lower_cholesky])
def _transform_to_lower_cholesky(constraint):
return LowerCholeskyTransform()


@biject_to.register(constraints.scaled_unit_lower_cholesky)
@biject_to.register([constraints.scaled_unit_lower_cholesky])
def _transform_to_scaled_unit_lower_cholesky(constraint):
return ScaledUnitLowerCholeskyTransform()


@biject_to.register(constraints.ordered_vector)
@biject_to.register([constraints.ordered_vector])
def _transform_to_ordered_vector(constraint):
return OrderedTransform()


@biject_to.register(constraints.positive_definite)
@biject_to.register([constraints.positive_definite])
def _transform_to_positive_definite(constraint):
return ComposeTransform([LowerCholeskyTransform(), CholeskyTransform().inv])


@biject_to.register(constraints.positive_ordered_vector)
@biject_to.register([constraints.positive_ordered_vector])
def _transform_to_positive_ordered_vector(constraint):
return ComposeTransform([OrderedTransform(), ExpTransform()])


@biject_to.register(constraints.real)
@biject_to.register([constraints.real])
def _transform_to_real(constraint):
return IdentityTransform()


@biject_to.register(constraints.softplus_positive)
@biject_to.register([constraints.softplus_positive])
def _transform_to_softplus_positive(constraint):
return SoftplusTransform()


@biject_to.register(constraints.softplus_lower_cholesky)
@biject_to.register([constraints.softplus_lower_cholesky])
def _transform_to_softplus_lower_cholesky(constraint):
return SoftplusLowerCholeskyTransform()


@biject_to.register(constraints.simplex)
@biject_to.register([constraints.simplex])
def _transform_to_simplex(constraint):
return StickBreakingTransform()
13 changes: 13 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,6 +1593,13 @@ def g(x):
assert_allclose(grad_fx, grad_gx, atol=1e-4)


def test_beta_proportion_invalid_mean():
with dist.distribution.validation_enabled(), pytest.raises(
ValueError, match=r"^BetaProportion distribution got invalid mean parameter\.$"
):
dist.BetaProportion(1.0, 1.0)


########################################
# Tests for constraints and transforms #
########################################
Expand Down Expand Up @@ -1713,6 +1720,11 @@ def g(x):
jnp.array([[1, 0, 0], [0.5, 0.5, 0]]),
jnp.array([True, False]),
),
(
constraints.open_interval(0.0, 1.0),
jnp.array([-5, 0, 0.5, 1, 7]),
jnp.array([False, False, True, False, False]),
),
],
)
def test_constraints(constraint, x, expected):
Expand Down Expand Up @@ -1754,6 +1766,7 @@ def test_constraints(constraint, x, expected):
constraints.softplus_positive,
constraints.softplus_lower_cholesky,
constraints.unit_interval,
constraints.open_interval(0.0, 1.0),
],
ids=lambda x: x.__class__,
)
Expand Down