Skip to content

Commit

Permalink
Implement a Joint normal form funsor (#69)
Browse files Browse the repository at this point in the history
* Add a simple delta distribution

* Add tests for nontrivial event dim

* Simplify unit test

* Sketch a general Delta funsor class

* Simplify to binding a single name in Delta

* Add some tests for Delta

* Add test for ground substitution

* Add tests for reduction

* Add test for conversion from dist.Delta to Delta

* Sketch JointNormalForm funsor

* Settle on Joint interface

* Add more + handling

* Remove .log_density field from Delta funsor

* Drop handling of .log_density from Joint

* Add logic promoting various Binary(-,-) to Joint

* Revert "Remove .log_density field from Delta funsor"

This reverts commit 897f523.

* Revert "Drop handling of .log_density from Joint"

This reverts commit a7d0082.

* Simplify Gaussian funsor

* WIP Refactor Joint patterns

* Get Gaussian working with Joint

* Add a smoke test for Joint

* Add test for reduction

* Make xfail more targeted

* Update docstring on Joint

* Remove unnecessary handling of Binary(ops.add,...)
  • Loading branch information
fritzo authored and eb8680 committed Mar 13, 2019
1 parent 9dbb231 commit 5fa7fa6
Show file tree
Hide file tree
Showing 12 changed files with 494 additions and 163 deletions.
5 changes: 3 additions & 2 deletions funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from funsor.terms import Funsor, Number, Variable, of_shape, to_funsor
from funsor.torch import Function, Tensor, arange, function, torch_einsum

from . import (adjoint, contract, delta, distributions, domains, einsum, gaussian, handlers, interpreter, minipyro, ops,
terms, torch)
from . import (adjoint, contract, delta, distributions, domains, einsum, gaussian, handlers, interpreter, joint,
minipyro, ops, terms, torch)

__all__ = [
'Domain',
Expand All @@ -29,6 +29,7 @@
'gaussian',
'handlers',
'interpreter',
'joint',
'minipyro',
'of_shape',
'ops',
Expand Down
15 changes: 9 additions & 6 deletions funsor/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,12 @@ def eager_normal(loc, scale, value):

inputs, (loc, scale) = align_tensors(loc, scale)
inputs.update(value.inputs)
int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real')

log_density = -0.5 * math.log(2 * math.pi) - scale.log()
log_prob = -0.5 * math.log(2 * math.pi) - scale.log()
loc = loc.unsqueeze(-1)
precision = scale.pow(-2).unsqueeze(-1).unsqueeze(-1)
return Gaussian(log_density, loc, precision, inputs)
return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs)


# Create a Gaussian from a noisy identity transform.
Expand All @@ -215,12 +216,13 @@ def eager_normal(loc, scale, value):
inputs = loc.inputs.copy()
inputs.update(scale.inputs)
inputs.update(value.inputs)
int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real')

log_density = -0.5 * math.log(2 * math.pi) - scale.data.log()
log_prob = -0.5 * math.log(2 * math.pi) - scale.data.log()
loc = scale.data.new_zeros(scale.data.shape + (2,))
p = scale.data.pow(-2)
precision = torch.stack([p, -p, -p, p], -1).reshape(p.shape + (2, 2))
return Gaussian(log_density, loc, precision, inputs)
return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs)


class MultivariateNormal(Distribution):
Expand Down Expand Up @@ -261,11 +263,12 @@ def eager_mvn(loc, scale_tril, value):
dim, = loc.output.shape
inputs, (loc, scale_tril) = align_tensors(loc, scale_tril)
inputs.update(value.inputs)
int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real')

log_density = -0.5 * dim * math.log(2 * math.pi) - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1)
log_prob = -0.5 * dim * math.log(2 * math.pi) - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1)
inv_scale_tril = torch.inverse(scale_tril)
precision = torch.matmul(inv_scale_tril.transpose(-1, -2), inv_scale_tril)
return Gaussian(log_density, loc, precision, inputs)
return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs)


__all__ = [
Expand Down
161 changes: 42 additions & 119 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import funsor.ops as ops
from funsor.domains import reals
from funsor.ops import Op
from funsor.ops import AddOp
from funsor.terms import Binary, Funsor, FunsorMeta, Number, eager
from funsor.torch import Tensor, align_tensor, align_tensors, materialize

Expand All @@ -24,7 +24,7 @@ def _issubshape(subshape, supershape):


def _log_det_tril(x):
return x.diagonal(dim1=-1, dim2=-2).log().sum()
return x.diagonal(dim1=-1, dim2=-2).log().sum(-1)


def _mv(mat, vec):
Expand Down Expand Up @@ -66,7 +66,6 @@ def align_gaussian(new_inputs, old):
"""
assert isinstance(new_inputs, OrderedDict)
assert isinstance(old, Gaussian)
log_density = old.log_density
loc = old.loc
precision = old.precision

Expand All @@ -75,7 +74,6 @@ def align_gaussian(new_inputs, old):
new_ints = OrderedDict((k, d) for k, d in new_inputs.items() if d.dtype != 'real')
old_ints = OrderedDict((k, d) for k, d in old.inputs.items() if d.dtype != 'real')
if new_ints != old_ints:
log_density = align_tensor(new_ints, Tensor(log_density, old_ints))
loc = align_tensor(new_ints, Tensor(loc, old_ints))
precision = align_tensor(new_ints, Tensor(precision, old_ints))

Expand Down Expand Up @@ -106,17 +104,18 @@ def align_gaussian(new_inputs, old):
new_slice2 = slice(new_offset2, new_offset2 + num_elements2)
precision[..., new_slice1, new_slice2] = old_precision[..., old_slice1, old_slice2]

return log_density, loc, precision
return loc, precision


class GaussianMeta(FunsorMeta):
"""
Wrapper to convert between OrderedDict and tuple.
"""
def __call__(cls, log_density, loc, precision, inputs):
def __call__(cls, loc, precision, inputs):
if isinstance(inputs, OrderedDict):
inputs = tuple(inputs.items())
return super(GaussianMeta, cls).__call__(log_density, loc, precision, inputs)
assert isinstance(inputs, tuple)
return super(GaussianMeta, cls).__call__(loc, precision, inputs)


@add_metaclass(GaussianMeta)
Expand All @@ -125,8 +124,7 @@ class Gaussian(Funsor):
Funsor representing a batched joint Gaussian distribution as a log-density
function.
"""
def __init__(self, log_density, loc, precision, inputs):
assert isinstance(log_density, torch.Tensor)
def __init__(self, loc, precision, inputs):
assert isinstance(loc, torch.Tensor)
assert isinstance(precision, torch.Tensor)
assert isinstance(inputs, tuple)
Expand All @@ -141,13 +139,11 @@ def __init__(self, log_density, loc, precision, inputs):
# Compute total shape of all bint inputs.
batch_shape = tuple(d.dtype for d in inputs.values()
if isinstance(d.dtype, integer_types))
assert _issubshape(log_density.shape, batch_shape)
assert _issubshape(loc.shape, batch_shape + (dim,))
assert _issubshape(precision.shape, batch_shape + (dim, dim))

output = reals()
super(Gaussian, self).__init__(inputs, output)
self.log_density = log_density
self.loc = loc
self.precision = precision
self.batch_shape = batch_shape
Expand All @@ -173,26 +169,25 @@ def eager_subs(self, subs):
if int_subs:
int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real')
real_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype == 'real')
tensors = [self.log_density, self.loc, self.precision]
tensors = [self.loc, self.precision]
funsors = [Tensor(x, int_inputs).eager_subs(int_subs) for x in tensors]
inputs = funsors[0].inputs.copy()
inputs.update(real_inputs)
int_result = Gaussian(funsors[0].data, funsors[1].data, funsors[2].data, inputs)
int_result = Gaussian(funsors[0].data, funsors[1].data, inputs)
return int_result.eager_subs(real_subs)

# Try to perform a complete substitution of all real variables, resulting in a Tensor.
assert real_subs and not int_subs
if all(k in subs for k, d in self.inputs.items() if d.dtype == 'real'):
# Broadcast all component tensors.
int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real')
tensors = [Tensor(self.log_density, int_inputs),
Tensor(self.loc, int_inputs),
tensors = [Tensor(self.loc, int_inputs),
Tensor(self.precision, int_inputs)]
tensors.extend(subs.values())
inputs, tensors = align_tensors(*tensors)
batch_dim = tensors[0].dim()
batch_dim = self.loc.dim() - 1
batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors))
(log_density, loc, precision), values = tensors[:3], tensors[3:]
(loc, precision), values = tensors[:2], tensors[2:]

# Form the concatenated value.
offsets, event_size = _compute_offsets(self.inputs)
Expand All @@ -204,8 +199,10 @@ def eager_subs(self, subs):
value[..., offset: offset + self.inputs[k].num_elements] = value_k

# Evaluate the non-normalized log density.
result = log_density - 0.5 * _vmv(precision, value - loc)
return Tensor(result, inputs)
result = -0.5 * _vmv(precision, value - loc)
result = Tensor(result, inputs)
assert result.output == reals()
return result

raise NotImplementedError('TODO implement partial substitution of real variables')

Expand All @@ -220,12 +217,12 @@ def eager_reduce(self, op, reduced_vars):
return None # defer to default implementation

inputs = OrderedDict((k, d) for k, d in self.inputs.items() if k not in reduced_reals)
log_density = self.log_density
int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real')
if reduced_reals == real_vars:
dim = self.loc.size(-1)
log_det_term = _log_det_tril(torch.cholesky(self.precision))
data = log_density + log_det_term - 0.5 * math.log(2 * math.pi) * dim
result = Tensor(data, inputs)
data = log_det_term - 0.5 * math.log(2 * math.pi) * dim
result = Tensor(data, int_inputs)
else:
offsets, _ = _compute_offsets(self.inputs)
index = []
Expand All @@ -243,8 +240,8 @@ def eager_reduce(self, op, reduced_vars):
precision = torch.matmul(inv_scale_tril, inv_scale_tril.transpose(-1, -2))
reduced_dim = sum(self.inputs[k].num_elements for k in reduced_reals)
log_det_term = _log_det_tril(scale_tril) - _log_det_tril(self_scale_tril)
log_density = log_density + log_det_term - 0.5 * math.log(2 * math.pi) * reduced_dim
result = Gaussian(log_density, loc, precision, inputs)
log_prob = Tensor(log_det_term - 0.5 * math.log(2 * math.pi) * reduced_dim, int_inputs)
result = log_prob + Gaussian(loc, precision, inputs)

return result.reduce(ops.logaddexp, reduced_ints)

Expand All @@ -254,101 +251,27 @@ def eager_reduce(self, op, reduced_vars):
return None # defer to default implementation


@eager.register(Binary, Op, Gaussian, Number)
def eager_binary_gaussian_number(op, lhs, rhs):
if op is ops.add or op is ops.sub:
# Add a constant log_density term to a Gaussian.
log_density = op(lhs.log_density, rhs.data)
return Gaussian(log_density, lhs.loc, lhs.precision, lhs.inputs)

if op is ops.mul or op is ops.truediv:
# Scale a Gaussian, as under pyro.poutine.scale.
raise NotImplementedError('TODO')

return None # defer to default implementation


@eager.register(Binary, Op, Number, Gaussian)
def eager_binary_number_gaussian(op, lhs, rhs):
if op is ops.add:
# Add a constant log_density term to a Gaussian.
log_density = op(lhs.data, rhs.log_density)
return Gaussian(log_density, rhs.loc, rhs.precision, rhs.inputs)

if op is ops.mul:
# Scale a Gaussian, as under pyro.poutine.scale.
raise NotImplementedError('TODO')

return None # defer to default implementation


@eager.register(Binary, Op, Gaussian, Tensor)
def eager_binary_gaussian_tensor(op, lhs, rhs):
if op is ops.add or op is ops.sub:
# Add a batch-dependent log_density term to a Gaussian.
nonreal_inputs = OrderedDict((k, d) for k, d in lhs.inputs.items()
if d.dtype != 'real')
inputs, (rhs_data, log_density, loc, precision) = align_tensors(
rhs,
Tensor(lhs.log_density, nonreal_inputs),
Tensor(lhs.loc, nonreal_inputs),
Tensor(lhs.precision, nonreal_inputs))
log_density = op(log_density, rhs_data)
inputs.update(lhs.inputs)
return Gaussian(log_density, loc, precision, inputs)

if op is ops.mul or op is ops.truediv:
# Scale a Gaussian, as under pyro.poutine.scale.
raise NotImplementedError('TODO')

return None # defer to default implementation


@eager.register(Binary, Op, Tensor, Gaussian)
def eager_binary_tensor_gaussian(op, lhs, rhs):
if op is ops.add:
# Add a batch-dependent log_density term to a Gaussian.
nonreal_inputs = OrderedDict((k, d) for k, d in rhs.inputs.items()
if d.dtype != 'real')
inputs, (lhs_data, log_density, loc, precision) = align_tensors(
lhs,
Tensor(rhs.log_density, nonreal_inputs),
Tensor(rhs.loc, nonreal_inputs),
Tensor(rhs.precision, nonreal_inputs))
log_density = op(lhs_data, log_density)
inputs.update(rhs.inputs)
return Gaussian(log_density, loc, precision, inputs)

if op is ops.mul:
# Scale a Gaussian, as under pyro.poutine.scale.
raise NotImplementedError('TODO')

return None # defer to default implementation


@eager.register(Binary, Op, Gaussian, Gaussian)
def eager_binary_gaussian_gaussian(op, lhs, rhs):
if op is ops.add:
# Fuse two Gaussians by adding their log-densities pointwise.
# This is similar to a Kalman filter update, but also keeps track of
# the marginal likelihood which accumulates into log_density.

# Align data.
inputs = lhs.inputs.copy()
inputs.update(rhs.inputs)
lhs_log_density, lhs_loc, lhs_precision = align_gaussian(inputs, lhs)
rhs_log_density, rhs_loc, rhs_precision = align_gaussian(inputs, rhs)

# Fuse aligned Gaussians.
precision_loc = _mv(lhs_precision, lhs_loc) + _mv(rhs_precision, rhs_loc)
precision = lhs_precision + rhs_precision
scale_tril = torch.inverse(torch.cholesky(precision))
loc = _mv(scale_tril.transpose(-1, -2), _mv(scale_tril, precision_loc))
quadratic_term = _vmv(lhs_precision, loc - lhs_loc) + _vmv(rhs_precision, loc - rhs_loc)
log_density = lhs_log_density + rhs_log_density - 0.5 * quadratic_term
return Gaussian(log_density, loc, precision, inputs)

return None # defer to default implementation
@eager.register(Binary, AddOp, Gaussian, Gaussian)
def eager_add_gaussian_gaussian(op, lhs, rhs):
# Fuse two Gaussians by adding their log-densities pointwise.
# This is similar to a Kalman filter update, but also keeps track of
# the marginal likelihood which accumulates into a Tensor.

# Align data.
inputs = lhs.inputs.copy()
inputs.update(rhs.inputs)
int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real')
lhs_loc, lhs_precision = align_gaussian(inputs, lhs)
rhs_loc, rhs_precision = align_gaussian(inputs, rhs)

# Fuse aligned Gaussians.
precision_loc = _mv(lhs_precision, lhs_loc) + _mv(rhs_precision, rhs_loc)
precision = lhs_precision + rhs_precision
scale_tril = torch.inverse(torch.cholesky(precision))
loc = _mv(scale_tril.transpose(-1, -2), _mv(scale_tril, precision_loc))
quadratic_term = _vmv(lhs_precision, loc - lhs_loc) + _vmv(rhs_precision, loc - rhs_loc)
likelihood = Tensor(-0.5 * quadratic_term, int_inputs)
return likelihood + Gaussian(loc, precision, inputs)


__all__ = [
Expand Down
Loading

0 comments on commit 5fa7fa6

Please sign in to comment.