Skip to content

Commit

Permalink
Support UnpackTransform.inv via pack_fn (pyro-ppl#1824)
Browse files Browse the repository at this point in the history
* add pack_fn to UnpackTransform for inverse

* better pack_fn
  • Loading branch information
fehiepsi authored Jul 1, 2024
1 parent f4e69eb commit 2ed9f92
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 11 deletions.
21 changes: 16 additions & 5 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import jax
from jax import lax, vmap
from jax.flatten_util import ravel_pytree
from jax.nn import log_sigmoid, softplus
import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
Expand Down Expand Up @@ -1102,13 +1101,15 @@ class UnpackTransform(Transform):
Transforms a contiguous array to a pytree of subarrays.
:param unpack_fn: callable used to unpack a contiguous array.
:param pack_fn: callable used to pack a pytree into a contiguous array.
"""

domain = constraints.real_vector
codomain = constraints.dependent

def __init__(self, unpack_fn):
def __init__(self, unpack_fn, pack_fn=None):
self.unpack_fn = unpack_fn
self.pack_fn = pack_fn

def __call__(self, x):
batch_shape = x.shape[:-1]
Expand All @@ -1121,9 +1122,15 @@ def __call__(self, x):
return self.unpack_fn(x)

def _inverse(self, y):
if self.pack_fn is None:
raise NotImplementedError(
"pack_fn needs to be provided to perform UnpackTransform.inv."
)
leading_dims = [
v.shape[0] if jnp.ndim(v) > 0 else 0 for v in jax.tree.flatten(y)[0]
]
if not leading_dims:
return jnp.array([])
d0 = leading_dims[0]
not_scalar = d0 > 0 or len(leading_dims) > 1
if not_scalar and all(d == d0 for d in leading_dims[1:]):
Expand All @@ -1132,7 +1139,7 @@ def _inverse(self, y):
" cannot transform a batch of unpacked arrays.",
stacklevel=find_stack_level(),
)
return ravel_pytree(y)[0]
return self.pack_fn(y)

def log_abs_det_jacobian(self, x, y, intermediates=None):
return jnp.zeros(jnp.shape(x)[:-1])
Expand All @@ -1145,10 +1152,14 @@ def inverse_shape(self, shape):

def tree_flatten(self):
# XXX: what if unpack_fn is a parametrized callable pytree?
return (), ((), {"unpack_fn": self.unpack_fn})
return (), ((), {"unpack_fn": self.unpack_fn, "pack_fn": self.pack_fn})

def __eq__(self, other):
return isinstance(other, UnpackTransform) and self.unpack_fn is other.unpack_fn
return (
isinstance(other, UnpackTransform)
and (self.unpack_fn is other.unpack_fn)
and (self.pack_fn is other.pack_fn)
)


def _get_target_shape(shape, forward_shape, inverse_shape):
Expand Down
19 changes: 14 additions & 5 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,13 +624,20 @@ def _unravel_dict(x_flat, shape_dict):
def _ravel_dict(x):
"""Return the flatten version of `x` and shapes of each item in `x`."""
assert isinstance(x, dict)
shape_dict = {}
shape_dict = {name: jnp.shape(value) for name, value in x.items()}
x_flat = _ravel_dict_with_shape_dict(x, shape_dict)
return x_flat, shape_dict


def _ravel_dict_with_shape_dict(x, shape_dict):
assert set(x.keys()) == set(shape_dict.keys())
x_flat = []
for name, value in x.items():
shape_dict[name] = jnp.shape(value)
for name, shape in shape_dict.items():
value = x[name]
assert shape == jnp.shape(value)
x_flat.append(jnp.reshape(value, -1))
x_flat = jnp.concatenate(x_flat) if x_flat else jnp.zeros((0,))
return x_flat, shape_dict
return x_flat


class AutoContinuous(AutoGuide):
Expand Down Expand Up @@ -661,7 +668,9 @@ def _setup_prototype(self, *args, **kwargs):
unpack_latent = partial(_unravel_dict, shape_dict=shape_dict)
# this is to match the behavior of Pyro, where we can apply
# unpack_latent for a batch of samples
self._unpack_latent = UnpackTransform(unpack_latent)
self._unpack_latent = UnpackTransform(
unpack_latent, _ravel_dict_with_shape_dict
)
self.latent_dim = jnp.size(self._init_latent)
if self.latent_dim == 0:
raise RuntimeError(
Expand Down
3 changes: 2 additions & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2703,7 +2703,8 @@ def test_compose_transform_with_intermediates(ts):
def test_unpack_transform(x_dim, y_dim):
xy = np.random.randn(x_dim + y_dim)
unpack_fn = lambda xy: {"x": xy[:x_dim], "y": xy[x_dim:]} # noqa: E731
transform = transforms.UnpackTransform(unpack_fn)
pack_fn = lambda d: jnp.concatenate([d["x"], d["y"]], axis=-1) # noqa: E731
transform = transforms.UnpackTransform(unpack_fn, pack_fn)
z = transform(xy)
if x_dim == y_dim:
with pytest.warns(UserWarning, match="UnpackTransform.inv"):
Expand Down

0 comments on commit 2ed9f92

Please sign in to comment.