Skip to content

Commit

Permalink
Fix handling of event dimensions in ComposeTransform (fixes pyro-pp…
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Oct 23, 2024
1 parent 8ace34f commit 2938a7c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
2 changes: 1 addition & 1 deletion numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def __eq__(self, other):

def _get_compose_transform_input_event_dim(parts):
input_event_dim = parts[-1].domain.event_dim
for part in parts[len(parts) - 1 :: -1]:
for part in parts[:-1][::-1]:
input_event_dim = part.domain.event_dim + max(
input_event_dim - part.codomain.event_dim, 0
)
Expand Down
13 changes: 13 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,16 @@ def test_biject_to(constraint, shape):
expected_shape = constrained.shape[: constrained.ndim - constraint.event_dim]
assert passed.shape == expected_shape
assert jnp.all(passed)


@pytest.mark.parametrize(
"transform",
[
CorrCholeskyTransform(),
CorrCholeskyTransform().inv,
],
)
def test_compose_domain_codomain(transform):
composed = ComposeTransform([transform])
assert transform.domain.event_dim == composed.domain.event_dim
assert transform.codomain.event_dim == composed.codomain.event_dim

0 comments on commit 2938a7c

Please sign in to comment.