Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rule for creating Gaussians from Affine inputs #119

Merged
merged 14 commits into from
Apr 7, 2019
4 changes: 2 additions & 2 deletions examples/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def model(data):

# A delayed sample statement.
x_curr = funsor.Variable('x_{}'.format(t), funsor.reals())
log_prob += dist.Normal(x_prev, trans_noise, value=x_curr)
log_prob += dist.Normal(1 + x_prev / 2., trans_noise, value=x_curr)

if not args.lazy and isinstance(x_prev, funsor.Variable):
log_prob = log_prob.reduce(ops.logaddexp, x_prev.name)

log_prob += dist.Normal(x_curr, emit_noise, value=y)
log_prob += dist.Normal(0.5 + 3 * x_curr, emit_noise, value=y)

log_prob = log_prob.reduce(ops.logaddexp)
return log_prob
Expand Down
2 changes: 1 addition & 1 deletion funsor/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def eager_binary(op, var, other):
elif op is ops.sub:
return var + -other
elif op is ops.truediv:
return var * ops.invert(other)
return var * (1. / other)
return None


Expand Down
40 changes: 25 additions & 15 deletions funsor/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
import torch
from six import add_metaclass

from pyro.distributions.util import broadcast_shape

import funsor.delta
import funsor.ops as ops
from funsor.affine import Affine
from funsor.domains import bint, reals
from funsor.gaussian import Gaussian
from funsor.interpreter import interpretation
Expand Down Expand Up @@ -241,23 +244,30 @@ def eager_normal(loc, scale, value):
return Normal(loc, scale, 'value')(value=value)


# Create a Gaussian from a noisy identity transform.
# This is extremely limited but suffices for examples/kalman_filter.py
@eager.register(Normal, Variable, Tensor, Variable)
@eager.register(Normal, (Variable, Affine), Tensor, (Variable, Affine))
@eager.register(Normal, (Variable, Affine), Tensor, Tensor)
@eager.register(Normal, Tensor, Tensor, (Variable, Affine))
def eager_normal(loc, scale, value):
assert loc.output == reals()
assert value.output == reals()
assert loc.name != value.name
inputs = loc.inputs.copy()
inputs.update(scale.inputs)
inputs.update(value.inputs)
int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real')
affine = (loc - value) / scale
assert isinstance(affine, Affine)
real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real')
assert not any(v.shape for v in real_inputs.values())

tensors = [affine.const] + [c for v, c in affine.coeffs.items()]
inputs, tensors = align_tensors(*tensors)
shape = broadcast_shape(*(t.shape for t in tensors))
const, coeffs = tensors[0], tensors[1:]

dim = sum(d.num_elements for d in real_inputs.values())
loc = const.new_zeros(shape + (dim,))
loc[..., 0] = -const / coeffs[0]
precision = const.new_empty(shape + (dim, dim))
for i, (v1, c1) in enumerate(zip(real_inputs, coeffs)):
for j, (v2, c2) in enumerate(zip(real_inputs, coeffs)):
precision[..., i, j] = c1 * c2

log_prob = -0.5 * math.log(2 * math.pi) - scale.data.log()
loc = scale.data.new_zeros(scale.data.shape + (2,))
p = scale.data.pow(-2)
precision = torch.stack([p, -p, -p, p], -1).reshape(p.shape + (2, 2))
return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs)
log_prob = -0.5 * math.log(2 * math.pi) - scale.log()
return log_prob + Gaussian(loc, precision, affine.inputs)


class MultivariateNormal(Distribution):
Expand Down
17 changes: 14 additions & 3 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def sym_inverse(mat):
return torch.pinverse(mat)


def _sym_solve_mv(mat, vec):
def sym_solve_mv(mat, vec):
r"""
Computes ``mat \ vec`` assuming mat is symmetric and usually positive definite,
but falling back to general pseudoinverse if positive definiteness fails.
Expand Down Expand Up @@ -224,6 +224,17 @@ def __repr__(self):
return 'Gaussian(..., ({}))'.format(' '.join(
'({}, {}),'.format(*kv) for kv in self.inputs.items()))

def align(self, names):
assert isinstance(names, tuple)
assert all(name in self.inputs for name in names)
if not names or names == tuple(self.inputs):
return self

inputs = OrderedDict((name, self.inputs[name]) for name in names)
inputs.update(self.inputs)
loc, precision = align_gaussian(inputs, self)
return Gaussian(loc, precision, inputs)

def eager_subs(self, subs):
assert isinstance(subs, tuple)
subs = tuple((k, materialize(v)) for k, v in subs if k in self.inputs)
Expand Down Expand Up @@ -359,7 +370,7 @@ def eager_reduce(self, op, reduced_vars):
old_ints).reduce(ops.add, reduced_vars)
assert precision.inputs == new_ints
assert precision_loc.inputs == new_ints
loc = Tensor(_sym_solve_mv(precision.data, precision_loc.data), new_ints)
loc = Tensor(sym_solve_mv(precision.data, precision_loc.data), new_ints)
expanded_loc = align_tensor(old_ints, loc)
quadratic_term = Tensor(_vmv(self.precision, expanded_loc - self.loc),
old_ints).reduce(ops.add, reduced_vars)
Expand Down Expand Up @@ -421,7 +432,7 @@ def eager_add_gaussian_gaussian(op, lhs, rhs):
# Fuse aligned Gaussians.
precision_loc = _mv(lhs_precision, lhs_loc) + _mv(rhs_precision, rhs_loc)
precision = lhs_precision + rhs_precision
loc = _sym_solve_mv(precision, precision_loc)
loc = sym_solve_mv(precision, precision_loc)
quadratic_term = _vmv(lhs_precision, loc - lhs_loc) + _vmv(rhs_precision, loc - rhs_loc)
likelihood = Tensor(-0.5 * quadratic_term, int_inputs)
return likelihood + Gaussian(loc, precision, inputs)
Expand Down
4 changes: 1 addition & 3 deletions funsor/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,9 @@ def align(self, names):
assert all(name in self.inputs for name in names)
if not names or names == tuple(self.inputs):
return self

inputs = OrderedDict((name, self.inputs[name]) for name in names)
inputs.update(self.inputs)

if any(d.shape for d in self.inputs.values()):
raise NotImplementedError("TODO: Implement align with vector indices.")
old_dims = tuple(self.inputs)
new_dims = tuple(inputs)
data = self.data.permute(tuple(old_dims.index(d) for d in new_dims))
Expand Down
29 changes: 29 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,35 @@ def test_normal_gaussian_3(batch_shape):
assert_close(actual, expected, atol=1e-4)


NORMAL_AFFINE_TESTS = [
'dist.Normal(x+2, scale, y+2)',
'dist.Normal(y, scale, x)',
'dist.Normal(x - y, scale, 0)',
'dist.Normal(0, scale, y - x)',
'dist.Normal(2 * x - y, scale, x)',
# TODO should we expect these to work without correction terms?
'dist.Normal(0, 1, (x - y) / scale) - scale.log()',
'dist.Normal(2 * y, 2 * scale, 2 * x) + math.log(2)',
]


@pytest.mark.parametrize('expr', NORMAL_AFFINE_TESTS)
def test_normal_affine(expr):

scale = Tensor(torch.tensor(0.3), OrderedDict())
x = Variable('x', reals())
y = Variable('y', reals())

expected = dist.Normal(x, scale, y)
actual = eval(expr)

assert isinstance(actual, Joint)
assert dict(actual.inputs) == dict(expected.inputs), (actual.inputs, expected.inputs)

assert_close(actual.gaussian.align(tuple(expected.gaussian.inputs)), expected.gaussian)
assert_close(actual.discrete.align(tuple(expected.discrete.inputs)), expected.discrete)


def test_normal_independent():
loc = random_tensor(OrderedDict(), reals(2))
scale = random_tensor(OrderedDict(), reals(2)).exp()
Expand Down
24 changes: 24 additions & 0 deletions test/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,30 @@ def test_smoke(expr, expected_type):
assert isinstance(result, expected_type)


@pytest.mark.parametrize('int_inputs', [
{},
{'i': bint(2)},
{'i': bint(2), 'j': bint(3)},
], ids=id_from_inputs)
@pytest.mark.parametrize('real_inputs', [
{'x': reals()},
{'x': reals(4)},
{'x': reals(2, 3)},
{'x': reals(), 'y': reals()},
{'x': reals(2), 'y': reals(3)},
{'x': reals(4), 'y': reals(2, 3), 'z': reals()},
], ids=id_from_inputs)
def test_align(int_inputs, real_inputs):
inputs1 = OrderedDict(list(sorted(int_inputs.items())) +
list(sorted(real_inputs.items())))
inputs2 = OrderedDict(reversed(inputs1.items()))
g1 = random_gaussian(inputs1)
g2 = g1.align(tuple(inputs2))
assert g2.inputs == inputs2
g3 = g2.align(tuple(inputs1))
assert_close(g3, g1)


@pytest.mark.parametrize('int_inputs', [
{},
{'i': bint(2)},
Expand Down