Skip to content

Commit

Permalink
fix tests, addresses comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Jan 23, 2021
1 parent 9921ec1 commit 9da9b9b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 11 deletions.
10 changes: 8 additions & 2 deletions numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,16 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
return self.bijector.forward_log_det_jacobian(x, self.domain.event_dim)

def forward_shape(self, shape):
return self.bijector.forward_event_shape(shape)
out_shape = self.bijector.forward_event_shape(shape)
in_event_shape = self.bijector.inverse_event_shape(out_shape)
batch_shape = shape[:len(shape) - len(in_event_shape)]
return batch_shape + out_shape

def inverse_shape(self, shape):
return self.bijector.inverse_event_shape(shape)
in_shape = self.bijector.inverse_event_shape(shape)
out_event_shape = self.bijector.forward_event_shape(in_shape)
batch_shape = shape[:len(shape) - len(out_event_shape)]
return batch_shape + in_shape


@biject_to.register(BijectorConstraint)
Expand Down
30 changes: 21 additions & 9 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,20 +749,31 @@ def __init__(self, base_distribution, transforms, validate_args=None):
else:
raise ValueError("transforms must be a Transform or list, but was {}".format(transforms))
if isinstance(base_distribution, TransformedDistribution):
self.base_dist = base_distribution.base_dist
base_dist = base_distribution.base_dist
self.transforms = base_distribution.transforms + transforms
else:
self.base_dist = base_distribution
base_dist = base_distribution
self.transforms = transforms
shape = base_distribution.batch_shape + base_distribution.event_shape
base_ndim = len(shape)
base_shape = base_dist.shape()
base_event_dim = base_dist.event_dim
transform = ComposeTransform(self.transforms)
transform_input_event_dim = transform.domain.event_dim
if base_ndim < transform_input_event_dim:
domain_event_dim = transform.domain.event_dim
if len(base_shape) < domain_event_dim:
raise ValueError("Base distribution needs to have shape with size at least {}, but got {}."
.format(transform_input_event_dim, base_ndim))
event_dim = transform.codomain.event_dim + max(self.base_dist.event_dim - transform_input_event_dim, 0)
shape = transform.forward_shape(shape)
.format(domain_event_dim, base_shape))
shape = transform.forward_shape(base_shape)
expanded_base_shape = transform.inverse_shape(shape)
if base_shape != expanded_base_shape:
base_batch_shape = expanded_base_shape[:len(expanded_base_shape) - base_event_dim]
base_dist = base_dist.expand(base_batch_shape)
reinterpreted_batch_ndims = domain_event_dim - base_event_dim
if reinterpreted_batch_ndims > 0:
base_dist = base_dist.to_event(reinterpreted_batch_ndims)
self.base_dist = base_dist

# Compute shapes.
event_dim = transform.codomain.event_dim + max(base_event_dim - domain_event_dim, 0)
assert len(shape) >= event_dim
cut = len(shape) - event_dim
batch_shape = shape[:cut]
event_shape = shape[cut:]
Expand Down Expand Up @@ -858,6 +869,7 @@ def __init__(self, v=0., log_density=0., event_dim=0, validate_args=None):
self.log_density = promote_shapes(log_density, shape=batch_shape)[0]
super(Delta, self).__init__(batch_shape, event_shape, validate_args=validate_args)

@property
def support(self):
return independent(real, self.event_dim)

Expand Down

0 comments on commit 9da9b9b

Please sign in to comment.