From a21a5e56c14a8734261f93da15f48217bac9a371 Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Wed, 24 Nov 2021 20:18:16 +0000 Subject: [PATCH 1/8] check that mean param of betaproportion is in 0-1 open interval --- numpyro/distributions/constraints.py | 15 +++++++++++++++ numpyro/distributions/continuous.py | 4 ++-- test/test_distributions.py | 13 +++++++++++++ 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index acf099918..aeab1a556 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -319,6 +319,20 @@ def feasible_like(self, prototype): ) +class _OpenInterval(Constraint): + def __init__(self, lower_bound, upper_bound): + self.lower_bound = lower_bound + self.upper_bound = upper_bound + + def __call__(self, x): + return (x > self.lower_bound) & (x < self.upper_bound) + + def feasible_like(self, prototype): + return jax.numpy.broadcast_to( + (self.lower_bound + self.upper_bound) / 2, jax.numpy.shape(prototype) + ) + + class _LowerCholesky(Constraint): event_dim = 2 @@ -507,3 +521,4 @@ def feasible_like(self, prototype): softplus_positive = _SoftplusPositive() sphere = _Sphere() unit_interval = _Interval(0.0, 1.0) +open_interval = _OpenInterval diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index af2a9b452..337b7efe4 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -72,10 +72,10 @@ def __init__(self, concentration1, concentration0, validate_args=None): ) concentration1 = jnp.broadcast_to(concentration1, batch_shape) concentration0 = jnp.broadcast_to(concentration0, batch_shape) + super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args) self._dirichlet = Dirichlet( jnp.stack([concentration1, concentration0], axis=-1) ) - super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args) def sample(self, key, sample_shape=()): assert is_prng_key(key) @@ -1659,7 +1659,7 @@ class BetaProportion(Beta): """ arg_constraints = { - "mean": constraints.unit_interval, + "mean": constraints.open_interval(0.0, 1.0), "concentration": constraints.positive, } reparametrized_params = ["mean", "concentration"] diff --git a/test/test_distributions.py b/test/test_distributions.py index 1e36aaf88..88605b6f5 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1593,6 +1593,13 @@ def g(x): assert_allclose(grad_fx, grad_gx, atol=1e-4) +def test_beta_proportion_invalid_mean(): + with dist.distribution.validation_enabled(), pytest.raises( + ValueError, match=r"^BetaProportion distribution got invalid mean parameter\.$" + ): + dist.BetaProportion(1.0, 1.0) + + ######################################## # Tests for constraints and transforms # ######################################## @@ -1713,6 +1720,11 @@ def g(x): jnp.array([[1, 0, 0], [0.5, 0.5, 0]]), jnp.array([True, False]), ), + ( + constraints.open_interval(0.0, 1.0), + jnp.array([-5, 0, 0.5, 1, 7]), + jnp.array([False, False, True, False, False]), + ), ], ) def test_constraints(constraint, x, expected): @@ -1754,6 +1766,7 @@ def test_constraints(constraint, x, expected): constraints.softplus_positive, constraints.softplus_lower_cholesky, constraints.unit_interval, + constraints.open_interval(0.0, 1.0), ], ids=lambda x: x.__class__, ) From 85cce9296d6b9b427c333cce32dd506303f004a2 Mon Sep 17 00:00:00 2001 From: MarcoGorelli Date: Thu, 25 Nov 2021 09:56:31 +0000 Subject: [PATCH 2/8] inherit from _Interval, only overwrite __call__ --- numpyro/distributions/constraints.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index aeab1a556..98ba8c7de 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -319,19 +319,10 @@ def feasible_like(self, prototype): ) -class _OpenInterval(Constraint): - def __init__(self, lower_bound, upper_bound): - self.lower_bound = lower_bound - self.upper_bound = upper_bound - +class _OpenInterval(_Interval): def __call__(self, x): return (x > self.lower_bound) & (x < self.upper_bound) - def feasible_like(self, prototype): - return jax.numpy.broadcast_to( - (self.lower_bound + self.upper_bound) / 2, jax.numpy.shape(prototype) - ) - class _LowerCholesky(Constraint): event_dim = 2 From 0a2f465d7cffcc15a686134201f279b9345cada7 Mon Sep 17 00:00:00 2001 From: MarcoGorelli Date: Thu, 25 Nov 2021 10:45:59 +0000 Subject: [PATCH 3/8] register transform_to_open_interval --- numpyro/distributions/transforms.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 9cc4b5fd4..772958028 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1121,6 +1121,21 @@ def _transform_to_interval(constraint): ) +@biject_to.register(constraints.open_interval) +def _transform_to_open_interval(constraint): + scale = constraint.upper_bound - constraint.lower_bound + return ComposeTransform( + [ + SigmoidTransform(), + AffineTransform( + constraint.lower_bound, + scale, + domain=constraints.open_interval(0.0, 1.0), + ), + ] + ) + + @biject_to.register(constraints.l1_ball) def _transform_to_l1_ball(constraint): return L1BallTransform() From 17d442a0442027dfad679bb46595bf992f2fab48 Mon Sep 17 00:00:00 2001 From: MarcoGorelli Date: Thu, 25 Nov 2021 12:06:51 +0000 Subject: [PATCH 4/8] fix domain of transform_to_open_interval --- numpyro/distributions/transforms.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 772958028..c8ef6f55e 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1128,9 +1128,7 @@ def _transform_to_open_interval(constraint): [ SigmoidTransform(), AffineTransform( - constraint.lower_bound, - scale, - domain=constraints.open_interval(0.0, 1.0), + constraint.lower_bound, scale, domain=constraints.unit_interval ), ] ) From bd5b7a5d23fecd7b5fd4a54bb7d2112b48137248 Mon Sep 17 00:00:00 2001 From: MarcoGorelli Date: Thu, 25 Nov 2021 12:54:47 +0000 Subject: [PATCH 5/8] noop From 128959bd89c08c54e8c1c5ca4bbe860ec3da42d5 Mon Sep 17 00:00:00 2001 From: MarcoGorelli Date: Sat, 27 Nov 2021 08:54:21 +0000 Subject: [PATCH 6/8] deduplicate _transform_to_interval --- numpyro/distributions/transforms.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index c8ef6f55e..96120ca71 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1106,6 +1106,7 @@ def _biject_to_independent(constraint): ) +@biject_to.register(constraints.open_interval) @biject_to.register(constraints.interval) def _transform_to_interval(constraint): if constraint is constraints.unit_interval: @@ -1121,19 +1122,6 @@ def _transform_to_interval(constraint): ) -@biject_to.register(constraints.open_interval) -def _transform_to_open_interval(constraint): - scale = constraint.upper_bound - constraint.lower_bound - return ComposeTransform( - [ - SigmoidTransform(), - AffineTransform( - constraint.lower_bound, scale, domain=constraints.unit_interval - ), - ] - ) - - @biject_to.register(constraints.l1_ball) def _transform_to_l1_ball(constraint): return L1BallTransform() From ae59cac115beb3924f6e237cdebffe77a2e26f90 Mon Sep 17 00:00:00 2001 From: MarcoGorelli Date: Sat, 27 Nov 2021 12:55:35 +0000 Subject: [PATCH 7/8] register list of constraints --- numpyro/contrib/tfp/distributions.py | 2 +- numpyro/distributions/transforms.py | 45 ++++++++++++++-------------- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/numpyro/contrib/tfp/distributions.py b/numpyro/contrib/tfp/distributions.py index 2c8d3c40f..c174195c0 100644 --- a/numpyro/contrib/tfp/distributions.py +++ b/numpyro/contrib/tfp/distributions.py @@ -102,7 +102,7 @@ def inverse_shape(self, shape): return batch_shape + in_shape -@biject_to.register(BijectorConstraint) +@biject_to.register([BijectorConstraint]) def _transform_to_bijector_constraint(constraint): return BijectorTransform(constraint.bijector) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 96120ca71..10bb8e9cf 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1044,14 +1044,14 @@ class ConstraintRegistry(object): def __init__(self): self._registry = {} - def register(self, constraint, factory=None): + def register(self, constraint_list, factory=None): if factory is None: - return lambda factory: self.register(constraint, factory) + return lambda factory: self.register(constraint_list, factory) - if isinstance(constraint, constraints.Constraint): - constraint = type(constraint) - - self._registry[constraint] = factory + for constraint in constraint_list: + if isinstance(constraint, constraints.Constraint): + constraint = type(constraint) + self._registry[constraint] = factory def __call__(self, constraint): try: @@ -1065,19 +1065,19 @@ def __call__(self, constraint): biject_to = ConstraintRegistry() -@biject_to.register(constraints.corr_cholesky) +@biject_to.register([constraints.corr_cholesky]) def _transform_to_corr_cholesky(constraint): return CorrCholeskyTransform() -@biject_to.register(constraints.corr_matrix) +@biject_to.register([constraints.corr_matrix]) def _transform_to_corr_matrix(constraint): return ComposeTransform( [CorrCholeskyTransform(), CorrMatrixCholeskyTransform().inv] ) -@biject_to.register(constraints.greater_than) +@biject_to.register([constraints.greater_than]) def _transform_to_greater_than(constraint): if constraint is constraints.positive: return ExpTransform() @@ -1089,7 +1089,7 @@ def _transform_to_greater_than(constraint): ) -@biject_to.register(constraints.less_than) +@biject_to.register([constraints.less_than]) def _transform_to_less_than(constraint): return ComposeTransform( [ @@ -1099,15 +1099,14 @@ def _transform_to_less_than(constraint): ) -@biject_to.register(constraints.independent) +@biject_to.register([constraints.independent]) def _biject_to_independent(constraint): return IndependentTransform( biject_to(constraint.base_constraint), constraint.reinterpreted_batch_ndims ) -@biject_to.register(constraints.open_interval) -@biject_to.register(constraints.interval) +@biject_to.register([constraints.interval, constraints.open_interval]) def _transform_to_interval(constraint): if constraint is constraints.unit_interval: return SigmoidTransform() @@ -1122,51 +1121,51 @@ def _transform_to_interval(constraint): ) -@biject_to.register(constraints.l1_ball) +@biject_to.register([constraints.l1_ball]) def _transform_to_l1_ball(constraint): return L1BallTransform() -@biject_to.register(constraints.lower_cholesky) +@biject_to.register([constraints.lower_cholesky]) def _transform_to_lower_cholesky(constraint): return LowerCholeskyTransform() -@biject_to.register(constraints.scaled_unit_lower_cholesky) +@biject_to.register([constraints.scaled_unit_lower_cholesky]) def _transform_to_scaled_unit_lower_cholesky(constraint): return ScaledUnitLowerCholeskyTransform() -@biject_to.register(constraints.ordered_vector) +@biject_to.register([constraints.ordered_vector]) def _transform_to_ordered_vector(constraint): return OrderedTransform() -@biject_to.register(constraints.positive_definite) +@biject_to.register([constraints.positive_definite]) def _transform_to_positive_definite(constraint): return ComposeTransform([LowerCholeskyTransform(), CholeskyTransform().inv]) -@biject_to.register(constraints.positive_ordered_vector) +@biject_to.register([constraints.positive_ordered_vector]) def _transform_to_positive_ordered_vector(constraint): return ComposeTransform([OrderedTransform(), ExpTransform()]) -@biject_to.register(constraints.real) +@biject_to.register([constraints.real]) def _transform_to_real(constraint): return IdentityTransform() -@biject_to.register(constraints.softplus_positive) +@biject_to.register([constraints.softplus_positive]) def _transform_to_softplus_positive(constraint): return SoftplusTransform() -@biject_to.register(constraints.softplus_lower_cholesky) +@biject_to.register([constraints.softplus_lower_cholesky]) def _transform_to_softplus_lower_cholesky(constraint): return SoftplusLowerCholeskyTransform() -@biject_to.register(constraints.simplex) +@biject_to.register([constraints.simplex]) def _transform_to_simplex(constraint): return StickBreakingTransform() From 96dfe16f2bb3e784bb0ca3f68b0f9c000a2ecd9c Mon Sep 17 00:00:00 2001 From: MarcoGorelli Date: Sun, 28 Nov 2021 09:00:23 +0000 Subject: [PATCH 8/8] return factory from biject_to.register --- numpyro/contrib/tfp/distributions.py | 2 +- numpyro/distributions/transforms.py | 46 +++++++++++++++------------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/numpyro/contrib/tfp/distributions.py b/numpyro/contrib/tfp/distributions.py index c174195c0..2c8d3c40f 100644 --- a/numpyro/contrib/tfp/distributions.py +++ b/numpyro/contrib/tfp/distributions.py @@ -102,7 +102,7 @@ def inverse_shape(self, shape): return batch_shape + in_shape -@biject_to.register([BijectorConstraint]) +@biject_to.register(BijectorConstraint) def _transform_to_bijector_constraint(constraint): return BijectorTransform(constraint.bijector) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 10bb8e9cf..e468b9b8c 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1044,14 +1044,15 @@ class ConstraintRegistry(object): def __init__(self): self._registry = {} - def register(self, constraint_list, factory=None): + def register(self, constraint, factory=None): if factory is None: - return lambda factory: self.register(constraint_list, factory) + return lambda factory: self.register(constraint, factory) - for constraint in constraint_list: - if isinstance(constraint, constraints.Constraint): - constraint = type(constraint) - self._registry[constraint] = factory + if isinstance(constraint, constraints.Constraint): + constraint = type(constraint) + + self._registry[constraint] = factory + return factory def __call__(self, constraint): try: @@ -1065,19 +1066,19 @@ def __call__(self, constraint): biject_to = ConstraintRegistry() -@biject_to.register([constraints.corr_cholesky]) +@biject_to.register(constraints.corr_cholesky) def _transform_to_corr_cholesky(constraint): return CorrCholeskyTransform() -@biject_to.register([constraints.corr_matrix]) +@biject_to.register(constraints.corr_matrix) def _transform_to_corr_matrix(constraint): return ComposeTransform( [CorrCholeskyTransform(), CorrMatrixCholeskyTransform().inv] ) -@biject_to.register([constraints.greater_than]) +@biject_to.register(constraints.greater_than) def _transform_to_greater_than(constraint): if constraint is constraints.positive: return ExpTransform() @@ -1089,7 +1090,7 @@ def _transform_to_greater_than(constraint): ) -@biject_to.register([constraints.less_than]) +@biject_to.register(constraints.less_than) def _transform_to_less_than(constraint): return ComposeTransform( [ @@ -1099,14 +1100,15 @@ def _transform_to_less_than(constraint): ) -@biject_to.register([constraints.independent]) +@biject_to.register(constraints.independent) def _biject_to_independent(constraint): return IndependentTransform( biject_to(constraint.base_constraint), constraint.reinterpreted_batch_ndims ) -@biject_to.register([constraints.interval, constraints.open_interval]) +@biject_to.register(constraints.open_interval) +@biject_to.register(constraints.interval) def _transform_to_interval(constraint): if constraint is constraints.unit_interval: return SigmoidTransform() @@ -1121,51 +1123,51 @@ def _transform_to_interval(constraint): ) -@biject_to.register([constraints.l1_ball]) +@biject_to.register(constraints.l1_ball) def _transform_to_l1_ball(constraint): return L1BallTransform() -@biject_to.register([constraints.lower_cholesky]) +@biject_to.register(constraints.lower_cholesky) def _transform_to_lower_cholesky(constraint): return LowerCholeskyTransform() -@biject_to.register([constraints.scaled_unit_lower_cholesky]) +@biject_to.register(constraints.scaled_unit_lower_cholesky) def _transform_to_scaled_unit_lower_cholesky(constraint): return ScaledUnitLowerCholeskyTransform() -@biject_to.register([constraints.ordered_vector]) +@biject_to.register(constraints.ordered_vector) def _transform_to_ordered_vector(constraint): return OrderedTransform() -@biject_to.register([constraints.positive_definite]) +@biject_to.register(constraints.positive_definite) def _transform_to_positive_definite(constraint): return ComposeTransform([LowerCholeskyTransform(), CholeskyTransform().inv]) -@biject_to.register([constraints.positive_ordered_vector]) +@biject_to.register(constraints.positive_ordered_vector) def _transform_to_positive_ordered_vector(constraint): return ComposeTransform([OrderedTransform(), ExpTransform()]) -@biject_to.register([constraints.real]) +@biject_to.register(constraints.real) def _transform_to_real(constraint): return IdentityTransform() -@biject_to.register([constraints.softplus_positive]) +@biject_to.register(constraints.softplus_positive) def _transform_to_softplus_positive(constraint): return SoftplusTransform() -@biject_to.register([constraints.softplus_lower_cholesky]) +@biject_to.register(constraints.softplus_lower_cholesky) def _transform_to_softplus_lower_cholesky(constraint): return SoftplusLowerCholeskyTransform() -@biject_to.register([constraints.simplex]) +@biject_to.register(constraints.simplex) def _transform_to_simplex(constraint): return StickBreakingTransform()