Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add forward shape for SimplexToOrderTransform #1583

Merged
merged 2 commits into from
May 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,12 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
J_logdet = (softplus(y) + softplus(-y)).sum(-1)
return J_logdet

def forward_shape(self, shape):
return shape[:-1] + (shape[-1] - 1,)

def inverse_shape(self, shape):
return shape[:-1] + (shape[-1] + 1,)


def _softplus_inv(y):
return jnp.log(-jnp.expm1(-y)) + y
Expand Down
16 changes: 8 additions & 8 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,14 @@ def sample(

:param str name: name of the sample site.
:param fn: a stochastic function that returns a sample.
:param numpy.ndarray obs: observed value
:param jnp.ndarray obs: observed value
:param jax.random.PRNGKey rng_key: an optional random key for `fn`.
:param sample_shape: Shape of samples to be drawn.
:param dict infer: an optional dictionary containing additional information
for inference algorithms. For example, if `fn` is a discrete distribution,
setting `infer={'enumerate': 'parallel'}` to tell MCMC marginalize
this discrete latent site.
:param numpy.ndarray obs_mask: Optional boolean array mask of shape
:param jnp.ndarray obs_mask: Optional boolean array mask of shape
broadcastable with ``fn.batch_shape``. If provided, events with
mask=True will be conditioned on ``obs`` and remaining events will be
imputed by sampling. This introduces a latent sample site named ``name
Expand Down Expand Up @@ -235,7 +235,7 @@ def param(name, init_value=None, **kwargs):
Note that the onus of using this to initialize the optimizer is
on the user inference algorithm, since there is no global parameter
store in NumPyro.
:type init_value: numpy.ndarray or callable
:type init_value: jnp.ndarray or callable
:param constraint: NumPyro constraint, defaults to ``constraints.real``.
:type constraint: numpyro.distributions.constraints.Constraint
:param int event_dim: (optional) number of rightmost dimensions unrelated
Expand Down Expand Up @@ -289,7 +289,7 @@ def deterministic(name, value):
values in the model execution trace.

:param str name: name of the deterministic site.
:param numpy.ndarray value: deterministic value to record in the trace.
:param jnp.ndarray value: deterministic value to record in the trace.
"""
if not _PYRO_STACK:
return value
Expand Down Expand Up @@ -376,7 +376,7 @@ def model():
# ...

:returns: The mask.
:rtype: None, bool, or numpy.ndarray
:rtype: None, bool, or jnp.ndarray
"""
return _inspect()["mask"]

Expand Down Expand Up @@ -607,7 +607,7 @@ def factor(name, log_factor):
probabilistic model.

:param str name: Name of the trivial sample.
:param numpy.ndarray log_factor: A possibly batched log probability factor.
:param jnp.ndarray log_factor: A possibly batched log probability factor.
"""
unit_dist = numpyro.distributions.distribution.Unit(log_factor)
unit_value = unit_dist.sample(None)
Expand Down Expand Up @@ -657,11 +657,11 @@ def model(data):
data = numpyro.subsample(data, event_dim=0)
# ...

:param numpy.ndarray data: A tensor of batched data.
:param jnp.ndarray data: A tensor of batched data.
:param int event_dim: The event dimension of the data tensor. Dimensions to
the left are considered batch dimensions.
:returns: A subsampled version of ``data``
:rtype: ~numpy.ndarray
:rtype: ~jnp.ndarray
"""
if not _PYRO_STACK:
return data
Expand Down
10 changes: 10 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2300,6 +2300,16 @@ def test_composed_transform_1(batch_shape):
assert_allclose(log_det, expected_log_det)


@pytest.mark.parametrize("batch_shape", [(), (5,)])
def test_simplex_to_order_transform(batch_shape):
simplex = jnp.arange(5.0) / jnp.arange(5.0).sum()
simplex = jnp.broadcast_to(simplex, batch_shape + simplex.shape)
transform = SimplexToOrderedTransform()
out = transform(simplex)
assert out.shape == transform.forward_shape(simplex.shape)
assert simplex.shape == transform.inverse_shape(out.shape)


@pytest.mark.parametrize("batch_shape", [(), (5,)])
@pytest.mark.parametrize("prepend_event_shape", [(), (4,)])
@pytest.mark.parametrize("sample_shape", [(), (7,)])
Expand Down