diff --git a/funsor/affine.py b/funsor/affine.py index fc1e46a3d..f26ab74ca 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -8,6 +8,11 @@ from funsor.torch import Tensor +# FIXME change this to a sound but incomplete test using pattern matching. +def is_affine(fn): + return True + + def extract_affine(fn): """ Extracts an affine representation of a funsor, which is exact for affine diff --git a/funsor/distributions.py b/funsor/distributions.py index c4d4d7e00..75df088ef 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -7,9 +7,10 @@ import funsor.delta import funsor.ops as ops +from funsor.affine import extract_affine, is_affine from funsor.cnf import Contraction from funsor.domains import bint, reals -from funsor.gaussian import BlockMatrix, BlockVector, Gaussian, cholesky_inverse +from funsor.gaussian import BlockMatrix, BlockVector, Gaussian from funsor.interpreter import interpretation from funsor.terms import Funsor, FunsorMeta, Number, Variable, eager, lazy, to_funsor from funsor.torch import Tensor, align_tensors, ignore_jit_warnings, materialize, torch_stack @@ -449,23 +450,80 @@ def eager_mvn(loc, scale_tril, value): # Create a Gaussian from a ground observation. -@eager.register(MultivariateNormal, Tensor, Tensor, Variable) -@eager.register(MultivariateNormal, Variable, Tensor, Tensor) +# TODO refactor this logic into Gaussian.eager_subs() and +# here return Gaussian(...scale_tril...)(value=loc-value). +@eager.register(MultivariateNormal, (Variable, Contraction), Tensor, (Variable, Contraction)) +@eager.register(MultivariateNormal, (Variable, Contraction), Tensor, Tensor) +@eager.register(MultivariateNormal, Tensor, Tensor, (Variable, Contraction)) def eager_mvn(loc, scale_tril, value): - if isinstance(loc, Variable): - loc, value = value, loc + assert len(loc.shape) == 1 + assert len(scale_tril.shape) == 2 + assert value.output == loc.output + if not is_affine(loc) or not is_affine(value): + return None # lazy + + # Extract an affine representation. + eye = torch.eye(scale_tril.data.size(-1)).expand(scale_tril.data.shape) + prec_sqrt = Tensor(eye.triangular_solve(scale_tril.data, upper=False).solution, + scale_tril.inputs) + affine = prec_sqrt @ (loc - value) + const, coeffs = extract_affine(affine) + if not isinstance(const, Tensor): + return None # lazy + if not all(isinstance(coeff, Tensor) for coeff, _ in coeffs.values()): + return None # lazy + + # Compute log_prob using funsors. + scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), scale_tril.inputs) + log_prob = (-0.5 * scale_diag.shape[0] * math.log(2 * math.pi) + - scale_diag.log().sum() - 0.5 * (const ** 2).sum()) + + # Dovetail to avoid variable name collision in einsum. + equations1 = [''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a')) for c in eqn) + for _, eqn in coeffs.values()] + equations2 = [''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a') + 1) for c in eqn) + for _, eqn in coeffs.values()] - dim, = loc.output.shape - inputs, (loc, scale_tril) = align_tensors(loc, scale_tril) - inputs.update(value.inputs) - int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') + real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real') + assert tuple(real_inputs) == tuple(coeffs) - precision = cholesky_inverse(scale_tril) - info_vec = precision.matmul(loc.unsqueeze(-1)).squeeze(-1) - log_prob = (-0.5 * dim * math.log(2 * math.pi) - - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1) - - 0.5 * (loc * info_vec).sum(-1)) - return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, inputs) + # Align and broadcast tensors. + neg_const = -const + tensors = [neg_const] + [coeff for coeff, _ in coeffs.values()] + inputs, tensors = align_tensors(*tensors, expand=True) + neg_const, coeffs = tensors[0], tensors[1:] + dim = sum(d.num_elements for d in real_inputs.values()) + batch_shape = neg_const.shape[:-1] + + info_vec = BlockVector(batch_shape + (dim,)) + precision = BlockMatrix(batch_shape + (dim, dim)) + offset1 = 0 + for i1, (v1, c1) in enumerate(zip(real_inputs, coeffs)): + size1 = real_inputs[v1].num_elements + slice1 = slice(offset1, offset1 + size1) + inputs1, output1 = equations1[i1].split('->') + input11, input12 = inputs1.split(',') + assert input11 == input12 + output1 + info_vec[..., slice1] = torch.einsum( + f'...{input11},...{output1}->...{input12}', c1, neg_const) \ + .reshape(batch_shape + (size1,)) + offset2 = 0 + for i2, (v2, c2) in enumerate(zip(real_inputs, coeffs)): + size2 = real_inputs[v2].num_elements + slice2 = slice(offset2, offset2 + size2) + inputs2, output2 = equations2[i2].split('->') + input21, input22 = inputs2.split(',') + assert input21 == input22 + output2 + precision[..., slice1, slice2] = torch.einsum( + f'...{input11},...{input22}{output1}->...{input12}{input22}', c1, c2) \ + .reshape(batch_shape + (size1, size2)) + offset2 += size2 + offset1 += size1 + + info_vec = info_vec.as_tensor() + precision = precision.as_tensor() + inputs.update(real_inputs) + return log_prob + Gaussian(info_vec, precision, inputs) class Poisson(Distribution): diff --git a/funsor/ops.py b/funsor/ops.py index 15aba813b..66c5a81aa 100644 --- a/funsor/ops.py +++ b/funsor/ops.py @@ -53,6 +53,7 @@ def log_abs_det_jacobian(x, y): raise NotImplementedError +# FIXME Most code assumes this is an AssociativeCommutativeOp. class AssociativeOp(Op): pass @@ -65,7 +66,7 @@ class MulOp(AssociativeOp): pass -class MatmulOp(AssociativeOp): +class MatmulOp(Op): # Associtive but not commutative. pass diff --git a/funsor/torch.py b/funsor/torch.py index e092746a4..ef420dc25 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -470,10 +470,6 @@ def eager_binary_tensor_tensor(op, lhs, rhs): shape = shape[:cut] + (1,) * (lhs_dim - rhs_dim) + shape[cut:] rhs_data = rhs_data.reshape(shape) - print(f"lhs.data.shape = {lhs.data.shape}") - print(f"rhs.data.shape = {rhs.data.shape}") - print(f"lhs_data.shape = {lhs_data.shape}") - print(f"rhs_data.shape = {rhs_data.shape}") data = op(lhs_data, rhs_data) if len(lhs.shape) == 1: data = data.squeeze(-2) diff --git a/test/examples/test_sensor_fusion.py b/test/examples/test_sensor_fusion.py index 2f9433ddd..b4fe638d4 100644 --- a/test/examples/test_sensor_fusion.py +++ b/test/examples/test_sensor_fusion.py @@ -3,6 +3,7 @@ import pytest import torch +import funsor.distributions as dist import funsor.ops as ops from funsor.cnf import Contraction from funsor.domains import bint, reals @@ -13,8 +14,49 @@ from funsor.torch import Tensor +# This version constructs factors using funsor.distributions. +@pytest.mark.parametrize('state_dim,obs_dim', [(3, 2), (2, 3)]) +def test_distributions(state_dim, obs_dim): + data = Tensor(torch.randn(2, obs_dim))["time"] + + bias = Variable("bias", reals(obs_dim)) + bias_dist = dist_to_funsor(random_mvn((), obs_dim))(value=bias) + + prev = Variable("prev", reals(state_dim)) + curr = Variable("curr", reals(state_dim)) + trans_mat = Tensor(torch.eye(state_dim) + 0.1 * torch.randn(state_dim, state_dim)) + trans_mvn = random_mvn((), state_dim) + trans_dist = dist.MultivariateNormal( + loc=trans_mvn.loc, + scale_tril=trans_mvn.scale_tril, + value=curr - prev @ trans_mat) + + state = Variable("state", reals(state_dim)) + obs = Variable("obs", reals(obs_dim)) + obs_mat = Tensor(torch.randn(state_dim, obs_dim)) + obs_mvn = random_mvn((), obs_dim) + obs_dist = dist.MultivariateNormal( + loc=obs_mvn.loc, + scale_tril=obs_mvn.scale_tril, + value=state @ obs_mat + bias - obs) + + log_prob = 0 + log_prob += bias_dist + + state_0 = Variable("state_0", reals(state_dim)) + log_prob += obs_dist(state=state_0, obs=data(time=0)) + + state_1 = Variable("state_1", reals(state_dim)) + log_prob += trans_dist(prev=state_0, curr=state_1) + log_prob += obs_dist(state=state_1, obs=data(time=1)) + + log_prob = log_prob.reduce(ops.logaddexp) + assert isinstance(log_prob, Tensor), log_prob.pretty() + + +# This version constructs factors using funsor.pyro.convert. @pytest.mark.xfail(reason="missing pattern") -def test_end_to_end(): +def test_pyro_convert(): data = Tensor(torch.randn(2, 2), OrderedDict([("time", bint(2))])) bias_dist = dist_to_funsor(random_mvn((), 2)) @@ -44,7 +86,7 @@ def test_end_to_end(): @pytest.mark.xfail(reason="missing pattern") def test_affine_subs(): - # This was recorded from test_end_to_end. + # This was recorded from test_pyro_convert. x = Subs( Gaussian( torch.tensor([1.3027106523513794, 1.4167094230651855, -0.9750942587852478, 0.5321089029312134, -0.9039931297302246], dtype=torch.float32), # noqa diff --git a/test/test_affine.py b/test/test_affine.py index 87cd8484c..83d0895aa 100644 --- a/test/test_affine.py +++ b/test/test_affine.py @@ -80,6 +80,7 @@ def test_affine_subs(expr, expected_type, expected_inputs): @pytest.mark.parametrize('expr', [ "Variable('x', reals()) + 0.5", + "Variable('x', reals(2, 3)) + Variable('y', reals(2, 3))", "Variable('x', reals(2)) + Variable('y', reals(2))", "Variable('x', reals(2)) + torch.ones(2)", "Variable('x', reals(2)) * torch.randn(2)", diff --git a/test/test_distributions.py b/test/test_distributions.py index 2fbe878c9..2a4d992e1 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -7,12 +7,14 @@ import funsor import funsor.distributions as dist -from funsor.cnf import Contraction +from funsor.cnf import Contraction, GaussianMixture from funsor.delta import Delta from funsor.domains import bint, reals -from funsor.terms import Independent, Variable -from funsor.testing import assert_close, check_funsor, random_tensor -from funsor.torch import Tensor +from funsor.interpreter import interpretation, reinterpret +from funsor.pyro.convert import dist_to_funsor +from funsor.terms import Independent, Variable, lazy +from funsor.testing import assert_close, check_funsor, random_mvn, random_tensor +from funsor.torch import Einsum, Tensor @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) @@ -398,7 +400,6 @@ def test_normal_gaussian_3(batch_shape): 'dist.Normal(x - y, scale, 0)', 'dist.Normal(0, scale, y - x)', 'dist.Normal(2 * x - y, scale, x)', - # TODO should we expect these to work without correction terms? 'dist.Normal(0, 1, (x - y) / scale) - scale.log()', 'dist.Normal(2 * y, 2 * scale, 2 * x) + math.log(2)', ] @@ -488,6 +489,75 @@ def test_mvn_gaussian(batch_shape): assert_close(actual, expected, atol=1e-3, rtol=1e-4) +def _check_mvn_affine(d1, data): + assert isinstance(d1, dist.MultivariateNormal) + d2 = reinterpret(d1) + assert issubclass(type(d2), GaussianMixture) + actual = d2(**data) + expected = d1(**data) + assert_close(actual, expected) + + +def test_mvn_affine_one_var(): + x = Variable('x', reals(2)) + data = dict(x=Tensor(torch.randn(2))) + with interpretation(lazy): + d = dist_to_funsor(random_mvn((), 2)) + d = d(value=2 * x + 1) + _check_mvn_affine(d, data) + + +def test_mvn_affine_two_vars(): + x = Variable('x', reals(2)) + y = Variable('y', reals(2)) + data = dict(x=Tensor(torch.randn(2)), y=Tensor(torch.randn(2))) + with interpretation(lazy): + d = dist_to_funsor(random_mvn((), 2)) + d = d(value=x - y) + _check_mvn_affine(d, data) + + +def test_mvn_affine_matmul(): + x = Variable('x', reals(2)) + y = Variable('y', reals(3)) + m = Tensor(torch.randn(2, 3)) + data = dict(x=Tensor(torch.randn(2)), y=Tensor(torch.randn(3))) + with interpretation(lazy): + d = dist_to_funsor(random_mvn((), 3)) + d = d(value=x @ m - y) + _check_mvn_affine(d, data) + + +def test_mvn_affine_einsum(): + c = Tensor(torch.randn(3, 2, 2)) + x = Variable('x', reals(2, 2)) + y = Variable('y', reals()) + data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(()))) + with interpretation(lazy): + d = dist_to_funsor(random_mvn((), 3)) + d = d(value=Einsum("abc,bc->a", c, x) + y) + _check_mvn_affine(d, data) + + +def test_mvn_affine_getitem(): + x = Variable('x', reals(2, 2)) + data = dict(x=Tensor(torch.randn(2, 2))) + with interpretation(lazy): + d = dist_to_funsor(random_mvn((), 2)) + d = d(value=x[0] - x[1]) + _check_mvn_affine(d, data) + + +def test_mvn_affine_reshape(): + x = Variable('x', reals(2, 2)) + y = Variable('y', reals(4)) + data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(4))) + with interpretation(lazy): + d = dist_to_funsor(random_mvn((), 4)) + d = d(value=x.reshape((4,)) - y) + _check_mvn_affine(d, data) + + @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) @pytest.mark.parametrize('syntax', ['eager', 'lazy']) def test_poisson_probs_density(batch_shape, syntax):