diff --git a/examples/kalman_filter.py b/examples/kalman_filter.py index aa68974c6..32f033bee 100644 --- a/examples/kalman_filter.py +++ b/examples/kalman_filter.py @@ -28,12 +28,12 @@ def model(data): # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.reals()) - log_prob += dist.Normal(x_prev, trans_noise, value=x_curr) + log_prob += dist.Normal(1 + x_prev / 2., trans_noise, value=x_curr) if not args.lazy and isinstance(x_prev, funsor.Variable): log_prob = log_prob.reduce(ops.logaddexp, x_prev.name) - log_prob += dist.Normal(x_curr, emit_noise, value=y) + log_prob += dist.Normal(0.5 + 3 * x_curr, emit_noise, value=y) log_prob = log_prob.reduce(ops.logaddexp) return log_prob diff --git a/funsor/affine.py b/funsor/affine.py index c78fc97a0..0254daea6 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -161,7 +161,7 @@ def eager_binary(op, var, other): elif op is ops.sub: return var + -other elif op is ops.truediv: - return var * ops.invert(other) + return var * (1. / other) return None diff --git a/funsor/distributions.py b/funsor/distributions.py index dbd45639d..784ad8c17 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -7,8 +7,11 @@ import torch from six import add_metaclass +from pyro.distributions.util import broadcast_shape + import funsor.delta import funsor.ops as ops +from funsor.affine import Affine from funsor.domains import bint, reals from funsor.gaussian import Gaussian from funsor.interpreter import interpretation @@ -241,23 +244,30 @@ def eager_normal(loc, scale, value): return Normal(loc, scale, 'value')(value=value) -# Create a Gaussian from a noisy identity transform. -# This is extremely limited but suffices for examples/kalman_filter.py -@eager.register(Normal, Variable, Tensor, Variable) +@eager.register(Normal, (Variable, Affine), Tensor, (Variable, Affine)) +@eager.register(Normal, (Variable, Affine), Tensor, Tensor) +@eager.register(Normal, Tensor, Tensor, (Variable, Affine)) def eager_normal(loc, scale, value): - assert loc.output == reals() - assert value.output == reals() - assert loc.name != value.name - inputs = loc.inputs.copy() - inputs.update(scale.inputs) - inputs.update(value.inputs) - int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') + affine = (loc - value) / scale + assert isinstance(affine, Affine) + real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real') + assert not any(v.shape for v in real_inputs.values()) + + tensors = [affine.const] + [c for v, c in affine.coeffs.items()] + inputs, tensors = align_tensors(*tensors) + shape = broadcast_shape(*(t.shape for t in tensors)) + const, coeffs = tensors[0], tensors[1:] + + dim = sum(d.num_elements for d in real_inputs.values()) + loc = const.new_zeros(shape + (dim,)) + loc[..., 0] = -const / coeffs[0] + precision = const.new_empty(shape + (dim, dim)) + for i, (v1, c1) in enumerate(zip(real_inputs, coeffs)): + for j, (v2, c2) in enumerate(zip(real_inputs, coeffs)): + precision[..., i, j] = c1 * c2 - log_prob = -0.5 * math.log(2 * math.pi) - scale.data.log() - loc = scale.data.new_zeros(scale.data.shape + (2,)) - p = scale.data.pow(-2) - precision = torch.stack([p, -p, -p, p], -1).reshape(p.shape + (2, 2)) - return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs) + log_prob = -0.5 * math.log(2 * math.pi) - scale.log() + return log_prob + Gaussian(loc, precision, affine.inputs) class MultivariateNormal(Distribution): diff --git a/funsor/gaussian.py b/funsor/gaussian.py index 795ac103d..3c87e61d5 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -83,7 +83,7 @@ def sym_inverse(mat): return torch.pinverse(mat) -def _sym_solve_mv(mat, vec): +def sym_solve_mv(mat, vec): r""" Computes ``mat \ vec`` assuming mat is symmetric and usually positive definite, but falling back to general pseudoinverse if positive definiteness fails. @@ -224,6 +224,17 @@ def __repr__(self): return 'Gaussian(..., ({}))'.format(' '.join( '({}, {}),'.format(*kv) for kv in self.inputs.items())) + def align(self, names): + assert isinstance(names, tuple) + assert all(name in self.inputs for name in names) + if not names or names == tuple(self.inputs): + return self + + inputs = OrderedDict((name, self.inputs[name]) for name in names) + inputs.update(self.inputs) + loc, precision = align_gaussian(inputs, self) + return Gaussian(loc, precision, inputs) + def eager_subs(self, subs): assert isinstance(subs, tuple) subs = tuple((k, materialize(v)) for k, v in subs if k in self.inputs) @@ -359,7 +370,7 @@ def eager_reduce(self, op, reduced_vars): old_ints).reduce(ops.add, reduced_vars) assert precision.inputs == new_ints assert precision_loc.inputs == new_ints - loc = Tensor(_sym_solve_mv(precision.data, precision_loc.data), new_ints) + loc = Tensor(sym_solve_mv(precision.data, precision_loc.data), new_ints) expanded_loc = align_tensor(old_ints, loc) quadratic_term = Tensor(_vmv(self.precision, expanded_loc - self.loc), old_ints).reduce(ops.add, reduced_vars) @@ -421,7 +432,7 @@ def eager_add_gaussian_gaussian(op, lhs, rhs): # Fuse aligned Gaussians. precision_loc = _mv(lhs_precision, lhs_loc) + _mv(rhs_precision, rhs_loc) precision = lhs_precision + rhs_precision - loc = _sym_solve_mv(precision, precision_loc) + loc = sym_solve_mv(precision, precision_loc) quadratic_term = _vmv(lhs_precision, loc - lhs_loc) + _vmv(rhs_precision, loc - rhs_loc) likelihood = Tensor(-0.5 * quadratic_term, int_inputs) return likelihood + Gaussian(loc, precision, inputs) diff --git a/funsor/torch.py b/funsor/torch.py index b95b948c1..520fc1b7f 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -136,11 +136,9 @@ def align(self, names): assert all(name in self.inputs for name in names) if not names or names == tuple(self.inputs): return self + inputs = OrderedDict((name, self.inputs[name]) for name in names) inputs.update(self.inputs) - - if any(d.shape for d in self.inputs.values()): - raise NotImplementedError("TODO: Implement align with vector indices.") old_dims = tuple(self.inputs) new_dims = tuple(inputs) data = self.data.permute(tuple(old_dims.index(d) for d in new_dims)) diff --git a/test/test_distributions.py b/test/test_distributions.py index c8d6a9203..1ef394f80 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -204,6 +204,35 @@ def test_normal_gaussian_3(batch_shape): assert_close(actual, expected, atol=1e-4) +NORMAL_AFFINE_TESTS = [ + 'dist.Normal(x+2, scale, y+2)', + 'dist.Normal(y, scale, x)', + 'dist.Normal(x - y, scale, 0)', + 'dist.Normal(0, scale, y - x)', + 'dist.Normal(2 * x - y, scale, x)', + # TODO should we expect these to work without correction terms? + 'dist.Normal(0, 1, (x - y) / scale) - scale.log()', + 'dist.Normal(2 * y, 2 * scale, 2 * x) + math.log(2)', +] + + +@pytest.mark.parametrize('expr', NORMAL_AFFINE_TESTS) +def test_normal_affine(expr): + + scale = Tensor(torch.tensor(0.3), OrderedDict()) + x = Variable('x', reals()) + y = Variable('y', reals()) + + expected = dist.Normal(x, scale, y) + actual = eval(expr) + + assert isinstance(actual, Joint) + assert dict(actual.inputs) == dict(expected.inputs), (actual.inputs, expected.inputs) + + assert_close(actual.gaussian.align(tuple(expected.gaussian.inputs)), expected.gaussian) + assert_close(actual.discrete.align(tuple(expected.discrete.inputs)), expected.discrete) + + def test_normal_independent(): loc = random_tensor(OrderedDict(), reals(2)) scale = random_tensor(OrderedDict(), reals(2)).exp() diff --git a/test/test_gaussian.py b/test/test_gaussian.py index b027d3218..ea2314ae3 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -84,6 +84,30 @@ def test_smoke(expr, expected_type): assert isinstance(result, expected_type) +@pytest.mark.parametrize('int_inputs', [ + {}, + {'i': bint(2)}, + {'i': bint(2), 'j': bint(3)}, +], ids=id_from_inputs) +@pytest.mark.parametrize('real_inputs', [ + {'x': reals()}, + {'x': reals(4)}, + {'x': reals(2, 3)}, + {'x': reals(), 'y': reals()}, + {'x': reals(2), 'y': reals(3)}, + {'x': reals(4), 'y': reals(2, 3), 'z': reals()}, +], ids=id_from_inputs) +def test_align(int_inputs, real_inputs): + inputs1 = OrderedDict(list(sorted(int_inputs.items())) + + list(sorted(real_inputs.items()))) + inputs2 = OrderedDict(reversed(inputs1.items())) + g1 = random_gaussian(inputs1) + g2 = g1.align(tuple(inputs2)) + assert g2.inputs == inputs2 + g3 = g2.align(tuple(inputs1)) + assert_close(g3, g1) + + @pytest.mark.parametrize('int_inputs', [ {}, {'i': bint(2)},