diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index c8017eb57..c656c2cc4 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -871,6 +871,12 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): J_logdet = (softplus(y) + softplus(-y)).sum(-1) return J_logdet + def forward_shape(self, shape): + return shape[:-1] + (shape[-1] - 1,) + + def inverse_shape(self, shape): + return shape[:-1] + (shape[-1] + 1,) + def _softplus_inv(y): return jnp.log(-jnp.expm1(-y)) + y diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 4ccedfaea..99cf35902 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -135,14 +135,14 @@ def sample( :param str name: name of the sample site. :param fn: a stochastic function that returns a sample. - :param numpy.ndarray obs: observed value + :param jnp.ndarray obs: observed value :param jax.random.PRNGKey rng_key: an optional random key for `fn`. :param sample_shape: Shape of samples to be drawn. :param dict infer: an optional dictionary containing additional information for inference algorithms. For example, if `fn` is a discrete distribution, setting `infer={'enumerate': 'parallel'}` to tell MCMC marginalize this discrete latent site. - :param numpy.ndarray obs_mask: Optional boolean array mask of shape + :param jnp.ndarray obs_mask: Optional boolean array mask of shape broadcastable with ``fn.batch_shape``. If provided, events with mask=True will be conditioned on ``obs`` and remaining events will be imputed by sampling. This introduces a latent sample site named ``name @@ -235,7 +235,7 @@ def param(name, init_value=None, **kwargs): Note that the onus of using this to initialize the optimizer is on the user inference algorithm, since there is no global parameter store in NumPyro. - :type init_value: numpy.ndarray or callable + :type init_value: jnp.ndarray or callable :param constraint: NumPyro constraint, defaults to ``constraints.real``. :type constraint: numpyro.distributions.constraints.Constraint :param int event_dim: (optional) number of rightmost dimensions unrelated @@ -289,7 +289,7 @@ def deterministic(name, value): values in the model execution trace. :param str name: name of the deterministic site. - :param numpy.ndarray value: deterministic value to record in the trace. + :param jnp.ndarray value: deterministic value to record in the trace. """ if not _PYRO_STACK: return value @@ -376,7 +376,7 @@ def model(): # ... :returns: The mask. - :rtype: None, bool, or numpy.ndarray + :rtype: None, bool, or jnp.ndarray """ return _inspect()["mask"] @@ -607,7 +607,7 @@ def factor(name, log_factor): probabilistic model. :param str name: Name of the trivial sample. - :param numpy.ndarray log_factor: A possibly batched log probability factor. + :param jnp.ndarray log_factor: A possibly batched log probability factor. """ unit_dist = numpyro.distributions.distribution.Unit(log_factor) unit_value = unit_dist.sample(None) @@ -657,11 +657,11 @@ def model(data): data = numpyro.subsample(data, event_dim=0) # ... - :param numpy.ndarray data: A tensor of batched data. + :param jnp.ndarray data: A tensor of batched data. :param int event_dim: The event dimension of the data tensor. Dimensions to the left are considered batch dimensions. :returns: A subsampled version of ``data`` - :rtype: ~numpy.ndarray + :rtype: ~jnp.ndarray """ if not _PYRO_STACK: return data diff --git a/test/test_distributions.py b/test/test_distributions.py index fcfe5e284..9144e0387 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -2300,6 +2300,16 @@ def test_composed_transform_1(batch_shape): assert_allclose(log_det, expected_log_det) +@pytest.mark.parametrize("batch_shape", [(), (5,)]) +def test_simplex_to_order_transform(batch_shape): + simplex = jnp.arange(5.0) / jnp.arange(5.0).sum() + simplex = jnp.broadcast_to(simplex, batch_shape + simplex.shape) + transform = SimplexToOrderedTransform() + out = transform(simplex) + assert out.shape == transform.forward_shape(simplex.shape) + assert simplex.shape == transform.inverse_shape(out.shape) + + @pytest.mark.parametrize("batch_shape", [(), (5,)]) @pytest.mark.parametrize("prepend_event_shape", [(), (4,)]) @pytest.mark.parametrize("sample_shape", [(), (7,)])