Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transforms' forward_shape and inverse_shape #887

Merged
merged 6 commits into from
Jan 24, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ def _inverse(self, y):
def log_abs_det_jacobian(self, x, y, intermediates=None):
return self.bijector.forward_log_det_jacobian(x, self.domain.event_dim)

def forward_shape(self, shape):
return self.bijector.forward_event_shape(shape)

def inverse_shape(self, shape):
return self.bijector.inverse_event_shape(shape)

fritzo marked this conversation as resolved.
Show resolved Hide resolved

@biject_to.register(BijectorConstraint)
def _transform_to_bijector_constraint(constraint):
Expand Down
3 changes: 3 additions & 0 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def __init__(self, base_constraint, reinterpreted_batch_ndims):
assert isinstance(base_constraint, Constraint)
assert isinstance(reinterpreted_batch_ndims, int)
assert reinterpreted_batch_ndims >= 0
if isinstance(base_constraint, _IndependentConstraint):
reinterpreted_batch_ndims = reinterpreted_batch_ndims + base_constraint.reinterpreted_batch_ndims
base_constraint = base_constraint.base_constraint
self.base_constraint = base_constraint
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
super().__init__()
Expand Down
32 changes: 8 additions & 24 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ class ImproperUniform(Distribution):
arg_constraints = {}

def __init__(self, support, batch_shape, event_shape, validate_args=None):
self.support = support
self.support = independent(support, len(event_shape) - support.event_dim)
super().__init__(batch_shape, event_shape, validate_args=validate_args)

@validate_sample
Expand Down Expand Up @@ -748,24 +748,12 @@ def __init__(self, base_distribution, transforms, validate_args=None):
raise ValueError("transforms must be a Transform or a list of Transforms")
else:
raise ValueError("transforms must be a Transform or list, but was {}".format(transforms))
# XXX: this logic will not be valid when IndependentDistribution is support;
# in that case, it is more involved to support Transform(Indep(Transform));
# however, we might not need to support such kind of distribution
# and should raise an error if base_distribution is an Indep one
if isinstance(base_distribution, TransformedDistribution):
self.base_dist = base_distribution.base_dist
self.transforms = base_distribution.transforms + transforms
else:
self.base_dist = base_distribution
self.transforms = transforms
# NB: here we assume that base_dist.shape == transformed_dist.shape
# but that might not be True for some transforms such as StickBreakingTransform
# because the event dimension is transformed from (n - 1,) to (n,).
# Currently, we have no mechanism to fix this issue. Given that
# this is just an edge case, we might skip this issue but need
# to pay attention to any inference function that inspects
# transformed distribution's shape.
# TODO: address this and the comment below when infer_shapes is available
shape = base_distribution.batch_shape + base_distribution.event_shape
base_ndim = len(shape)
transform = ComposeTransform(self.transforms)
Expand All @@ -774,16 +762,10 @@ def __init__(self, base_distribution, transforms, validate_args=None):
raise ValueError("Base distribution needs to have shape with size at least {}, but got {}."
.format(transform_input_event_dim, base_ndim))
event_dim = transform.codomain.event_dim + max(self.base_dist.event_dim - transform_input_event_dim, 0)
# See the above note. Currently, there is no way to interpret the shape of output after
# transforming. To solve this issue, we need something like Bijector.forward_event_shape
# as in TFP. For now, we will prepend singleton dimensions to compromise, so that
# event_dim, len(batch_shape) are still correct.
if event_dim <= base_ndim:
batch_shape = shape[:base_ndim - event_dim]
event_shape = shape[base_ndim - event_dim:]
else:
event_shape = (-1,) * event_dim
batch_shape = ()
shape = transform.forward_shape(shape)
cut = len(shape) - event_dim
batch_shape = shape[:cut]
event_shape = shape[cut:]
fritzo marked this conversation as resolved.
Show resolved Hide resolved
super(TransformedDistribution, self).__init__(batch_shape, event_shape, validate_args=validate_args)

@property
Expand Down Expand Up @@ -862,7 +844,6 @@ def tree_flatten(self):
class Delta(Distribution):
arg_constraints = {'v': real, 'log_density': real}
reparameterized_params = ['v', 'log_density']
support = real
is_discrete = True

def __init__(self, v=0., log_density=0., event_dim=0, validate_args=None):
Expand All @@ -877,6 +858,9 @@ def __init__(self, v=0., log_density=0., event_dim=0, validate_args=None):
self.log_density = promote_shapes(log_density, shape=batch_shape)[0]
super(Delta, self).__init__(batch_shape, event_shape, validate_args=validate_args)

def support(self):
return independent(real, self.event_dim)
fritzo marked this conversation as resolved.
Show resolved Hide resolved

def sample(self, key, sample_shape=()):
shape = sample_shape + self.batch_shape + self.event_shape
return jnp.broadcast_to(self.v, shape)
Expand Down
4 changes: 2 additions & 2 deletions numpyro/distributions/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class InverseAutoregressiveTransform(Transform):
"""
domain = real_vector
codomain = real_vector
event_dim = 1

def __init__(self, autoregressive_nn, log_scale_min_clip=-5., log_scale_max_clip=3.):
"""
Expand Down Expand Up @@ -93,7 +92,8 @@ class BlockNeuralAutoregressiveTransform(Transform):
1. *Block Neural Autoregressive Flow*,
Nicola De Cao, Ivan Titov, Wilker Aziz
"""
event_dim = 1
domain = real_vector
codomain = real_vector

def __init__(self, bn_arn):
self.bn_arn = bn_arn
Expand Down
115 changes: 114 additions & 1 deletion numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from jax import ops, tree_flatten, tree_map, vmap
from jax import lax, ops, tree_flatten, tree_map, vmap
from jax.dtypes import canonicalize_dtype
from jax.flatten_util import ravel_pytree
from jax.nn import softplus
Expand Down Expand Up @@ -76,6 +76,20 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
def call_with_intermediates(self, x):
return self(x), None

def forward_shape(self, shape):
"""
Infers the shape of the forward computation, given the input shape.
Defaults to preserving shape.
"""
return shape

def inverse_shape(self, shape):
"""
Infers the shapes of the inverse computation, given the output shape.
Defaults to preserving shape.
"""
return shape


class _InverseTransform(Transform):
def __init__(self, transform):
Expand All @@ -101,6 +115,12 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
# NB: we don't use intermediates for inverse transform
return -self._inv.log_abs_det_jacobian(y, x, None)

def forward_shape(self, shape):
return self._inv.inverse_shape(shape)

def inverse_shape(self, shape):
return self._inv.forward_shape(shape)


class AbsTransform(Transform):
domain = constraints.real
Expand Down Expand Up @@ -161,6 +181,16 @@ def _inverse(self, y):
def log_abs_det_jacobian(self, x, y, intermediates=None):
return jnp.broadcast_to(jnp.log(jnp.abs(self.scale)), jnp.shape(x))

def forward_shape(self, shape):
return lax.broadcast_shapes(shape,
getattr(self.loc, "shape", ()),
getattr(self.scale, "shape", ()))

def inverse_shape(self, shape):
return lax.broadcast_shapes(shape,
getattr(self.loc, "shape", ()),
getattr(self.scale, "shape", ()))


def _get_compose_transform_input_event_dim(parts):
input_event_dim = parts[-1].domain.event_dim
Expand Down Expand Up @@ -243,6 +273,39 @@ def call_with_intermediates(self, x):
intermediates.append(inter)
return x, intermediates

def forward_shape(self, shape):
for part in self.parts:
shape = part.forward_shape(shape)
return shape

def inverse_shape(self, shape):
for part in reversed(self.parts):
shape = part.inverse_shape(shape)
return shape


def _matrix_forward_shape(shape, offset=0):
# Reshape from (..., N) to (..., D, D).
if len(shape) < 1:
raise ValueError("Too few dimensions in input")
N = shape[-1]
D = round((0.25 + 2 * N) ** 0.5 - 0.5)
if D * (D + 1) // 2 != N:
raise ValueError("Input is not a flattend lower-diagonal number")
D = D - offset
return shape[:-1] + (D, D)


def _matrix_inverse_shape(shape, offset=0):
# Reshape from (..., D, D) to (..., N).
if len(shape) < 2:
raise ValueError("Too few dimensions on input")
if shape[-2] != shape[-1]:
raise ValueError("Input is not square")
D = shape[-1] + offset
N = D * (D + 1) // 2
return shape[:-2] + (N,)


class CorrCholeskyTransform(Transform):
r"""
Expand Down Expand Up @@ -306,6 +369,12 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
tanh_logdet = -2 * jnp.sum(x + softplus(-2 * x) - jnp.log(2.), axis=-1)
return stick_breaking_logdet + tanh_logdet

def forward_shape(self, shape):
return _matrix_forward_shape(shape, offset=-1)

def inverse_shape(self, shape):
return _matrix_inverse_shape(shape, offset=-1)


class ExpTransform(Transform):
# TODO: refine domain/codomain logic through setters, especially when
Expand Down Expand Up @@ -386,6 +455,12 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
def call_with_intermediates(self, x):
return self.base_transform.call_with_intermediates(x)

def forward_shape(self, shape):
return self.base_transform.forward_shape(shape)

def inverse_shape(self, shape):
return self.base_transform.inverse_shape(shape)


class InvCholeskyTransform(Transform):
r"""
Expand Down Expand Up @@ -455,6 +530,16 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
return jnp.broadcast_to(jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1),
jnp.shape(x)[:-1])

def forward_shape(self, shape):
if len(shape) < 1:
raise ValueError("Too few dimensions on input")
return lax.broadcast_shapes(shape, self.loc.shape, self.scale_tril.shape[:-1])

def inverse_shape(self, shape):
if len(shape) < 1:
raise ValueError("Too few dimensions on input")
return lax.broadcast_shapes(shape, self.loc.shape, self.scale_tril.shape[:-1])


class LowerCholeskyTransform(Transform):
domain = constraints.real_vector
Expand All @@ -475,6 +560,12 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2)
return x[..., -n:].sum(-1)

def forward_shape(self, shape):
return _matrix_forward_shape(shape)

def inverse_shape(self, shape):
return _matrix_inverse_shape(shape)


class OrderedTransform(Transform):
"""
Expand Down Expand Up @@ -537,6 +628,12 @@ def _inverse(self, y):
def log_abs_det_jacobian(self, x, y, intermediates=None):
return jnp.log(jnp.abs(self.exponent * y / x))

def forward_shape(self, shape):
return lax.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))

def inverse_shape(self, shape):
return lax.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))


class SigmoidTransform(Transform):
codomain = constraints.unit_interval
Expand Down Expand Up @@ -586,6 +683,16 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
# the case z ~ 1
return jnp.sum(jnp.log(y[..., :-1] * z) - x, axis=-1)

def forward_shape(self, shape):
if len(shape) < 1:
raise ValueError("Too few dimensions on input")
return shape[:-1] + (shape[-1] + 1,)

def inverse_shape(self, shape):
if len(shape) < 1:
raise ValueError("Too few dimensions on input")
return shape[:-1] + (shape[-1] - 1,)


class UnpackTransform(Transform):
"""
Expand Down Expand Up @@ -620,6 +727,12 @@ def _inverse(self, y):
def log_abs_det_jacobian(self, x, y, intermediates=None):
return jnp.zeros(jnp.shape(x)[:-1])

def forward_shape(self, shape):
raise NotImplementedError

def inverse_shape(self, shape):
raise NotImplementedError


##########################################################
# CONSTRAINT_REGISTRY
Expand Down
14 changes: 1 addition & 13 deletions numpyro/infer/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,8 @@ def init_to_uniform(site=None, radius=2):
sample_shape = site['kwargs'].get('sample_shape')
rng_key, subkey = random.split(rng_key)

# this is used to interpret the changes of event_shape in
# domain and codomain spaces
try:
prototype_value = site['fn'](rng_key=subkey, sample_shape=())
except NotImplementedError:
# XXX: this works for ImproperUniform prior,
# we can't use this logic for general priors
# because some distributions such as TransformedDistribution might
# have wrong event_shape.
# TODO: address this when infer_shapes is available
prototype_value = jnp.full(site['fn'].shape(), jnp.nan)

transform = biject_to(site['fn'].support)
unconstrained_shape = jnp.shape(transform.inv(prototype_value))
unconstrained_shape = transform.inverse_shape(site["fn"].shape())
unconstrained_samples = dist.Uniform(-radius, radius)(
rng_key=rng_key, sample_shape=sample_shape + unconstrained_shape)
return transform(unconstrained_samples)
Expand Down
3 changes: 3 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,9 @@ def test_biject_to(constraint, shape):
x = random.normal(rng_key, shape)
y = transform(x)

assert transform.forward_shape(x.shape) == y.shape
assert transform.inverse_shape(y.shape) == x.shape

# test inv work for NaN arrays:
x_nan = transform.inv(jnp.full(jnp.shape(y), jnp.nan))
assert (x_nan.shape == x.shape)
Expand Down