From 6e3e9ad33d576b748f3e6dce183d17ea38742046 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Wed, 21 Aug 2024 18:56:51 -0400 Subject: [PATCH] Fix event dimensions of (co)domain of `ReshapeTransform`. --- numpyro/distributions/transforms.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 40c31fa513..2e24881811 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1175,9 +1175,6 @@ class ReshapeTransform(Transform): :param inverse_shape: Shape of the sample for the inverse transform. """ - domain = constraints.real - codomain = constraints.real - def __init__(self, forward_shape, inverse_shape) -> None: forward_size = math.prod(forward_shape) inverse_size = math.prod(inverse_shape) @@ -1189,6 +1186,14 @@ def __init__(self, forward_shape, inverse_shape) -> None: self._forward_shape = forward_shape self._inverse_shape = inverse_shape + @property + def domain(self): + return constraints.independent(constraints.real, len(self._inverse_shape)) + + @property + def codomain(self): + return constraints.independent(constraints.real, len(self._forward_shape)) + def forward_shape(self, shape): return _get_target_shape(shape, self._forward_shape, self._inverse_shape)