Skip to content

Commit

Permalink
Fix event dimensions of (co)domain of ReshapeTransform.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Aug 21, 2024
1 parent b19a83d commit 6e3e9ad
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,9 +1175,6 @@ class ReshapeTransform(Transform):
:param inverse_shape: Shape of the sample for the inverse transform.
"""

domain = constraints.real
codomain = constraints.real

def __init__(self, forward_shape, inverse_shape) -> None:
forward_size = math.prod(forward_shape)
inverse_size = math.prod(inverse_shape)
Expand All @@ -1189,6 +1186,14 @@ def __init__(self, forward_shape, inverse_shape) -> None:
self._forward_shape = forward_shape
self._inverse_shape = inverse_shape

@property
def domain(self):
return constraints.independent(constraints.real, len(self._inverse_shape))

@property
def codomain(self):
return constraints.independent(constraints.real, len(self._forward_shape))

def forward_shape(self, shape):
return _get_target_shape(shape, self._forward_shape, self._inverse_shape)

Expand Down

0 comments on commit 6e3e9ad

Please sign in to comment.