From 9da9b9b2fb750f404a492f2a795b1e16f07e0696 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 22 Jan 2021 23:25:51 -0600 Subject: [PATCH] fix tests, addresses comments --- numpyro/contrib/tfp/distributions.py | 10 +++++++-- numpyro/distributions/distribution.py | 30 +++++++++++++++++++-------- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/numpyro/contrib/tfp/distributions.py b/numpyro/contrib/tfp/distributions.py index 1d01a97c0..5714099a3 100644 --- a/numpyro/contrib/tfp/distributions.py +++ b/numpyro/contrib/tfp/distributions.py @@ -87,10 +87,16 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): return self.bijector.forward_log_det_jacobian(x, self.domain.event_dim) def forward_shape(self, shape): - return self.bijector.forward_event_shape(shape) + out_shape = self.bijector.forward_event_shape(shape) + in_event_shape = self.bijector.inverse_event_shape(out_shape) + batch_shape = shape[:len(shape) - len(in_event_shape)] + return batch_shape + out_shape def inverse_shape(self, shape): - return self.bijector.inverse_event_shape(shape) + in_shape = self.bijector.inverse_event_shape(shape) + out_event_shape = self.bijector.forward_event_shape(in_shape) + batch_shape = shape[:len(shape) - len(out_event_shape)] + return batch_shape + in_shape @biject_to.register(BijectorConstraint) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 14b009717..b59920454 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -749,20 +749,31 @@ def __init__(self, base_distribution, transforms, validate_args=None): else: raise ValueError("transforms must be a Transform or list, but was {}".format(transforms)) if isinstance(base_distribution, TransformedDistribution): - self.base_dist = base_distribution.base_dist + base_dist = base_distribution.base_dist self.transforms = base_distribution.transforms + transforms else: - self.base_dist = base_distribution + base_dist = base_distribution self.transforms = transforms - shape = base_distribution.batch_shape + base_distribution.event_shape - base_ndim = len(shape) + base_shape = base_dist.shape() + base_event_dim = base_dist.event_dim transform = ComposeTransform(self.transforms) - transform_input_event_dim = transform.domain.event_dim - if base_ndim < transform_input_event_dim: + domain_event_dim = transform.domain.event_dim + if len(base_shape) < domain_event_dim: raise ValueError("Base distribution needs to have shape with size at least {}, but got {}." - .format(transform_input_event_dim, base_ndim)) - event_dim = transform.codomain.event_dim + max(self.base_dist.event_dim - transform_input_event_dim, 0) - shape = transform.forward_shape(shape) + .format(domain_event_dim, base_shape)) + shape = transform.forward_shape(base_shape) + expanded_base_shape = transform.inverse_shape(shape) + if base_shape != expanded_base_shape: + base_batch_shape = expanded_base_shape[:len(expanded_base_shape) - base_event_dim] + base_dist = base_dist.expand(base_batch_shape) + reinterpreted_batch_ndims = domain_event_dim - base_event_dim + if reinterpreted_batch_ndims > 0: + base_dist = base_dist.to_event(reinterpreted_batch_ndims) + self.base_dist = base_dist + + # Compute shapes. + event_dim = transform.codomain.event_dim + max(base_event_dim - domain_event_dim, 0) + assert len(shape) >= event_dim cut = len(shape) - event_dim batch_shape = shape[:cut] event_shape = shape[cut:] @@ -858,6 +869,7 @@ def __init__(self, v=0., log_density=0., event_dim=0, validate_args=None): self.log_density = promote_shapes(log_density, shape=batch_shape)[0] super(Delta, self).__init__(batch_shape, event_shape, validate_args=validate_args) + @property def support(self): return independent(real, self.event_dim)