From cd694ebf05d03b273673729ada2d55f30fe23a0a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 18 Sep 2019 19:12:02 -0700 Subject: [PATCH 01/31] WIP sketch multivariate pattern recognition --- funsor/cnf.py | 20 ++++++++ funsor/distributions.py | 104 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 114 insertions(+), 10 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index 9c917e6f2..10450b192 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -93,6 +93,26 @@ def _alpha_convert(self, alpha_subs): red_op, bin_op, _, terms = super()._alpha_convert(alpha_subs) return red_op, bin_op, reduced_vars, terms + def extract_affine(self): + """ + Tries to return a pair ``(const, coeffs)`` where const is a funsor with + no real inputs and ``coeffs`` is an OrderedDict mapping Variable to a + ``(coefficient, eqn)`` pair in einsum form, i.e. satisfying:: + + affine = expected.extract_affine() + actual = sum(torch_einsum(eqn, coeff, var) + for var, (coeff, eqn) in affine.items()) + assert_close(actual, expected) + + If any real input appears nonlinearly, this returns None. + """ + const = affine(**{k: 0. for k, v in real_inputs.items()}) + coeffs = OrderedDict() + for c in real_inputs.keys(): + # TODO adapt this univariate code to multivariate setting. + # coeffs[c] = affine(**{k: 1. if c == k else 0. for k in real_inputs.keys()}) - const + raise NotImplementedError("TODO") + @quote.register(Contraction) def _(arg, indent, out): diff --git a/funsor/distributions.py b/funsor/distributions.py index 05c178728..e17fcddf1 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -451,23 +451,107 @@ def eager_mvn(loc, scale_tril, value): # Create a Gaussian from a ground observation. -@eager.register(MultivariateNormal, Tensor, Tensor, Variable) -@eager.register(MultivariateNormal, Variable, Tensor, Tensor) +@eager.register(MultivariateNormal, (Variable, Contraction), Tensor, (Variable, Contraction)) +@eager.register(MultivariateNormal, (Variable, Contraction), Tensor, Tensor) +@eager.register(MultivariateNormal, Tensor, Tensor, (Variable, Contraction)) def eager_mvn(loc, scale_tril, value): - if isinstance(loc, Variable): - loc, value = value, loc + affine = loc - value + const, coeffs = affine.extract_affine() + if not isinstance(const, Tensor): + return None # lazy + if not all(isinstance(coeff, Tensor) for coeff, _ in coeffs.values()): + return None # lazy + + # Dovetail to avoid variable name collision in einsum. + keep = {k: k for k in ',->.'} + equations1 = [''.join(keep.get(c, chr(ord(c) * 2))) + for _, eqn in coeffs.values() for c in eqn] + equations2 = [''.join(keep.get(c, chr(ord(c) * 2 + 1))) + for _, eqn in coeffs.values() for c in eqn] - dim, = loc.output.shape - inputs, (loc, scale_tril) = align_tensors(loc, scale_tril) - inputs.update(value.inputs) - int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') + real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real') + int_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype != 'real') + assert len(real_inputs) == len(coeffs) + + tensors = [scale_tril, const] + [coeffs[k] for k in real_inputs] + inputs, tensors = align_tensors(*tensors) + scale_tril, const, coeffs = tensors[0], tensors[1], tensors[2:] + + # Incorporate scale_tril before broadcasting. + prec_sqrt = cholesky_inverse(scale_tril) + for i in range(len(coeffs)): + coeff = tensors[1 + i] + inputs1, output1 = equations1[i].split('->') + inputs12, _ = inputs1.split(',') + inputs2, output2 = equations2[i].split('->') + inputs22, _ = inputs2.split(',') + coeff = torch.einsum( + f'...{input11},...{output1}{output2}->...{input21}', + coeff, prec_sqrt) + tensors[1 + i] = coeff + + tensors = torch.broadcast_tensors(*tensors) + scale_tril, const, coeffs = tensors[0], tensors[1], tensors[2:] + batch_shape = const.shape + dim = sum(d.num_elements for d in real_inputs.values()) + + loc = BlockVector(batch_shape + (dim,)) + loc[..., 0] = -const / coeffs[0] # FIXME consider directly constructing info_vec + + precision = BlockMatrix(batch_shape + (dim, dim)) + offset1 = 0 + for i1, (v1, c1) in enumerate(zip(real_inputs, coeffs)): + slice1 = slice(offset1, offset1 + real_inputs[v1].num_elements) + input11, input12 = equations1[i1] # FIXME + offset2 = 0 + for i2, (v2, c2) in enumerate(zip(real_inputs, coeffs)): + slice2 = slice(offset2, offset2 + real_inputs[v2].num_elements) + input21, input22 = equations2[i2] # FIXME + precision[..., slice1, slice2] = torch.einsum( + f'...{input11}{input12},...{input21}{input22}->...{input12}{input22}', c1, c2) + offset2 = slice2.stop + offset1 = slice1.stop - precision = cholesky_inverse(scale_tril) + loc = loc.as_tensor() + precision = precision.as_tensor() info_vec = precision.matmul(loc.unsqueeze(-1)).squeeze(-1) + log_prob = (-0.5 * dim * math.log(2 * math.pi) - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1) - 0.5 * (loc * info_vec).sum(-1)) - return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, inputs) + return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, affine.inputs) + + # -------------- + # OLD VERSION + # dim, = affine.output.shape + # inputs, (loc, scale_tril) = align_tensors(loc, scale_tril) + # inputs.update(value.inputs) + # int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') + # + # precision = cholesky_inverse(scale_tril) + # info_vec = precision.matmul(loc.unsqueeze(-1)).squeeze(-1) + # log_prob = (-0.5 * dim * math.log(2 * math.pi) + # - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1) + # - 0.5 * (loc * info_vec).sum(-1)) + # return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, inputs) + + # -------------- + # FUTURE VERSION that defers to Gaussian(value=Contraction(...)). + # + # dim, = affine.output.shape + # int_inputs = scale_tril.inputs + # scale_tril = scale_tril.data + # inputs = int_inputs.copy() + # value_name = gensym("value") + # inputs[value_name] = affine.output + # int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') + # + # precision = cholesky_inverse(scale_tril) + # info_vec = scale_tril.new_zeros(scale_tril.shape[:-1]) + # log_prob = (-0.5 * dim * math.log(2 * math.pi) + # - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1)) + # mvn = Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, inputs) + # return mvn(value=affine) class Poisson(Distribution): From 3c51566381ee54847ea9faef3d0f11a073fc5108 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 19 Sep 2019 14:00:39 -0700 Subject: [PATCH 02/31] WIP more progress on eager_mvn and .extract_affine() --- funsor/cnf.py | 27 ++++++++++----- funsor/distributions.py | 75 +++++++++++++---------------------------- 2 files changed, 42 insertions(+), 60 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index 10450b192..33337b594 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -2,13 +2,15 @@ from functools import reduce from typing import Tuple, Union +import opt_einsum +import torch from multipledispatch.variadic import Variadic import funsor.ops as ops from funsor.delta import Delta from funsor.domains import find_domain from funsor.gaussian import Gaussian -from funsor.interpreter import recursion_reinterpret +from funsor.interpreter import gensym, recursion_reinterpret from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp, nullop from funsor.terms import Align, Binary, Funsor, Number, Reduce, Subs, Unary, Variable, eager, normalize, to_funsor from funsor.torch import Tensor @@ -99,19 +101,28 @@ def extract_affine(self): no real inputs and ``coeffs`` is an OrderedDict mapping Variable to a ``(coefficient, eqn)`` pair in einsum form, i.e. satisfying:: + assert expected.is_affine affine = expected.extract_affine() actual = sum(torch_einsum(eqn, coeff, var) for var, (coeff, eqn) in affine.items()) assert_close(actual, expected) - - If any real input appears nonlinearly, this returns None. """ - const = affine(**{k: 0. for k, v in real_inputs.items()}) + assert self.is_affine + real_inputs = OrderedDict((k, v) for k, v in self.inputs if v.dtype == 'real') coeffs = OrderedDict() - for c in real_inputs.keys(): - # TODO adapt this univariate code to multivariate setting. - # coeffs[c] = affine(**{k: 1. if c == k else 0. for k in real_inputs.keys()}) - const - raise NotImplementedError("TODO") + zeros = {k: Tensor(torch.zeros(v.shape)) for k, v in real_inputs.items()} + const = self(**zeros) + name = gensym('probe') + for k, v in real_inputs.items(): + dim = v.num_elements + var = Variable(name, bint(dim)) + subs = zeros.copy() + subs[k] = Tensor(torch.eye(dim).reshape(dim, *v.shape), ((name, var.output,))) + coeff = Lambda(var, self(**subs) - const).reshape(TODO) + symbols = ''.join(map(opt_einsum.get_symbol, range(1 + len(v.shape)))) + eqn = f"...{symbols},...{symbols[1:]}->TODO" + coeffs[k] = TODO + return const, coeffs @quote.register(Contraction) diff --git a/funsor/distributions.py b/funsor/distributions.py index e17fcddf1..2de478a0d 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -451,11 +451,16 @@ def eager_mvn(loc, scale_tril, value): # Create a Gaussian from a ground observation. +# TODO refactor this logic into Gaussian.eager_subs() and +# here return Gaussian(...scale_tril...)(value=loc-value). @eager.register(MultivariateNormal, (Variable, Contraction), Tensor, (Variable, Contraction)) @eager.register(MultivariateNormal, (Variable, Contraction), Tensor, Tensor) @eager.register(MultivariateNormal, Tensor, Tensor, (Variable, Contraction)) def eager_mvn(loc, scale_tril, value): affine = loc - value + if not isinstance(affine, Contraction): + return None # lazy + const, coeffs = affine.extract_affine() if not isinstance(const, Tensor): return None # lazy @@ -463,35 +468,31 @@ def eager_mvn(loc, scale_tril, value): return None # lazy # Dovetail to avoid variable name collision in einsum. - keep = {k: k for k in ',->.'} - equations1 = [''.join(keep.get(c, chr(ord(c) * 2))) + equations1 = [''.join(c if c in '.,->' else chr(ord(c) * 2)) for _, eqn in coeffs.values() for c in eqn] - equations2 = [''.join(keep.get(c, chr(ord(c) * 2 + 1))) + equations2 = [''.join(c if c in '.,->' else chr(ord(c) * 2 + 1)) for _, eqn in coeffs.values() for c in eqn] real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real') - int_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype != 'real') assert len(real_inputs) == len(coeffs) - - tensors = [scale_tril, const] + [coeffs[k] for k in real_inputs] + tensors = [coeffs[k] for k in real_inputs] + [const, scale_tril] inputs, tensors = align_tensors(*tensors) - scale_tril, const, coeffs = tensors[0], tensors[1], tensors[2:] # Incorporate scale_tril before broadcasting. - prec_sqrt = cholesky_inverse(scale_tril) + scale_tril = tensors[-1] + prec_sqrt = cholesky_inverse(scale_tril) # FIXME is transpose correct? for i in range(len(coeffs)): - coeff = tensors[1 + i] + coeff = tensors[i] inputs1, output1 = equations1[i].split('->') - inputs12, _ = inputs1.split(',') + input11, _ = inputs1.split(',') inputs2, output2 = equations2[i].split('->') - inputs22, _ = inputs2.split(',') - coeff = torch.einsum( - f'...{input11},...{output1}{output2}->...{input21}', - coeff, prec_sqrt) - tensors[1 + i] = coeff + input21, _ = inputs2.split(',') + coeff = torch.einsum(f'...{input11},...{output1}{output2}->...{input21}', + coeff, prec_sqrt) + tensors[i] = coeff tensors = torch.broadcast_tensors(*tensors) - scale_tril, const, coeffs = tensors[0], tensors[1], tensors[2:] + coeffs, const, scale_tril = tensors[:-2], tensors[-2], tensors[-1] batch_shape = const.shape dim = sum(d.num_elements for d in real_inputs.values()) @@ -502,13 +503,14 @@ def eager_mvn(loc, scale_tril, value): offset1 = 0 for i1, (v1, c1) in enumerate(zip(real_inputs, coeffs)): slice1 = slice(offset1, offset1 + real_inputs[v1].num_elements) - input11, input12 = equations1[i1] # FIXME + input11, input12 = equations1[i1].split('->')[0].split(',') offset2 = 0 for i2, (v2, c2) in enumerate(zip(real_inputs, coeffs)): slice2 = slice(offset2, offset2 + real_inputs[v2].num_elements) - input21, input22 = equations2[i2] # FIXME + input21, input22 = equations2[i2].split('->')[0].split(',') precision[..., slice1, slice2] = torch.einsum( - f'...{input11}{input12},...{input21}{input22}->...{input12}{input22}', c1, c2) + f'...{input11}{input12},...{input21}{input22}->...{input12}{input22}', + c1, c2) offset2 = slice2.stop offset1 = slice1.stop @@ -519,39 +521,8 @@ def eager_mvn(loc, scale_tril, value): log_prob = (-0.5 * dim * math.log(2 * math.pi) - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1) - 0.5 * (loc * info_vec).sum(-1)) - return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, affine.inputs) - - # -------------- - # OLD VERSION - # dim, = affine.output.shape - # inputs, (loc, scale_tril) = align_tensors(loc, scale_tril) - # inputs.update(value.inputs) - # int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') - # - # precision = cholesky_inverse(scale_tril) - # info_vec = precision.matmul(loc.unsqueeze(-1)).squeeze(-1) - # log_prob = (-0.5 * dim * math.log(2 * math.pi) - # - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1) - # - 0.5 * (loc * info_vec).sum(-1)) - # return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, inputs) - - # -------------- - # FUTURE VERSION that defers to Gaussian(value=Contraction(...)). - # - # dim, = affine.output.shape - # int_inputs = scale_tril.inputs - # scale_tril = scale_tril.data - # inputs = int_inputs.copy() - # value_name = gensym("value") - # inputs[value_name] = affine.output - # int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') - # - # precision = cholesky_inverse(scale_tril) - # info_vec = scale_tril.new_zeros(scale_tril.shape[:-1]) - # log_prob = (-0.5 * dim * math.log(2 * math.pi) - # - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1)) - # mvn = Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, inputs) - # return mvn(value=affine) + int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') + return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, inputs) class Poisson(Distribution): From c0ec20385475b94a850ce1b46ac81c45118af1cc Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 19 Sep 2019 15:56:07 -0700 Subject: [PATCH 03/31] Add minimal failing tests for sensor fusion example (#247) --- test/examples/test_sensor_fusion.py | 64 +++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 test/examples/test_sensor_fusion.py diff --git a/test/examples/test_sensor_fusion.py b/test/examples/test_sensor_fusion.py new file mode 100644 index 000000000..2f9433ddd --- /dev/null +++ b/test/examples/test_sensor_fusion.py @@ -0,0 +1,64 @@ +from collections import OrderedDict + +import pytest +import torch + +import funsor.ops as ops +from funsor.cnf import Contraction +from funsor.domains import bint, reals +from funsor.gaussian import Gaussian +from funsor.pyro.convert import dist_to_funsor, matrix_and_mvn_to_funsor +from funsor.terms import Subs, Variable +from funsor.testing import random_mvn +from funsor.torch import Tensor + + +@pytest.mark.xfail(reason="missing pattern") +def test_end_to_end(): + data = Tensor(torch.randn(2, 2), OrderedDict([("time", bint(2))])) + + bias_dist = dist_to_funsor(random_mvn((), 2)) + + trans_mat = torch.randn(3, 3) + trans_mvn = random_mvn((), 3) + trans = matrix_and_mvn_to_funsor(trans_mat, trans_mvn, (), "prev", "curr") + + obs_mat = torch.randn(3, 2) + obs_mvn = random_mvn((), 2) + obs = matrix_and_mvn_to_funsor(obs_mat, obs_mvn, (), "state", "obs") + + log_prob = 0 + bias = Variable("bias", reals(2)) + log_prob += bias_dist(value=bias) + + state_0 = Variable("state_0", reals(3)) + log_prob += obs(state=state_0, obs=bias + data(time=0)) + + state_1 = Variable("state_1", reals(3)) + log_prob += trans(prev=state_0, curr=state_1) + log_prob += obs(state=state_1, obs=bias + data(time=1)) + + log_prob = log_prob.reduce(ops.logaddexp) + assert isinstance(log_prob, Tensor), log_prob.pretty() + + +@pytest.mark.xfail(reason="missing pattern") +def test_affine_subs(): + # This was recorded from test_end_to_end. + x = Subs( + Gaussian( + torch.tensor([1.3027106523513794, 1.4167094230651855, -0.9750942587852478, 0.5321089029312134, -0.9039931297302246], dtype=torch.float32), # noqa + torch.tensor([[1.0199567079544067, 0.9840421676635742, -0.473368763923645, 0.34206756949424744, -0.7562517523765564], [0.9840421676635742, 1.511502742767334, -1.7593903541564941, 0.6647964119911194, -0.5119513273239136], [-0.4733688533306122, -1.7593903541564941, 3.2386727333068848, -0.9345928430557251, -0.1534711718559265], [0.34206756949424744, 0.6647964119911194, -0.9345928430557251, 0.3141004145145416, -0.12399007380008698], [-0.7562517523765564, -0.5119513273239136, -0.1534711718559265, -0.12399007380008698, 0.6450173854827881]], dtype=torch.float32), # noqa + (('state_1_b6', + reals(3,),), + ('obs_b2', + reals(2,),),)), + (('obs_b2', + Contraction(ops.nullop, ops.add, + frozenset(), + (Variable('bias_b5', reals(2,)), + Tensor( + torch.tensor([-2.1787893772125244, 0.5684312582015991], dtype=torch.float32), # noqa + (), + 'real'),)),),)) + assert isinstance(x, (Gaussian, Contraction)), x.pretty() From 85dc66acefd89561d44e412b85fed3a17f0cb10e Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 19 Sep 2019 14:04:48 -0700 Subject: [PATCH 04/31] Add ReshapeOp, .shape property, .reshape() method --- funsor/ops.py | 23 +++++++++++++++++++++++ funsor/terms.py | 7 +++++++ 2 files changed, 30 insertions(+) diff --git a/funsor/ops.py b/funsor/ops.py index c756e7e3a..7a3e691c9 100644 --- a/funsor/ops.py +++ b/funsor/ops.py @@ -91,6 +91,28 @@ def nullop(x, y): raise ValueError("should never actually evaluate this!") +class ReshapeMeta(type): + _cache = {} + + def __call__(cls, shape): + shape = tuple(shape) + try: + return ReshapeMeta._cache[shape] + except KeyError: + instance = super().__call__(shape) + ReshapeMeta._cache[shape] = instance + return instance + + +class ReshapeOp(Op, metaclass=ReshapeMeta): + def __init__(self, shape): + self.shape = shape + super().__init__(self._default) + + def _default(self, x): + return x.reshape(self.shape) + + class GetitemMeta(type): _cache = {} @@ -288,6 +310,7 @@ def reciprocal(x): 'PRODUCT_INVERSES', 'ReciprocalOp', 'SubOp', + 'ReshapeOp', 'UNITS', 'abs', 'add', diff --git a/funsor/terms.py b/funsor/terms.py index 7fd4f3d94..582187a77 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -325,6 +325,10 @@ def __init__(self, inputs, output, fresh=None, bound=None): def dtype(self): return self.output.dtype + @property + def shape(self): + return self.output.shape + def __hash__(self): return id(self) @@ -564,6 +568,9 @@ def log1p(self): def sigmoid(self): return Unary(ops.sigmoid, self) + def reshape(self, shape): + return Unary(ops.ReshapeOp(shape), self) + # The following reductions are treated as Unary ops because they # reduce over output shape while preserving all inputs. # To reduce over inputs, instead call .reduce(op, reduced_vars). From 4580d5fc2792f429cba3751169ba76897aa9ce17 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 19 Sep 2019 16:26:22 -0700 Subject: [PATCH 05/31] Add Tensor special case and tests --- funsor/domains.py | 2 ++ funsor/torch.py | 12 +++++++++++- test/test_torch.py | 25 +++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/funsor/domains.py b/funsor/domains.py index 02817b1c5..f43c636de 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -88,6 +88,8 @@ def find_domain(op, *domains): shape = domains[0].shape if op is ops.log or op is ops.exp: dtype = 'real' + elif isinstance(op, ops.ReshapeOp): + shape = op.shape return Domain(shape, dtype) lhs, rhs = domains diff --git a/funsor/torch.py b/funsor/torch.py index 881dfe8cc..a0ea1a823 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -11,7 +11,7 @@ import funsor.ops as ops from funsor.delta import Delta from funsor.domains import Domain, bint, find_domain, reals -from funsor.ops import GetitemOp, Op +from funsor.ops import GetitemOp, Op, ReshapeOp from funsor.terms import ( Binary, Funsor, @@ -19,6 +19,7 @@ Lambda, Number, Slice, + Unary, Variable, eager, substitute, @@ -429,6 +430,15 @@ def eager_binary_tensor_tensor(op, lhs, rhs): return Tensor(data, inputs, dtype) +@eager.register(Unary, ReshapeOp, Tensor) +def eager_reshape_tensor(op, arg): + if arg.shape == op.shape: + return arg + batch_shape = arg.data.shape[:arg.data.dim() - len(arg.shape)] + data = arg.data.reshape(batch_shape + op.shape) + return Tensor(data, arg.inputs, arg.dtype) + + @eager.register(Binary, GetitemOp, Tensor, Number) def eager_getitem_tensor_number(op, lhs, rhs): index = [slice(None)] * (len(lhs.inputs) + op.offset) diff --git a/test/test_torch.py b/test/test_torch.py index 670d048ff..d0047c937 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -352,6 +352,31 @@ def test_binary_scalar_funsor(symbol, dims, scalar): check_funsor(actual, inputs, reals(), expected_data) +@pytest.mark.parametrize("batch_shape", [(), (5,), (4, 3)]) +@pytest.mark.parametrize("old_shape,new_shape", [ + ((), ()), + ((), (1,)), + ((2,), (2, 1)), + ((2,), (1, 2)), + ((6,), (2, 3)), + ((6,), (2, 1, 3)), + ((2, 3, 2), (3, 2, 2)), + ((2, 3, 2), (2, 2, 3)), +]) +def test_reshape(batch_shape, old_shape, new_shape): + inputs = OrderedDict(zip("abc", map(bint, batch_shape))) + old = random_tensor(inputs, reals(*old_shape)) + assert old.reshape(old.shape) is old + + new = old.reshape(new_shape) + assert new.inputs == inputs + assert new.shape == new_shape + assert new.dtype == old.dtype + + old2 = new.reshape(old_shape) + assert_close(old2, old) + + def test_getitem_number_0_inputs(): data = torch.randn((5, 4, 3, 2)) x = Tensor(data) From 708307e560346bf587095766bc9add3e4964c25f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 19 Sep 2019 17:56:53 -0700 Subject: [PATCH 06/31] Start writing a test --- funsor/cnf.py | 44 ++++++++++++++++++++++++++++-------------- test/test_cnf.py | 50 +++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 75 insertions(+), 19 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index 33337b594..05c451cf9 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -8,11 +8,24 @@ import funsor.ops as ops from funsor.delta import Delta -from funsor.domains import find_domain +from funsor.domains import bint, find_domain from funsor.gaussian import Gaussian from funsor.interpreter import gensym, recursion_reinterpret from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp, nullop -from funsor.terms import Align, Binary, Funsor, Number, Reduce, Subs, Unary, Variable, eager, normalize, to_funsor +from funsor.terms import ( + Align, + Binary, + Funsor, + Lambda, + Number, + Reduce, + Subs, + Unary, + Variable, + eager, + normalize, + to_funsor +) from funsor.torch import Tensor from funsor.util import quote @@ -98,17 +111,18 @@ def _alpha_convert(self, alpha_subs): def extract_affine(self): """ Tries to return a pair ``(const, coeffs)`` where const is a funsor with - no real inputs and ``coeffs`` is an OrderedDict mapping Variable to a + no real inputs and ``coeffs`` is an OrderedDict mapping input name to a ``(coefficient, eqn)`` pair in einsum form, i.e. satisfying:: - assert expected.is_affine - affine = expected.extract_affine() - actual = sum(torch_einsum(eqn, coeff, var) - for var, (coeff, eqn) in affine.items()) - assert_close(actual, expected) + x = Contraction(...) + assert x.is_affine + const, coeffs = x.extract_affine() + y = sum(Einsum(eqn, (coeff, Variable(var, coeff.output))) + for var, (coeff, eqn) in coeffs.items()) + assert_close(y, x) """ assert self.is_affine - real_inputs = OrderedDict((k, v) for k, v in self.inputs if v.dtype == 'real') + real_inputs = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype == 'real') coeffs = OrderedDict() zeros = {k: Tensor(torch.zeros(v.shape)) for k, v in real_inputs.items()} const = self(**zeros) @@ -117,11 +131,13 @@ def extract_affine(self): dim = v.num_elements var = Variable(name, bint(dim)) subs = zeros.copy() - subs[k] = Tensor(torch.eye(dim).reshape(dim, *v.shape), ((name, var.output,))) - coeff = Lambda(var, self(**subs) - const).reshape(TODO) - symbols = ''.join(map(opt_einsum.get_symbol, range(1 + len(v.shape)))) - eqn = f"...{symbols},...{symbols[1:]}->TODO" - coeffs[k] = TODO + subs[k] = Tensor(torch.eye(dim).reshape((dim,) + v.shape))[var] + coeff = Lambda(var, self(**subs) - const).reshape(v.shape + const.shape) + inputs1 = ''.join(map(opt_einsum.get_symbol, range(len(coeff.shape)))) + inputs2 = inputs1[:len(v.shape)] + output = inputs1[len(v.shape):] + eqn = f"...{inputs1},...{inputs2}->...{output}" + coeffs[k] = coeff, eqn return const, coeffs diff --git a/test/test_cnf.py b/test/test_cnf.py index 196a692a2..87a431e13 100644 --- a/test/test_cnf.py +++ b/test/test_cnf.py @@ -1,15 +1,22 @@ +from collections import OrderedDict + import pytest -import torch # noqa F403 +import torch from funsor.cnf import Contraction -from funsor.domains import bint # noqa F403 +from funsor.domains import bint, reals from funsor.einsum import einsum, naive_plated_einsum from funsor.interpreter import interpretation, reinterpret -from funsor.terms import Number, eager, normalize, reflect -from funsor.testing import assert_close, check_funsor, make_einsum_example # , xfail_param -from funsor.torch import Tensor +from funsor.terms import Number, Variable, eager, normalize, reflect +from funsor.testing import assert_close, check_funsor, make_einsum_example, random_tensor +from funsor.torch import Einsum, Tensor from funsor.util import quote +assert Variable # flake8 +assert bint # flake8 +assert reals # flake8 +assert torch # flake8 + EINSUM_EXAMPLES = [ ("a,b->", ''), ("ab,a->", ''), @@ -59,3 +66,36 @@ def test_normalize_einsum(equation, plates, backend, einsum_impl): actual = eval(quote(expected)) # requires torch, bint assert_close(actual, expected) + + +@pytest.mark.parametrize('expr', [ + "Variable('x', reals()) + Number(0.5)", + "Variable('x', reals(2)) + Variable('y', reals(2))", + "Variable('x', reals(2)) + Tensor(torch.ones(2))", +]) +def test_extract_affine(expr): + x = eval(expr) + assert isinstance(x, Contraction) + assert x.is_affine + + const, coeffs = x.extract_affine() + assert isinstance(const, Tensor) + assert const.shape == x.shape + assert list(coeffs) == list(x.inputs) + for name, (coeff, eqn) in coeffs.items(): + assert isinstance(name, str) + assert isinstance(coeff, Tensor) + assert isinstance(eqn, str) + + real_inputs = OrderedDict((k, d) for k, d in x.inputs.items() + if d.dtype == 'real') + int_inputs = OrderedDict((k, d) for k, d in x.inputs.items() + if d.dtype != 'real') + subs = {k: random_tensor(int_inputs, d) for k, d in real_inputs.items()} + expected = x(**subs) + assert isinstance(expected, Tensor) + + actual = const + sum(Einsum(eqn, (coeff, subs[k])) + for k, (coeff, eqn) in coeffs.items()) + assert isinstance(actual, Tensor) + assert_close(actual, expected) From fb50652508164a2a24c5cc629227ab440d15ab79 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 19 Sep 2019 18:46:29 -0700 Subject: [PATCH 07/31] Fix bugs in Einsum --- funsor/__init__.py | 2 ++ funsor/affine.py | 44 +++++++++++++++++++++++++++++++++++++++ funsor/cnf.py | 8 ++++++-- funsor/torch.py | 31 ++++++++++++++++++++++++---- test/test_affine.py | 41 +++++++++++++++++++++++++++++++++++-- test/test_cnf.py | 50 +++++---------------------------------------- 6 files changed, 123 insertions(+), 53 deletions(-) create mode 100644 funsor/affine.py diff --git a/funsor/__init__.py b/funsor/__init__.py index 9e41b2501..e831fc958 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -7,6 +7,7 @@ from . import ( adjoint, + affine, cnf, delta, distributions, @@ -38,6 +39,7 @@ 'Tensor', 'Variable', 'adjoint', + 'affine', 'arange', 'backward', 'bint', diff --git a/funsor/affine.py b/funsor/affine.py new file mode 100644 index 000000000..54792ceb4 --- /dev/null +++ b/funsor/affine.py @@ -0,0 +1,44 @@ +from collections import OrderedDict + +import opt_einsum +import torch + +from funsor.interpreter import gensym +from funsor.terms import Lambda, Variable, bint +from funsor.torch import Tensor + + +def extract_affine(fn): + """ + Returns a pair ``(const, coeffs)`` where const is a funsor with no real + inputs and ``coeffs`` is an OrderedDict mapping input name to a + ``(coefficient, eqn)`` pair in einsum form. For funsors that are affine, + this satisfies:: + + x = Contraction(...) + assert x.is_affine + const, coeffs = x.extract_affine() + y = sum(Einsum(eqn, (coeff, Variable(var, coeff.output))) + for var, (coeff, eqn) in coeffs.items()) + assert_close(y, x) + + :param Funsor fn: + """ + # Avoid adding these dependencies to funsor core. + real_inputs = OrderedDict((k, v) for k, v in fn.inputs.items() if v.dtype == 'real') + coeffs = OrderedDict() + zeros = {k: Tensor(torch.zeros(v.shape)) for k, v in real_inputs.items()} + const = fn(**zeros) + name = gensym('probe') + for k, v in real_inputs.items(): + dim = v.num_elements + var = Variable(name, bint(dim)) + subs = zeros.copy() + subs[k] = Tensor(torch.eye(dim).reshape((dim,) + v.shape))[var] + coeff = Lambda(var, fn(**subs) - const).reshape(v.shape + const.shape) + inputs1 = ''.join(map(opt_einsum.get_symbol, range(len(coeff.shape)))) + inputs2 = inputs1[:len(v.shape)] + output = inputs1[len(v.shape):] + eqn = f"{inputs1},{inputs2}->{output}" + coeffs[k] = coeff, eqn + return const, coeffs diff --git a/funsor/cnf.py b/funsor/cnf.py index 05c451cf9..2a386d932 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -33,6 +33,8 @@ class Contraction(Funsor): """ Declarative representation of a finitary sum-product operation + + :ivar bool is_affine: Whether this contraction is affine. """ def __init__(self, red_op, bin_op, reduced_vars, terms): terms = (terms,) if isinstance(terms, Funsor) else terms @@ -110,8 +112,8 @@ def _alpha_convert(self, alpha_subs): def extract_affine(self): """ - Tries to return a pair ``(const, coeffs)`` where const is a funsor with - no real inputs and ``coeffs`` is an OrderedDict mapping input name to a + Returns a pair ``(const, coeffs)`` where const is a funsor with no real + inputs and ``coeffs`` is an OrderedDict mapping input name to a ``(coefficient, eqn)`` pair in einsum form, i.e. satisfying:: x = Contraction(...) @@ -120,6 +122,8 @@ def extract_affine(self): y = sum(Einsum(eqn, (coeff, Variable(var, coeff.output))) for var, (coeff, eqn) in coeffs.items()) assert_close(y, x) + + This only works for affine funsors. Check with :ivar:`.is_affine` """ assert self.is_affine real_inputs = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype == 'real') diff --git a/funsor/torch.py b/funsor/torch.py index a0ea1a823..c71396612 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -1,8 +1,10 @@ import functools +import itertools import warnings from collections import OrderedDict from functools import reduce +import opt_einsum import torch from contextlib2 import contextmanager from multipledispatch import dispatch @@ -750,8 +752,8 @@ def __init__(self, equation, operands): for ein_input, x in zip(ein_inputs, operands): assert x.dtype == 'real' inputs.update(x.inputs) - assert len(ein_inputs) == len(x.output.shape) - for name, size in zip(ein_inputs, x.output.shape): + assert len(ein_input) == len(x.output.shape) + for name, size in zip(ein_input, x.output.shape): other_size = size_dict.setdefault(name, size) if other_size != size: raise ValueError("Size mismatch at {}: {} vs {}" @@ -771,8 +773,29 @@ def __str__(self): @eager.register(Einsum, str, tuple) def eager_einsum(equation, operands): if all(isinstance(x, Tensor) for x in operands): - inputs, tensors = align_tensors(*operands) - data = torch.einsum(equation, tensors) + # Make new symbols for inputs of operands. + inputs = OrderedDict() + for x in operands: + inputs.update(x.inputs) + symbols = set(equation) + get_symbol = iter(map(opt_einsum.get_symbol, itertools.count())) + new_symbols = {} + for k in inputs: + symbol = next(get_symbol) + while symbol in symbols: + symbol = next(get_symbol) + new_symbols[k] = symbol + + # Manually broadcast using einsum symbols. + assert '.' not in equation + ins, out = equation.split('->') + ins = ins.split(',') + ins = [''.join(new_symbols[k] for k in x.inputs) + x_out + for x, x_out in zip(operands, ins)] + out = ''.join(new_symbols[k] for k in inputs) + out + equation = ','.join(ins) + '->' + out + + data = torch.einsum(equation, [x.data for x in operands]) return Tensor(data, inputs) return None # defer to default implementation diff --git a/test/test_affine.py b/test/test_affine.py index 9e09f2fbc..328c08caa 100644 --- a/test/test_affine.py +++ b/test/test_affine.py @@ -3,11 +3,12 @@ import pytest import torch +from funsor.affine import extract_affine from funsor.cnf import Contraction from funsor.domains import bint, reals from funsor.terms import Number, Variable -from funsor.testing import check_funsor -from funsor.torch import Tensor +from funsor.testing import assert_close, check_funsor, random_tensor +from funsor.torch import Einsum, Tensor SMOKE_TESTS = [ ('t+x', Contraction), @@ -75,3 +76,39 @@ def test_affine_subs(expr, expected_type, expected_inputs): assert isinstance(result, expected_type) check_funsor(result, expected_inputs, expected_output) assert result.is_affine + + +@pytest.mark.parametrize('expr', [ + "Variable('x', reals()) + 0.5", + "Variable('x', reals(2)) + Variable('y', reals(2))", + "Variable('x', reals(2)) + torch.ones(2)", + "Variable('x', reals(2)) * torch.randn(2)", + "Variable('x', reals(2)) * torch.randn(2) + torch.ones(2)", + "Einsum('abcd,ac->bd', (Tensor(torch.randn(2, 3, 4, 5)), " + " Variable('x', reals(2, 4))))", +]) +def test_extract_affine(expr): + x = eval(expr) + assert isinstance(x, (Contraction, Einsum)) + + const, coeffs = extract_affine(x) + assert isinstance(const, Tensor) + assert const.shape == x.shape + assert list(coeffs) == list(x.inputs) + for name, (coeff, eqn) in coeffs.items(): + assert isinstance(name, str) + assert isinstance(coeff, Tensor) + assert isinstance(eqn, str) + + real_inputs = OrderedDict((k, d) for k, d in x.inputs.items() + if d.dtype == 'real') + int_inputs = OrderedDict((k, d) for k, d in x.inputs.items() + if d.dtype != 'real') + subs = {k: random_tensor(int_inputs, d) for k, d in real_inputs.items()} + expected = x(**subs) + assert isinstance(expected, Tensor) + + actual = const + sum(Einsum(eqn, (coeff, subs[k])) + for k, (coeff, eqn) in coeffs.items()) + assert isinstance(actual, Tensor) + assert_close(actual, expected) diff --git a/test/test_cnf.py b/test/test_cnf.py index 87a431e13..196a692a2 100644 --- a/test/test_cnf.py +++ b/test/test_cnf.py @@ -1,22 +1,15 @@ -from collections import OrderedDict - import pytest -import torch +import torch # noqa F403 from funsor.cnf import Contraction -from funsor.domains import bint, reals +from funsor.domains import bint # noqa F403 from funsor.einsum import einsum, naive_plated_einsum from funsor.interpreter import interpretation, reinterpret -from funsor.terms import Number, Variable, eager, normalize, reflect -from funsor.testing import assert_close, check_funsor, make_einsum_example, random_tensor -from funsor.torch import Einsum, Tensor +from funsor.terms import Number, eager, normalize, reflect +from funsor.testing import assert_close, check_funsor, make_einsum_example # , xfail_param +from funsor.torch import Tensor from funsor.util import quote -assert Variable # flake8 -assert bint # flake8 -assert reals # flake8 -assert torch # flake8 - EINSUM_EXAMPLES = [ ("a,b->", ''), ("ab,a->", ''), @@ -66,36 +59,3 @@ def test_normalize_einsum(equation, plates, backend, einsum_impl): actual = eval(quote(expected)) # requires torch, bint assert_close(actual, expected) - - -@pytest.mark.parametrize('expr', [ - "Variable('x', reals()) + Number(0.5)", - "Variable('x', reals(2)) + Variable('y', reals(2))", - "Variable('x', reals(2)) + Tensor(torch.ones(2))", -]) -def test_extract_affine(expr): - x = eval(expr) - assert isinstance(x, Contraction) - assert x.is_affine - - const, coeffs = x.extract_affine() - assert isinstance(const, Tensor) - assert const.shape == x.shape - assert list(coeffs) == list(x.inputs) - for name, (coeff, eqn) in coeffs.items(): - assert isinstance(name, str) - assert isinstance(coeff, Tensor) - assert isinstance(eqn, str) - - real_inputs = OrderedDict((k, d) for k, d in x.inputs.items() - if d.dtype == 'real') - int_inputs = OrderedDict((k, d) for k, d in x.inputs.items() - if d.dtype != 'real') - subs = {k: random_tensor(int_inputs, d) for k, d in real_inputs.items()} - expected = x(**subs) - assert isinstance(expected, Tensor) - - actual = const + sum(Einsum(eqn, (coeff, subs[k])) - for k, (coeff, eqn) in coeffs.items()) - assert isinstance(actual, Tensor) - assert_close(actual, expected) From 4ee22e680f0b44c240d21817915c1c2a3d45adc7 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 19 Sep 2019 18:57:19 -0700 Subject: [PATCH 08/31] Add more affine tests --- funsor/affine.py | 20 ++++++++++++-------- test/test_affine.py | 7 +++++-- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/funsor/affine.py b/funsor/affine.py index 54792ceb4..d2b17d8d5 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -10,10 +10,8 @@ def extract_affine(fn): """ - Returns a pair ``(const, coeffs)`` where const is a funsor with no real - inputs and ``coeffs`` is an OrderedDict mapping input name to a - ``(coefficient, eqn)`` pair in einsum form. For funsors that are affine, - this satisfies:: + Extracts an affine representation of a funsor, which is exact for affine + funsors and approximate otherwise. For affine funsors this satisfies:: x = Contraction(...) assert x.is_affine @@ -22,14 +20,20 @@ def extract_affine(fn): for var, (coeff, eqn) in coeffs.items()) assert_close(y, x) - :param Funsor fn: + :param Funsor fn: A funsor. + :return: A pair ``(const, coeffs)`` where const is a funsor with no real + inputs and ``coeffs`` is an OrderedDict mapping input name to a + ``(coefficient, eqn)`` pair in einsum form. + :rtype: tuple """ - # Avoid adding these dependencies to funsor core. + # Determine constant part by evaluating fn at zero. real_inputs = OrderedDict((k, v) for k, v in fn.inputs.items() if v.dtype == 'real') - coeffs = OrderedDict() zeros = {k: Tensor(torch.zeros(v.shape)) for k, v in real_inputs.items()} const = fn(**zeros) + + # Determine linear coefficients by evaluating fn on basis vectors. name = gensym('probe') + coeffs = OrderedDict() for k, v in real_inputs.items(): dim = v.num_elements var = Variable(name, bint(dim)) @@ -39,6 +43,6 @@ def extract_affine(fn): inputs1 = ''.join(map(opt_einsum.get_symbol, range(len(coeff.shape)))) inputs2 = inputs1[:len(v.shape)] output = inputs1[len(v.shape):] - eqn = f"{inputs1},{inputs2}->{output}" + eqn = f'{inputs1},{inputs2}->{output}' coeffs[k] = coeff, eqn return const, coeffs diff --git a/test/test_affine.py b/test/test_affine.py index 328c08caa..f8ffd4aa4 100644 --- a/test/test_affine.py +++ b/test/test_affine.py @@ -84,8 +84,11 @@ def test_affine_subs(expr, expected_type, expected_inputs): "Variable('x', reals(2)) + torch.ones(2)", "Variable('x', reals(2)) * torch.randn(2)", "Variable('x', reals(2)) * torch.randn(2) + torch.ones(2)", - "Einsum('abcd,ac->bd', (Tensor(torch.randn(2, 3, 4, 5)), " - " Variable('x', reals(2, 4))))", + "Einsum('abcd,ac->bd'," + " (Tensor(torch.randn(2, 3, 4, 5)), Variable('x', reals(2, 4))))", + "Tensor(torch.randn(3, 5)) + Einsum('abcd,ac->bd'," + " (Tensor(torch.randn(2, 3, 4, 5)), Variable('x', reals(2, 4))))", + "Variable('x', reals(2, 8))[0] + torch.randn(8)", ]) def test_extract_affine(expr): x = eval(expr) From 155faaa1b1bac3d44fd01463949ca14e7f8c3917 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 19 Sep 2019 19:03:16 -0700 Subject: [PATCH 09/31] Add more tests --- test/test_affine.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_affine.py b/test/test_affine.py index f8ffd4aa4..87cd8484c 100644 --- a/test/test_affine.py +++ b/test/test_affine.py @@ -84,30 +84,30 @@ def test_affine_subs(expr, expected_type, expected_inputs): "Variable('x', reals(2)) + torch.ones(2)", "Variable('x', reals(2)) * torch.randn(2)", "Variable('x', reals(2)) * torch.randn(2) + torch.ones(2)", + "Variable('x', reals(2)) + Tensor(torch.randn(3, 2), OrderedDict(i=bint(3)))", "Einsum('abcd,ac->bd'," " (Tensor(torch.randn(2, 3, 4, 5)), Variable('x', reals(2, 4))))", "Tensor(torch.randn(3, 5)) + Einsum('abcd,ac->bd'," " (Tensor(torch.randn(2, 3, 4, 5)), Variable('x', reals(2, 4))))", "Variable('x', reals(2, 8))[0] + torch.randn(8)", + "Variable('x', reals(2, 8))[Variable('i', bint(2))] / 4 - 3.5", ]) def test_extract_affine(expr): x = eval(expr) assert isinstance(x, (Contraction, Einsum)) + real_inputs = OrderedDict((k, d) for k, d in x.inputs.items() + if d.dtype == 'real') const, coeffs = extract_affine(x) assert isinstance(const, Tensor) assert const.shape == x.shape - assert list(coeffs) == list(x.inputs) + assert list(coeffs) == list(real_inputs) for name, (coeff, eqn) in coeffs.items(): assert isinstance(name, str) assert isinstance(coeff, Tensor) assert isinstance(eqn, str) - real_inputs = OrderedDict((k, d) for k, d in x.inputs.items() - if d.dtype == 'real') - int_inputs = OrderedDict((k, d) for k, d in x.inputs.items() - if d.dtype != 'real') - subs = {k: random_tensor(int_inputs, d) for k, d in real_inputs.items()} + subs = {k: random_tensor(OrderedDict(), d) for k, d in real_inputs.items()} expected = x(**subs) assert isinstance(expected, Tensor) From 07b3da66ecf5499f43e0a6da8997995d00f71a86 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 19 Sep 2019 19:07:18 -0700 Subject: [PATCH 10/31] Revert changes to cnf.py --- funsor/cnf.py | 55 +++-------------------------------------- funsor/distributions.py | 3 ++- 2 files changed, 5 insertions(+), 53 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index 2a386d932..633e2bde9 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -2,30 +2,15 @@ from functools import reduce from typing import Tuple, Union -import opt_einsum -import torch from multipledispatch.variadic import Variadic import funsor.ops as ops from funsor.delta import Delta -from funsor.domains import bint, find_domain +from funsor.domains import find_domain from funsor.gaussian import Gaussian -from funsor.interpreter import gensym, recursion_reinterpret +from funsor.interpreter import recursion_reinterpret from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp, nullop -from funsor.terms import ( - Align, - Binary, - Funsor, - Lambda, - Number, - Reduce, - Subs, - Unary, - Variable, - eager, - normalize, - to_funsor -) +from funsor.terms import Align, Binary, Funsor, Number, Reduce, Subs, Unary, Variable, eager, normalize, to_funsor from funsor.torch import Tensor from funsor.util import quote @@ -110,40 +95,6 @@ def _alpha_convert(self, alpha_subs): red_op, bin_op, _, terms = super()._alpha_convert(alpha_subs) return red_op, bin_op, reduced_vars, terms - def extract_affine(self): - """ - Returns a pair ``(const, coeffs)`` where const is a funsor with no real - inputs and ``coeffs`` is an OrderedDict mapping input name to a - ``(coefficient, eqn)`` pair in einsum form, i.e. satisfying:: - - x = Contraction(...) - assert x.is_affine - const, coeffs = x.extract_affine() - y = sum(Einsum(eqn, (coeff, Variable(var, coeff.output))) - for var, (coeff, eqn) in coeffs.items()) - assert_close(y, x) - - This only works for affine funsors. Check with :ivar:`.is_affine` - """ - assert self.is_affine - real_inputs = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype == 'real') - coeffs = OrderedDict() - zeros = {k: Tensor(torch.zeros(v.shape)) for k, v in real_inputs.items()} - const = self(**zeros) - name = gensym('probe') - for k, v in real_inputs.items(): - dim = v.num_elements - var = Variable(name, bint(dim)) - subs = zeros.copy() - subs[k] = Tensor(torch.eye(dim).reshape((dim,) + v.shape))[var] - coeff = Lambda(var, self(**subs) - const).reshape(v.shape + const.shape) - inputs1 = ''.join(map(opt_einsum.get_symbol, range(len(coeff.shape)))) - inputs2 = inputs1[:len(v.shape)] - output = inputs1[len(v.shape):] - eqn = f"...{inputs1},...{inputs2}->...{output}" - coeffs[k] = coeff, eqn - return const, coeffs - @quote.register(Contraction) def _(arg, indent, out): diff --git a/funsor/distributions.py b/funsor/distributions.py index 2de478a0d..98cea6b2e 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -7,6 +7,7 @@ import funsor.delta import funsor.ops as ops +from funsor.affine import extract_affine from funsor.cnf import Contraction from funsor.domains import bint, reals from funsor.gaussian import BlockMatrix, BlockVector, Gaussian, cholesky_inverse @@ -461,7 +462,7 @@ def eager_mvn(loc, scale_tril, value): if not isinstance(affine, Contraction): return None # lazy - const, coeffs = affine.extract_affine() + const, coeffs = extract_affine(affine) if not isinstance(const, Tensor): return None # lazy if not all(isinstance(coeff, Tensor) for coeff, _ in coeffs.values()): From 279d7a91ac1f71b621c0c56fa2d691b50bca2fa6 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 19 Sep 2019 19:09:10 -0700 Subject: [PATCH 11/31] Implement affine funsor approximation --- funsor/__init__.py | 2 ++ funsor/affine.py | 48 +++++++++++++++++++++++++++++++++++++++++++++ funsor/torch.py | 31 +++++++++++++++++++++++++---- test/test_affine.py | 44 +++++++++++++++++++++++++++++++++++++++-- 4 files changed, 119 insertions(+), 6 deletions(-) create mode 100644 funsor/affine.py diff --git a/funsor/__init__.py b/funsor/__init__.py index 9e41b2501..e831fc958 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -7,6 +7,7 @@ from . import ( adjoint, + affine, cnf, delta, distributions, @@ -38,6 +39,7 @@ 'Tensor', 'Variable', 'adjoint', + 'affine', 'arange', 'backward', 'bint', diff --git a/funsor/affine.py b/funsor/affine.py new file mode 100644 index 000000000..d2b17d8d5 --- /dev/null +++ b/funsor/affine.py @@ -0,0 +1,48 @@ +from collections import OrderedDict + +import opt_einsum +import torch + +from funsor.interpreter import gensym +from funsor.terms import Lambda, Variable, bint +from funsor.torch import Tensor + + +def extract_affine(fn): + """ + Extracts an affine representation of a funsor, which is exact for affine + funsors and approximate otherwise. For affine funsors this satisfies:: + + x = Contraction(...) + assert x.is_affine + const, coeffs = x.extract_affine() + y = sum(Einsum(eqn, (coeff, Variable(var, coeff.output))) + for var, (coeff, eqn) in coeffs.items()) + assert_close(y, x) + + :param Funsor fn: A funsor. + :return: A pair ``(const, coeffs)`` where const is a funsor with no real + inputs and ``coeffs`` is an OrderedDict mapping input name to a + ``(coefficient, eqn)`` pair in einsum form. + :rtype: tuple + """ + # Determine constant part by evaluating fn at zero. + real_inputs = OrderedDict((k, v) for k, v in fn.inputs.items() if v.dtype == 'real') + zeros = {k: Tensor(torch.zeros(v.shape)) for k, v in real_inputs.items()} + const = fn(**zeros) + + # Determine linear coefficients by evaluating fn on basis vectors. + name = gensym('probe') + coeffs = OrderedDict() + for k, v in real_inputs.items(): + dim = v.num_elements + var = Variable(name, bint(dim)) + subs = zeros.copy() + subs[k] = Tensor(torch.eye(dim).reshape((dim,) + v.shape))[var] + coeff = Lambda(var, fn(**subs) - const).reshape(v.shape + const.shape) + inputs1 = ''.join(map(opt_einsum.get_symbol, range(len(coeff.shape)))) + inputs2 = inputs1[:len(v.shape)] + output = inputs1[len(v.shape):] + eqn = f'{inputs1},{inputs2}->{output}' + coeffs[k] = coeff, eqn + return const, coeffs diff --git a/funsor/torch.py b/funsor/torch.py index a0ea1a823..c71396612 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -1,8 +1,10 @@ import functools +import itertools import warnings from collections import OrderedDict from functools import reduce +import opt_einsum import torch from contextlib2 import contextmanager from multipledispatch import dispatch @@ -750,8 +752,8 @@ def __init__(self, equation, operands): for ein_input, x in zip(ein_inputs, operands): assert x.dtype == 'real' inputs.update(x.inputs) - assert len(ein_inputs) == len(x.output.shape) - for name, size in zip(ein_inputs, x.output.shape): + assert len(ein_input) == len(x.output.shape) + for name, size in zip(ein_input, x.output.shape): other_size = size_dict.setdefault(name, size) if other_size != size: raise ValueError("Size mismatch at {}: {} vs {}" @@ -771,8 +773,29 @@ def __str__(self): @eager.register(Einsum, str, tuple) def eager_einsum(equation, operands): if all(isinstance(x, Tensor) for x in operands): - inputs, tensors = align_tensors(*operands) - data = torch.einsum(equation, tensors) + # Make new symbols for inputs of operands. + inputs = OrderedDict() + for x in operands: + inputs.update(x.inputs) + symbols = set(equation) + get_symbol = iter(map(opt_einsum.get_symbol, itertools.count())) + new_symbols = {} + for k in inputs: + symbol = next(get_symbol) + while symbol in symbols: + symbol = next(get_symbol) + new_symbols[k] = symbol + + # Manually broadcast using einsum symbols. + assert '.' not in equation + ins, out = equation.split('->') + ins = ins.split(',') + ins = [''.join(new_symbols[k] for k in x.inputs) + x_out + for x, x_out in zip(operands, ins)] + out = ''.join(new_symbols[k] for k in inputs) + out + equation = ','.join(ins) + '->' + out + + data = torch.einsum(equation, [x.data for x in operands]) return Tensor(data, inputs) return None # defer to default implementation diff --git a/test/test_affine.py b/test/test_affine.py index 9e09f2fbc..87cd8484c 100644 --- a/test/test_affine.py +++ b/test/test_affine.py @@ -3,11 +3,12 @@ import pytest import torch +from funsor.affine import extract_affine from funsor.cnf import Contraction from funsor.domains import bint, reals from funsor.terms import Number, Variable -from funsor.testing import check_funsor -from funsor.torch import Tensor +from funsor.testing import assert_close, check_funsor, random_tensor +from funsor.torch import Einsum, Tensor SMOKE_TESTS = [ ('t+x', Contraction), @@ -75,3 +76,42 @@ def test_affine_subs(expr, expected_type, expected_inputs): assert isinstance(result, expected_type) check_funsor(result, expected_inputs, expected_output) assert result.is_affine + + +@pytest.mark.parametrize('expr', [ + "Variable('x', reals()) + 0.5", + "Variable('x', reals(2)) + Variable('y', reals(2))", + "Variable('x', reals(2)) + torch.ones(2)", + "Variable('x', reals(2)) * torch.randn(2)", + "Variable('x', reals(2)) * torch.randn(2) + torch.ones(2)", + "Variable('x', reals(2)) + Tensor(torch.randn(3, 2), OrderedDict(i=bint(3)))", + "Einsum('abcd,ac->bd'," + " (Tensor(torch.randn(2, 3, 4, 5)), Variable('x', reals(2, 4))))", + "Tensor(torch.randn(3, 5)) + Einsum('abcd,ac->bd'," + " (Tensor(torch.randn(2, 3, 4, 5)), Variable('x', reals(2, 4))))", + "Variable('x', reals(2, 8))[0] + torch.randn(8)", + "Variable('x', reals(2, 8))[Variable('i', bint(2))] / 4 - 3.5", +]) +def test_extract_affine(expr): + x = eval(expr) + assert isinstance(x, (Contraction, Einsum)) + real_inputs = OrderedDict((k, d) for k, d in x.inputs.items() + if d.dtype == 'real') + + const, coeffs = extract_affine(x) + assert isinstance(const, Tensor) + assert const.shape == x.shape + assert list(coeffs) == list(real_inputs) + for name, (coeff, eqn) in coeffs.items(): + assert isinstance(name, str) + assert isinstance(coeff, Tensor) + assert isinstance(eqn, str) + + subs = {k: random_tensor(OrderedDict(), d) for k, d in real_inputs.items()} + expected = x(**subs) + assert isinstance(expected, Tensor) + + actual = const + sum(Einsum(eqn, (coeff, subs[k])) + for k, (coeff, eqn) in coeffs.items()) + assert isinstance(actual, Tensor) + assert_close(actual, expected) From 0a834714a5a9cabc6729ae31a55a0349ca6e3ad6 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 19 Sep 2019 19:36:35 -0700 Subject: [PATCH 12/31] Add test for Einsum batching --- funsor/torch.py | 1 + test/test_torch.py | 29 +++++++++++++++++++++++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/funsor/torch.py b/funsor/torch.py index c71396612..2329438b7 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -784,6 +784,7 @@ def eager_einsum(equation, operands): symbol = next(get_symbol) while symbol in symbols: symbol = next(get_symbol) + symbols.add(symbol) new_symbols[k] = symbol # Manually broadcast using einsum symbols. diff --git a/test/test_torch.py b/test/test_torch.py index d0047c937..ec27094a8 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -675,7 +675,7 @@ def test_align(): assert x(i=i, j=j, k=k) == y(i=i, j=j, k=k) -@pytest.mark.parametrize('equation', [ +EINSUM_EXAMPLES = [ 'a->a', 'a,a->a', 'a,b->', @@ -689,7 +689,10 @@ def test_align(): 'ab,ba->ab', 'ab,ba->ba', 'ab,bc->ac', -]) +] + + +@pytest.mark.parametrize('equation', EINSUM_EXAMPLES) def test_einsum(equation): sizes = dict(a=2, b=3, c=4) inputs, outputs = equation.split('->') @@ -701,6 +704,28 @@ def test_einsum(equation): assert_close(actual, expected, atol=1e-5, rtol=None) +@pytest.mark.parametrize('equation', EINSUM_EXAMPLES) +@pytest.mark.parametrize('batch1', ['', 'i', 'j', 'ij']) +@pytest.mark.parametrize('batch2', ['', 'i', 'j', 'ij']) +def test_batched_einsum(equation, batch1, batch2): + inputs, output = equation.split('->') + inputs = inputs.split(',') + + sizes = dict(a=2, b=3, c=4, i=5, j=6) + batch1 = OrderedDict([(k, bint(sizes[k])) for k in batch1]) + batch2 = OrderedDict([(k, bint(sizes[k])) for k in batch2]) + funsors = [random_tensor(batch, reals(*(sizes[d] for d in dims))) + for batch, dims in zip([batch1, batch2], inputs)] + actual = Einsum(equation, tuple(funsors)) + + _equation = ','.join('...' + i for i in inputs) + '->...' + output + inputs, tensors = align_tensors(*funsors) + batch = tuple(v.size for v in inputs.values()) + tensors = [x.expand(batch + f.shape) for (x, f) in zip(tensors, funsors)] + expected = Tensor(torch.einsum(_equation, tensors), inputs) + assert_close(actual, expected, atol=1e-5, rtol=None) + + @pytest.mark.parametrize('y_shape', [(), (4,), (4, 5)], ids=str) @pytest.mark.parametrize('xy_shape', [(), (6,), (6, 7)], ids=str) @pytest.mark.parametrize('x_shape', [(), (2,), (2, 3)], ids=str) From d9138fc9eed87dd6965281ff7759edae7e0bf654 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 19 Sep 2019 19:54:06 -0700 Subject: [PATCH 13/31] Fix docs --- funsor/affine.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/funsor/affine.py b/funsor/affine.py index d2b17d8d5..18cf52c01 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -13,14 +13,13 @@ def extract_affine(fn): Extracts an affine representation of a funsor, which is exact for affine funsors and approximate otherwise. For affine funsors this satisfies:: - x = Contraction(...) - assert x.is_affine - const, coeffs = x.extract_affine() + x = ... + const, coeffs = extract_affine(x) y = sum(Einsum(eqn, (coeff, Variable(var, coeff.output))) for var, (coeff, eqn) in coeffs.items()) assert_close(y, x) - :param Funsor fn: A funsor. + :param Funsor fn: A funsor assumed to be affine (this is not checked). :return: A pair ``(const, coeffs)`` where const is a funsor with no real inputs and ``coeffs`` is an OrderedDict mapping input name to a ``(coefficient, eqn)`` pair in einsum form. From ad70e1d0c7fc797c9d322c7be5811c6b90818af8 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 20 Sep 2019 08:24:19 -0700 Subject: [PATCH 14/31] Address review comments --- funsor/affine.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/funsor/affine.py b/funsor/affine.py index 18cf52c01..fc1e46a3d 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -19,7 +19,12 @@ def extract_affine(fn): for var, (coeff, eqn) in coeffs.items()) assert_close(y, x) - :param Funsor fn: A funsor assumed to be affine (this is not checked). + The affine approximation is computed by ev evaluating ``fn`` at + zero and each basis vector. To improve performance, users may want to run + under the :func:`~funsor.memoize.memoize` interpretation. + + :param Funsor fn: A funsor assumed to be affine wrt the (add,mul) semiring. + The affine assumption is not checked. :return: A pair ``(const, coeffs)`` where const is a funsor with no real inputs and ``coeffs`` is an OrderedDict mapping input name to a ``(coefficient, eqn)`` pair in einsum form. From c7d10c62facddb8f307a32701a86b05741949ab7 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 20 Sep 2019 08:56:43 -0700 Subject: [PATCH 15/31] Add sensor fusion test using dist.MultivariateNormal --- test/examples/test_sensor_fusion.py | 45 +++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/test/examples/test_sensor_fusion.py b/test/examples/test_sensor_fusion.py index 2f9433ddd..5b47d1edc 100644 --- a/test/examples/test_sensor_fusion.py +++ b/test/examples/test_sensor_fusion.py @@ -3,6 +3,7 @@ import pytest import torch +import funsor.distributions as dist import funsor.ops as ops from funsor.cnf import Contraction from funsor.domains import bint, reals @@ -13,8 +14,48 @@ from funsor.torch import Tensor +# This version constructs factors using funsor.distributions. +def test_distributions(): + data = Tensor(torch.randn(2, 2), OrderedDict([("time", bint(2))])) + + bias = Variable("bias", reals(2)) + bias_dist = dist_to_funsor(random_mvn((), 2))(value=bias) + + prev = Variable("prev", reals(3)) + curr = Variable("curr", reals(3)) + trans_mat = Tensor(torch.randn(3, 3)) + trans_mvn = random_mvn((), 3) + trans_dist = dist.MultivariateNormal( + loc=trans_mvn.loc, + scale_tril=trans_mvn.scale_tril, + value=curr - prev @ trans_mat) + + state = Variable("state", reals(3)) + obs = Variable("obs", reals(2)) + obs_mat = Tensor(torch.randn(3, 2)) + obs_mvn = random_mvn((), 2) + obs_dist = dist.MultivariateNormal( + loc=obs_mvn.loc, + scale_tril=obs_mvn.scale_tril, + value=state @ obs_mat + bias - obs) + + log_prob = 0 + log_prob += bias_dist + + state_0 = Variable("state_0", reals(3)) + log_prob += obs_dist(state=state_0, obs=data(time=0)) + + state_1 = Variable("state_1", reals(3)) + log_prob += trans_dist(prev=state_0, curr=state_1) + log_prob += obs_dist(state=state_1, obs=data(time=1)) + + log_prob = log_prob.reduce(ops.logaddexp) + assert isinstance(log_prob, Tensor), log_prob.pretty() + + +# This version constructs factors using funsor.pyro.convert. @pytest.mark.xfail(reason="missing pattern") -def test_end_to_end(): +def test_pyro_convert(): data = Tensor(torch.randn(2, 2), OrderedDict([("time", bint(2))])) bias_dist = dist_to_funsor(random_mvn((), 2)) @@ -44,7 +85,7 @@ def test_end_to_end(): @pytest.mark.xfail(reason="missing pattern") def test_affine_subs(): - # This was recorded from test_end_to_end. + # This was recorded from test_pyro_convert. x = Subs( Gaussian( torch.tensor([1.3027106523513794, 1.4167094230651855, -0.9750942587852478, 0.5321089029312134, -0.9039931297302246], dtype=torch.float32), # noqa From 8d57895c23ed58b011ef1486944c4512459d6ccb Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 20 Sep 2019 11:36:02 -0700 Subject: [PATCH 16/31] WIP attempt to fix eager_mvn --- funsor/distributions.py | 40 +++++++++++++++++----------------------- test/test_affine.py | 1 + 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/funsor/distributions.py b/funsor/distributions.py index 98cea6b2e..8d2dd04ce 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -10,7 +10,7 @@ from funsor.affine import extract_affine from funsor.cnf import Contraction from funsor.domains import bint, reals -from funsor.gaussian import BlockMatrix, BlockVector, Gaussian, cholesky_inverse +from funsor.gaussian import BlockMatrix, BlockVector, Gaussian from funsor.interpreter import interpretation from funsor.terms import Funsor, FunsorMeta, Number, Variable, eager, lazy, to_funsor from funsor.torch import Tensor, align_tensors, ignore_jit_warnings, materialize, torch_stack @@ -369,8 +369,7 @@ def eager_normal(loc, scale, value): if isinstance(loc, Variable): loc, value = value, loc - inputs, (loc, scale) = align_tensors(loc, scale) - loc, scale = torch.broadcast_tensors(loc, scale) + inputs, (loc, scale) = align_tensors(loc, scale, expand=True) inputs.update(value.inputs) int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') @@ -408,8 +407,7 @@ def eager_normal(loc, scale, value): coeffs[c] = affine(**{k: 1. if c == k else 0. for k in real_inputs.keys()}) - const tensors = [const] + list(coeffs.values()) - inputs, tensors = align_tensors(*tensors) - tensors = torch.broadcast_tensors(*tensors) + inputs, tensors = align_tensors(*tensors, expand=True) const, coeffs = tensors[0], tensors[1:] dim = sum(d.num_elements for d in real_inputs.values()) @@ -469,30 +467,26 @@ def eager_mvn(loc, scale_tril, value): return None # lazy # Dovetail to avoid variable name collision in einsum. - equations1 = [''.join(c if c in '.,->' else chr(ord(c) * 2)) + equations1 = [''.join(c if c in ',->' else chr(ord(c) * 2)) for _, eqn in coeffs.values() for c in eqn] - equations2 = [''.join(c if c in '.,->' else chr(ord(c) * 2 + 1)) + equations2 = [''.join(c if c in ',->' else chr(ord(c) * 2 + 1)) for _, eqn in coeffs.values() for c in eqn] real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real') - assert len(real_inputs) == len(coeffs) - tensors = [coeffs[k] for k in real_inputs] + [const, scale_tril] - inputs, tensors = align_tensors(*tensors) + assert tuple(real_inputs) == tuple(coeffs) # Incorporate scale_tril before broadcasting. - scale_tril = tensors[-1] - prec_sqrt = cholesky_inverse(scale_tril) # FIXME is transpose correct? - for i in range(len(coeffs)): - coeff = tensors[i] - inputs1, output1 = equations1[i].split('->') - input11, _ = inputs1.split(',') - inputs2, output2 = equations2[i].split('->') - input21, _ = inputs2.split(',') - coeff = torch.einsum(f'...{input11},...{output1}{output2}->...{input21}', - coeff, prec_sqrt) - tensors[i] = coeff - - tensors = torch.broadcast_tensors(*tensors) + eye = torch.eye(scale_tril.data.size(-1)).expand(scale_tril.data.shape) + prec_sqrt = Tensor(eye.triangular_solve(scale_tril.data, upper=False).solution, + scale_tril.inputs) + tensors = [] + for k, (coeff, eqn) in coeffs.items(): + shape = real_inputs[k].shape + size = real_inputs[k].num_elements + tensors.append((prec_sqrt @ coeff.reshape((size, size))).reshape(shape + shape)) + + tensors.extend([const, scale_tril]) + inputs, tensors = align_tensors(*tensors, expand=True) coeffs, const, scale_tril = tensors[:-2], tensors[-2], tensors[-1] batch_shape = const.shape dim = sum(d.num_elements for d in real_inputs.values()) diff --git a/test/test_affine.py b/test/test_affine.py index 87cd8484c..83d0895aa 100644 --- a/test/test_affine.py +++ b/test/test_affine.py @@ -80,6 +80,7 @@ def test_affine_subs(expr, expected_type, expected_inputs): @pytest.mark.parametrize('expr', [ "Variable('x', reals()) + 0.5", + "Variable('x', reals(2, 3)) + Variable('y', reals(2, 3))", "Variable('x', reals(2)) + Variable('y', reals(2))", "Variable('x', reals(2)) + torch.ones(2)", "Variable('x', reals(2)) * torch.randn(2)", From 4f55bd65286c2c7a07ff6f868091e7ab67f1db50 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 20 Sep 2019 11:37:03 -0700 Subject: [PATCH 17/31] Add ops.matmul --- funsor/domains.py | 12 +++++++++ funsor/ops.py | 6 +++++ funsor/terms.py | 6 +++++ funsor/torch.py | 64 +++++++++++++++++++++++++++++++++++++++++----- test/test_torch.py | 20 +++++++++++++++ 5 files changed, 101 insertions(+), 7 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index f43c636de..267906254 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -97,6 +97,18 @@ def find_domain(op, *domains): dtype = lhs.dtype shape = lhs.shape[:op.offset] + lhs.shape[1 + op.offset:] return Domain(shape, dtype) + elif op == ops.matmul: + assert lhs.shape and rhs.shape + if len(rhs.shape) == 1: + assert lhs.shape[-1] == rhs.shape[-1] + shape = lhs.shape[:-1] + elif len(lhs.shape) == 1: + assert lhs.shape[-1] == rhs.shape[-2] + shape = rhs.shape[:-2] + rhs.shape[-1:] + else: + assert lhs.shape[-1] == rhs.shape[-2] + shape = broadcast_shape(lhs.shape[:-1], rhs.shape[:-2] + (1,)) + rhs.shape[-1:] + return Domain(shape, 'real') if lhs.dtype == 'real' or rhs.dtype == 'real': dtype = 'real' diff --git a/funsor/ops.py b/funsor/ops.py index 7a3e691c9..15aba813b 100644 --- a/funsor/ops.py +++ b/funsor/ops.py @@ -65,6 +65,10 @@ class MulOp(AssociativeOp): pass +class MatmulOp(AssociativeOp): + pass + + class LogAddExpOp(AssociativeOp): pass @@ -157,6 +161,7 @@ def _default(self, x, y): add = AddOp(operator.add) and_ = AssociativeOp(operator.and_) mul = MulOp(operator.mul) +matmul = MatmulOp(operator.matmul) or_ = AssociativeOp(operator.or_) xor = AssociativeOp(operator.xor) @@ -325,6 +330,7 @@ def reciprocal(x): 'log', 'log1p', 'lt', + 'matmul', 'max', 'min', 'mul', diff --git a/funsor/terms.py b/funsor/terms.py index 582187a77..7b0cecf3b 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -620,6 +620,12 @@ def __truediv__(self, other): def __rtruediv__(self, other): return Binary(ops.truediv, to_funsor(other), self) + def __matmul__(self, other): + return Binary(ops.matmul, self, to_funsor(other)) + + def __rmatmul__(self, other): + return Binary(ops.matmul, to_funsor(other), self) + def __pow__(self, other): return Binary(ops.pow, self, to_funsor(other)) diff --git a/funsor/torch.py b/funsor/torch.py index 2329438b7..e092746a4 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -13,7 +13,7 @@ import funsor.ops as ops from funsor.delta import Delta from funsor.domains import Domain, bint, find_domain, reals -from funsor.ops import GetitemOp, Op, ReshapeOp +from funsor.ops import GetitemOp, MatmulOp, Op, ReshapeOp from funsor.terms import ( Binary, Funsor, @@ -50,13 +50,15 @@ def _(x, indent, out): out.append((indent, f"torch.tensor({repr(x.tolist())}, dtype={x.dtype})")) -def align_tensor(new_inputs, x): +def align_tensor(new_inputs, x, expand=False): r""" - Permute and expand a tensor to match desired ``new_inputs``. + Permute and add dims to a tensor to match desired ``new_inputs``. :param OrderedDict new_inputs: A target set of inputs. :param funsor.terms.Funsor x: A :class:`Tensor` or :class:`~funsor.terms.Number` . + :param bool expand: If False (default), set result size to 1 for any input + of ``x`` not in ``new_inputs``; if True expand to ``new_inputs` size. :return: a number or :class:`torch.Tensor` that can be broadcast to other tensors with inputs ``new_inputs``. :rtype: tuple @@ -81,10 +83,14 @@ def align_tensor(new_inputs, x): # Unsquash multivariate input dims by filling in ones. data = data.reshape(tuple(old_inputs[k].dtype if k in old_inputs else 1 for k in new_inputs) + x.output.shape) + + # Optionally expand new dims. + if expand: + data = data.expand(tuple(d.dtype for d in new_inputs.values()) + x.output.shape) return data -def align_tensors(*args): +def align_tensors(*args, **kwargs): r""" Permute multiple tensors before applying a broadcasted op. @@ -92,15 +98,18 @@ def align_tensors(*args): :param funsor.terms.Funsor \*args: Multiple :class:`Tensor` s and :class:`~funsor.terms.Number` s. + :param bool expand: Whether to expand input tensors. Defaults to False. :return: a pair ``(inputs, tensors)`` where tensors are all :class:`torch.Tensor` s that can be broadcast together to a single data with given ``inputs``. :rtype: tuple """ + expand = kwargs.pop('expand', False) + assert not kwargs inputs = OrderedDict() for x in args: inputs.update(x.inputs) - tensors = [align_tensor(inputs, x) for x in args] + tensors = [align_tensor(inputs, x, expand=expand) for x in args] return inputs, tensors @@ -415,8 +424,41 @@ def eager_binary_tensor_tensor(op, lhs, rhs): # Reshape to support broadcasting of output shape. if inputs: - lhs_dim = len(lhs.output.shape) - rhs_dim = len(rhs.output.shape) + lhs_dim = len(lhs.shape) + rhs_dim = len(rhs.shape) + if lhs_dim < rhs_dim: + cut = lhs_data.dim() - lhs_dim + shape = lhs_data.shape + shape = shape[:cut] + (1,) * (rhs_dim - lhs_dim) + shape[cut:] + lhs_data = lhs_data.reshape(shape) + elif rhs_dim < lhs_dim: + cut = rhs_data.dim() - rhs_dim + shape = rhs_data.shape + shape = shape[:cut] + (1,) * (lhs_dim - rhs_dim) + shape[cut:] + rhs_data = rhs_data.reshape(shape) + + data = op(lhs_data, rhs_data) + return Tensor(data, inputs, dtype) + + +@eager.register(Binary, MatmulOp, Tensor, Tensor) +def eager_binary_tensor_tensor(op, lhs, rhs): + # Compute inputs and outputs. + dtype = find_domain(op, lhs.output, rhs.output).dtype + if lhs.inputs == rhs.inputs: + inputs = lhs.inputs + lhs_data, rhs_data = lhs.data, rhs.data + else: + inputs, (lhs_data, rhs_data) = align_tensors(lhs, rhs) + if len(lhs.shape) == 1: + lhs_data = lhs_data.unsqueeze(-2) + if len(rhs.shape) == 1: + rhs_data = rhs_data.unsqueeze(-1) + + # Reshape to support broadcasting of output shape. + if inputs: + lhs_dim = max(2, len(lhs.shape)) + rhs_dim = max(2, len(rhs.shape)) if lhs_dim < rhs_dim: cut = lhs_data.dim() - lhs_dim shape = lhs_data.shape @@ -428,7 +470,15 @@ def eager_binary_tensor_tensor(op, lhs, rhs): shape = shape[:cut] + (1,) * (lhs_dim - rhs_dim) + shape[cut:] rhs_data = rhs_data.reshape(shape) + print(f"lhs.data.shape = {lhs.data.shape}") + print(f"rhs.data.shape = {rhs.data.shape}") + print(f"lhs_data.shape = {lhs_data.shape}") + print(f"rhs_data.shape = {rhs_data.shape}") data = op(lhs_data, rhs_data) + if len(lhs.shape) == 1: + data = data.squeeze(-2) + if len(rhs.shape) == 1: + data = data.squeeze(-1) return Tensor(data, inputs, dtype) diff --git a/test/test_torch.py b/test/test_torch.py index ec27094a8..f59127b40 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -322,6 +322,26 @@ def test_binary_broadcast(inputs1, inputs2, output_shape1, output_shape2): assert_close(actual_block, expected_block) +@pytest.mark.parametrize('output_shape2', [(2,), (2, 5), (4, 2, 5)], ids=str) +@pytest.mark.parametrize('output_shape1', [(2,), (3, 2), (4, 3, 2)], ids=str) +@pytest.mark.parametrize('inputs2', [(), ('a',), ('b', 'a'), ('b', 'c', 'a')], ids=str) +@pytest.mark.parametrize('inputs1', [(), ('a',), ('a', 'b'), ('b', 'a', 'c')], ids=str) +def test_matmul(inputs1, inputs2, output_shape1, output_shape2): + sizes = {'a': 6, 'b': 7, 'c': 8} + inputs1 = OrderedDict((k, bint(sizes[k])) for k in inputs1) + inputs2 = OrderedDict((k, bint(sizes[k])) for k in inputs2) + x1 = random_tensor(inputs1, reals(*output_shape1)) + x2 = random_tensor(inputs1, reals(*output_shape2)) + + actual = x1 @ x2 + assert actual.output == find_domain(ops.matmul, x1.output, x2.output) + + block = {'a': 1, 'b': 2, 'c': 3} + actual_block = actual(**block) + expected_block = Tensor(x1(**block).data @ x2(**block).data) + assert_close(actual_block, expected_block) + + @pytest.mark.parametrize('scalar', [0.5]) @pytest.mark.parametrize('dims', [(), ('a',), ('a', 'b'), ('b', 'a', 'c')]) @pytest.mark.parametrize('symbol', BINARY_OPS) From 3e543d32c12b1f3ad7b630d9d3c6a8580b524c24 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 20 Sep 2019 11:43:40 -0700 Subject: [PATCH 18/31] Use matmul and expand=True in distributions and pyro.convert --- funsor/distributions.py | 6 ++---- funsor/pyro/convert.py | 7 ++----- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/funsor/distributions.py b/funsor/distributions.py index 05c178728..c4d4d7e00 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -368,8 +368,7 @@ def eager_normal(loc, scale, value): if isinstance(loc, Variable): loc, value = value, loc - inputs, (loc, scale) = align_tensors(loc, scale) - loc, scale = torch.broadcast_tensors(loc, scale) + inputs, (loc, scale) = align_tensors(loc, scale, expand=True) inputs.update(value.inputs) int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') @@ -407,8 +406,7 @@ def eager_normal(loc, scale, value): coeffs[c] = affine(**{k: 1. if c == k else 0. for k in real_inputs.keys()}) - const tensors = [const] + list(coeffs.values()) - inputs, tensors = align_tensors(*tensors) - tensors = torch.broadcast_tensors(*tensors) + inputs, tensors = align_tensors(*tensors, expand=True) const, coeffs = tensors[0], tensors[1:] dim = sum(d.num_elements for d in real_inputs.values()) diff --git a/funsor/pyro/convert.py b/funsor/pyro/convert.py index 05e388b7d..b794d9fdf 100644 --- a/funsor/pyro/convert.py +++ b/funsor/pyro/convert.py @@ -254,16 +254,13 @@ def eager_affine_normal(matrix, loc, scale, value_x, value_y): assert len(matrix.output.shape) == 2 assert value_x.output == reals(matrix.output.shape[0]) assert value_y.output == reals(matrix.output.shape[1]) - tensors = (matrix, loc, scale, value_x) - int_inputs, tensors = align_tensors(*tensors) - matrix, loc, scale, value_x = tensors + loc += value_x @ matrix + int_inputs, (loc, scale) = align_tensors(loc, scale, expand=True) - loc = loc + value_x.unsqueeze(-2).matmul(matrix).squeeze(-2) i_name = gensym("i") y_name = gensym("y") y_i_name = gensym("y_i") int_inputs[i_name] = bint(value_y.output.shape[0]) - loc, scale = torch.broadcast_tensors(loc, scale) loc = Tensor(loc, int_inputs) scale = Tensor(scale, int_inputs) y_dist = Independent(Normal(loc, scale, y_i_name), y_name, i_name, y_i_name) From 45878c52cab98ff04da206dbc1262bf8f33ad1b0 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 20 Sep 2019 14:50:30 -0700 Subject: [PATCH 19/31] More fixes to eager_mvn --- funsor/affine.py | 5 +++++ funsor/distributions.py | 38 ++++++++++++++++++++++---------------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/funsor/affine.py b/funsor/affine.py index fc1e46a3d..f26ab74ca 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -8,6 +8,11 @@ from funsor.torch import Tensor +# FIXME change this to a sound but incomplete test using pattern matching. +def is_affine(fn): + return True + + def extract_affine(fn): """ Extracts an affine representation of a funsor, which is exact for affine diff --git a/funsor/distributions.py b/funsor/distributions.py index 8d2dd04ce..9b812e680 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -7,7 +7,7 @@ import funsor.delta import funsor.ops as ops -from funsor.affine import extract_affine +from funsor.affine import extract_affine, is_affine from funsor.cnf import Contraction from funsor.domains import bint, reals from funsor.gaussian import BlockMatrix, BlockVector, Gaussian @@ -456,8 +456,12 @@ def eager_mvn(loc, scale_tril, value): @eager.register(MultivariateNormal, (Variable, Contraction), Tensor, Tensor) @eager.register(MultivariateNormal, Tensor, Tensor, (Variable, Contraction)) def eager_mvn(loc, scale_tril, value): + assert len(loc.shape) == 1 + assert len(scale_tril.shape) == 2 + assert value.output == loc.output + affine = loc - value - if not isinstance(affine, Contraction): + if not is_affine(affine): return None # lazy const, coeffs = extract_affine(affine) @@ -467,10 +471,10 @@ def eager_mvn(loc, scale_tril, value): return None # lazy # Dovetail to avoid variable name collision in einsum. - equations1 = [''.join(c if c in ',->' else chr(ord(c) * 2)) - for _, eqn in coeffs.values() for c in eqn] - equations2 = [''.join(c if c in ',->' else chr(ord(c) * 2 + 1)) - for _, eqn in coeffs.values() for c in eqn] + equations1 = [''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a')) for c in eqn) + for _, eqn in coeffs.values()] + equations2 = [''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a') + 1) for c in eqn) + for _, eqn in coeffs.values()] real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real') assert tuple(real_inputs) == tuple(coeffs) @@ -491,28 +495,30 @@ def eager_mvn(loc, scale_tril, value): batch_shape = const.shape dim = sum(d.num_elements for d in real_inputs.values()) - loc = BlockVector(batch_shape + (dim,)) - loc[..., 0] = -const / coeffs[0] # FIXME consider directly constructing info_vec - + info_vec = BlockVector(batch_shape + (dim,)) precision = BlockMatrix(batch_shape + (dim, dim)) offset1 = 0 for i1, (v1, c1) in enumerate(zip(real_inputs, coeffs)): slice1 = slice(offset1, offset1 + real_inputs[v1].num_elements) - input11, input12 = equations1[i1].split('->')[0].split(',') + inputs1, output1 = equations1[i1].split('->') + input11, input12 = inputs1.split(',')[0] + input112 = input11[len(input11) // 2:] + info_vec[..., slice1] = torch.einsum( + f'...{input11},...{output1}->...{input12}', c1, const) + offset2 = 0 for i2, (v2, c2) in enumerate(zip(real_inputs, coeffs)): slice2 = slice(offset2, offset2 + real_inputs[v2].num_elements) - input21, input22 = equations2[i2].split('->')[0].split(',') + input21 = equations2[i2].split(',')[0] + input212 = input21[len(input21) // 2:] precision[..., slice1, slice2] = torch.einsum( - f'...{input11}{input12},...{input21}{input22}->...{input12}{input22}', - c1, c2) + f'...{input11},...{input21}->...{input112}{input212}', c1, c2) + offset2 = slice2.stop offset1 = slice1.stop - loc = loc.as_tensor() + info_vec = info_vec.as_tensor() precision = precision.as_tensor() - info_vec = precision.matmul(loc.unsqueeze(-1)).squeeze(-1) - log_prob = (-0.5 * dim * math.log(2 * math.pi) - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1) - 0.5 * (loc * info_vec).sum(-1)) From 4eec408413e7995ccc9816080c108277131cab3a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 20 Sep 2019 14:51:18 -0700 Subject: [PATCH 20/31] Remove debug statements --- funsor/torch.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/funsor/torch.py b/funsor/torch.py index e092746a4..ef420dc25 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -470,10 +470,6 @@ def eager_binary_tensor_tensor(op, lhs, rhs): shape = shape[:cut] + (1,) * (lhs_dim - rhs_dim) + shape[cut:] rhs_data = rhs_data.reshape(shape) - print(f"lhs.data.shape = {lhs.data.shape}") - print(f"rhs.data.shape = {rhs.data.shape}") - print(f"lhs_data.shape = {lhs_data.shape}") - print(f"rhs_data.shape = {rhs_data.shape}") data = op(lhs_data, rhs_data) if len(lhs.shape) == 1: data = data.squeeze(-2) From f36e8647df83e0d037e1358965ae498a594fb34a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 20 Sep 2019 15:06:03 -0700 Subject: [PATCH 21/31] Fix some typos --- funsor/distributions.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/funsor/distributions.py b/funsor/distributions.py index 9b812e680..0c920939c 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -501,19 +501,16 @@ def eager_mvn(loc, scale_tril, value): for i1, (v1, c1) in enumerate(zip(real_inputs, coeffs)): slice1 = slice(offset1, offset1 + real_inputs[v1].num_elements) inputs1, output1 = equations1[i1].split('->') - input11, input12 = inputs1.split(',')[0] - input112 = input11[len(input11) // 2:] + input11, input12 = inputs1.split(',') info_vec[..., slice1] = torch.einsum( f'...{input11},...{output1}->...{input12}', c1, const) - offset2 = 0 for i2, (v2, c2) in enumerate(zip(real_inputs, coeffs)): slice2 = slice(offset2, offset2 + real_inputs[v2].num_elements) - input21 = equations2[i2].split(',')[0] - input212 = input21[len(input21) // 2:] + inputs2, output2 = equations2[i2].split('->') + input21, input22 = inputs2.split(',') precision[..., slice1, slice2] = torch.einsum( - f'...{input11},...{input21}->...{input112}{input212}', c1, c2) - + f'...{input11},...{input21}->...{input12}{input22}', c1, c2) offset2 = slice2.stop offset1 = slice1.stop From 3e8eed10e6915878cb4be2e9b33c7192b5e5c4b2 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 20 Sep 2019 15:32:12 -0700 Subject: [PATCH 22/31] Fix shape errors --- funsor/distributions.py | 15 ++++++++------- funsor/ops.py | 3 ++- test/examples/test_sensor_fusion.py | 29 +++++++++++++++-------------- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/funsor/distributions.py b/funsor/distributions.py index 0c920939c..c9b799e3e 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -484,15 +484,15 @@ def eager_mvn(loc, scale_tril, value): prec_sqrt = Tensor(eye.triangular_solve(scale_tril.data, upper=False).solution, scale_tril.inputs) tensors = [] + d1 = const.output for k, (coeff, eqn) in coeffs.items(): - shape = real_inputs[k].shape - size = real_inputs[k].num_elements - tensors.append((prec_sqrt @ coeff.reshape((size, size))).reshape(shape + shape)) + shape = (prec_sqrt.shape[-1], real_inputs[k].num_elements) + tensors.append((prec_sqrt @ coeff.reshape(shape)).reshape(coeff.shape)) tensors.extend([const, scale_tril]) - inputs, tensors = align_tensors(*tensors, expand=True) + int_inputs, tensors = align_tensors(*tensors, expand=True) coeffs, const, scale_tril = tensors[:-2], tensors[-2], tensors[-1] - batch_shape = const.shape + batch_shape = const.shape[:-1] dim = sum(d.num_elements for d in real_inputs.values()) info_vec = BlockVector(batch_shape + (dim,)) @@ -518,8 +518,9 @@ def eager_mvn(loc, scale_tril, value): precision = precision.as_tensor() log_prob = (-0.5 * dim * math.log(2 * math.pi) - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1) - - 0.5 * (loc * info_vec).sum(-1)) - int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') + - 0.5 * const.pow(2).reshape(batch_shape + (-1,)).sum(-1)) + inputs = int_inputs.copy() + inputs.update(real_inputs) return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, inputs) diff --git a/funsor/ops.py b/funsor/ops.py index 15aba813b..66c5a81aa 100644 --- a/funsor/ops.py +++ b/funsor/ops.py @@ -53,6 +53,7 @@ def log_abs_det_jacobian(x, y): raise NotImplementedError +# FIXME Most code assumes this is an AssociativeCommutativeOp. class AssociativeOp(Op): pass @@ -65,7 +66,7 @@ class MulOp(AssociativeOp): pass -class MatmulOp(AssociativeOp): +class MatmulOp(Op): # Associtive but not commutative. pass diff --git a/test/examples/test_sensor_fusion.py b/test/examples/test_sensor_fusion.py index 5b47d1edc..b4fe638d4 100644 --- a/test/examples/test_sensor_fusion.py +++ b/test/examples/test_sensor_fusion.py @@ -15,25 +15,26 @@ # This version constructs factors using funsor.distributions. -def test_distributions(): - data = Tensor(torch.randn(2, 2), OrderedDict([("time", bint(2))])) +@pytest.mark.parametrize('state_dim,obs_dim', [(3, 2), (2, 3)]) +def test_distributions(state_dim, obs_dim): + data = Tensor(torch.randn(2, obs_dim))["time"] - bias = Variable("bias", reals(2)) - bias_dist = dist_to_funsor(random_mvn((), 2))(value=bias) + bias = Variable("bias", reals(obs_dim)) + bias_dist = dist_to_funsor(random_mvn((), obs_dim))(value=bias) - prev = Variable("prev", reals(3)) - curr = Variable("curr", reals(3)) - trans_mat = Tensor(torch.randn(3, 3)) - trans_mvn = random_mvn((), 3) + prev = Variable("prev", reals(state_dim)) + curr = Variable("curr", reals(state_dim)) + trans_mat = Tensor(torch.eye(state_dim) + 0.1 * torch.randn(state_dim, state_dim)) + trans_mvn = random_mvn((), state_dim) trans_dist = dist.MultivariateNormal( loc=trans_mvn.loc, scale_tril=trans_mvn.scale_tril, value=curr - prev @ trans_mat) - state = Variable("state", reals(3)) - obs = Variable("obs", reals(2)) - obs_mat = Tensor(torch.randn(3, 2)) - obs_mvn = random_mvn((), 2) + state = Variable("state", reals(state_dim)) + obs = Variable("obs", reals(obs_dim)) + obs_mat = Tensor(torch.randn(state_dim, obs_dim)) + obs_mvn = random_mvn((), obs_dim) obs_dist = dist.MultivariateNormal( loc=obs_mvn.loc, scale_tril=obs_mvn.scale_tril, @@ -42,10 +43,10 @@ def test_distributions(): log_prob = 0 log_prob += bias_dist - state_0 = Variable("state_0", reals(3)) + state_0 = Variable("state_0", reals(state_dim)) log_prob += obs_dist(state=state_0, obs=data(time=0)) - state_1 = Variable("state_1", reals(3)) + state_1 = Variable("state_1", reals(state_dim)) log_prob += trans_dist(prev=state_0, curr=state_1) log_prob += obs_dist(state=state_1, obs=data(time=1)) From cc181f0a219f4f8297e624160ae799bc1631f9bf Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 20 Sep 2019 15:37:29 -0700 Subject: [PATCH 23/31] Fix const computation --- funsor/distributions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/funsor/distributions.py b/funsor/distributions.py index c9b799e3e..ed9a6f8d4 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -488,8 +488,9 @@ def eager_mvn(loc, scale_tril, value): for k, (coeff, eqn) in coeffs.items(): shape = (prec_sqrt.shape[-1], real_inputs[k].num_elements) tensors.append((prec_sqrt @ coeff.reshape(shape)).reshape(coeff.shape)) + tensors.append((prec_sqrt @ const.reshape((-1,))).reshape(const.shape)) + tensors.append(scale_tril) - tensors.extend([const, scale_tril]) int_inputs, tensors = align_tensors(*tensors, expand=True) coeffs, const, scale_tril = tensors[:-2], tensors[-2], tensors[-1] batch_shape = const.shape[:-1] From ffe0c5509dfd6e52e80f8f50ad88dca8ae1270f7 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 20 Sep 2019 16:07:28 -0700 Subject: [PATCH 24/31] Add failing tests --- funsor/distributions.py | 1 - test/test_distributions.py | 61 +++++++++++++++++++++++++++++++++++--- 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/funsor/distributions.py b/funsor/distributions.py index ed9a6f8d4..1d2958a81 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -484,7 +484,6 @@ def eager_mvn(loc, scale_tril, value): prec_sqrt = Tensor(eye.triangular_solve(scale_tril.data, upper=False).solution, scale_tril.inputs) tensors = [] - d1 = const.output for k, (coeff, eqn) in coeffs.items(): shape = (prec_sqrt.shape[-1], real_inputs[k].num_elements) tensors.append((prec_sqrt @ coeff.reshape(shape)).reshape(coeff.shape)) diff --git a/test/test_distributions.py b/test/test_distributions.py index 2fbe878c9..87c3f1857 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -7,12 +7,13 @@ import funsor import funsor.distributions as dist -from funsor.cnf import Contraction +from funsor.cnf import Contraction, GaussianMixture from funsor.delta import Delta from funsor.domains import bint, reals -from funsor.terms import Independent, Variable +from funsor.interpreter import interpretation, reinterpret +from funsor.terms import Independent, Variable, lazy from funsor.testing import assert_close, check_funsor, random_tensor -from funsor.torch import Tensor +from funsor.torch import Einsum, Tensor @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) @@ -398,7 +399,6 @@ def test_normal_gaussian_3(batch_shape): 'dist.Normal(x - y, scale, 0)', 'dist.Normal(0, scale, y - x)', 'dist.Normal(2 * x - y, scale, x)', - # TODO should we expect these to work without correction terms? 'dist.Normal(0, 1, (x - y) / scale) - scale.log()', 'dist.Normal(2 * y, 2 * scale, 2 * x) + math.log(2)', ] @@ -488,6 +488,59 @@ def test_mvn_gaussian(batch_shape): assert_close(actual, expected, atol=1e-3, rtol=1e-4) +def _check_mvn_affine(d1, data): + d2 = reinterpret(d1) + assert issubclass(type(d2), GaussianMixture) + actual = d2(**data) + expected = d1(**data) + assert_close(actual, expected) + + +def test_mvn_affine_1(): + x = Variable('x', reals(2)) + data = dict(x=Tensor(torch.randn(2))) + with interpretation(lazy): + d = dist.MultivariateNormal(Tensor(torch.zeros(2)), + Tensor(torch.eye(2)), + 2 * x + 1) + _check_mvn_affine(d, data) + + +def test_mvn_affine_2(): + x = Variable('x', reals(2)) + y = Variable('y', reals(2)) + data = dict(x=Tensor(torch.randn(2)), y=Tensor(torch.randn(2))) + with interpretation(lazy): + d = dist.MultivariateNormal(Tensor(torch.zeros(2)), + Tensor(torch.eye(2)), + x - y) + _check_mvn_affine(d, data) + + +def test_mvn_affine_3(): + x = Variable('x', reals(2)) + y = Variable('y', reals(3)) + m = Tensor(torch.randn(2, 3)) + data = dict(x=Tensor(torch.randn(2)), y=Tensor(torch.randn(3))) + with interpretation(lazy): + d = dist.MultivariateNormal(Tensor(torch.zeros(3)), + Tensor(torch.eye(3)), + x @ m - y) + _check_mvn_affine(d, data) + + +def test_mvn_affine_4(): + c = Tensor(torch.randn(3, 2, 2)) + x = Variable('x', reals(2, 2)) + y = Variable('y', reals()) + data = dict(x=Tensor(torch.randn(2)), y=Tensor(torch.randn(3))) + with interpretation(lazy): + d = dist.MultivariateNormal(Tensor(torch.zeros(3)), + Tensor(torch.eye(3)), + Einsum("abc,bc->a", c, x) + y) + _check_mvn_affine(d, data) + + @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) @pytest.mark.parametrize('syntax', ['eager', 'lazy']) def test_poisson_probs_density(batch_shape, syntax): From 47683235f46aae50cff80a50c1aa727e85910f43 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 20 Sep 2019 16:18:43 -0700 Subject: [PATCH 25/31] Add more failing tests --- funsor/distributions.py | 16 ++++++++++------ test/test_distributions.py | 31 ++++++++++++++++++++++++++----- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/funsor/distributions.py b/funsor/distributions.py index 1d2958a81..142eb5c7e 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -499,20 +499,24 @@ def eager_mvn(loc, scale_tril, value): precision = BlockMatrix(batch_shape + (dim, dim)) offset1 = 0 for i1, (v1, c1) in enumerate(zip(real_inputs, coeffs)): - slice1 = slice(offset1, offset1 + real_inputs[v1].num_elements) + size1 = real_inputs[v1].num_elements + slice1 = slice(offset1, offset1 + size1) inputs1, output1 = equations1[i1].split('->') input11, input12 = inputs1.split(',') info_vec[..., slice1] = torch.einsum( - f'...{input11},...{output1}->...{input12}', c1, const) + f'...{input11},...{output1}->...{input12}', c1, const) \ + .reshape(batch_shape + (size1,)) offset2 = 0 for i2, (v2, c2) in enumerate(zip(real_inputs, coeffs)): - slice2 = slice(offset2, offset2 + real_inputs[v2].num_elements) + size2 = real_inputs[v2].num_elements + slice2 = slice(offset2, offset2 + size2) inputs2, output2 = equations2[i2].split('->') input21, input22 = inputs2.split(',') precision[..., slice1, slice2] = torch.einsum( - f'...{input11},...{input21}->...{input12}{input22}', c1, c2) - offset2 = slice2.stop - offset1 = slice1.stop + f'...{input11},...{input21}->...{input12}{input22}', c1, c2) \ + .reshape(batch_shape + (size1, size2)) + offset2 += size2 + offset1 += size1 info_vec = info_vec.as_tensor() precision = precision.as_tensor() diff --git a/test/test_distributions.py b/test/test_distributions.py index 87c3f1857..a77c8e0fe 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -496,7 +496,7 @@ def _check_mvn_affine(d1, data): assert_close(actual, expected) -def test_mvn_affine_1(): +def test_mvn_affine_affine(): x = Variable('x', reals(2)) data = dict(x=Tensor(torch.randn(2))) with interpretation(lazy): @@ -506,7 +506,7 @@ def test_mvn_affine_1(): _check_mvn_affine(d, data) -def test_mvn_affine_2(): +def test_mvn_affine_two_vars(): x = Variable('x', reals(2)) y = Variable('y', reals(2)) data = dict(x=Tensor(torch.randn(2)), y=Tensor(torch.randn(2))) @@ -517,7 +517,7 @@ def test_mvn_affine_2(): _check_mvn_affine(d, data) -def test_mvn_affine_3(): +def test_mvn_affine_matmul(): x = Variable('x', reals(2)) y = Variable('y', reals(3)) m = Tensor(torch.randn(2, 3)) @@ -529,11 +529,11 @@ def test_mvn_affine_3(): _check_mvn_affine(d, data) -def test_mvn_affine_4(): +def test_mvn_affine_einsum(): c = Tensor(torch.randn(3, 2, 2)) x = Variable('x', reals(2, 2)) y = Variable('y', reals()) - data = dict(x=Tensor(torch.randn(2)), y=Tensor(torch.randn(3))) + data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(()))) with interpretation(lazy): d = dist.MultivariateNormal(Tensor(torch.zeros(3)), Tensor(torch.eye(3)), @@ -541,6 +541,27 @@ def test_mvn_affine_4(): _check_mvn_affine(d, data) +def test_mvn_affine_getitem(): + x = Variable('x', reals(2, 2)) + data = dict(x=Tensor(torch.randn(2, 2))) + with interpretation(lazy): + d = dist.MultivariateNormal(Tensor(torch.zeros(2)), + Tensor(torch.eye(2)), + x[0] - x[1]) + _check_mvn_affine(d, data) + + +def test_mvn_affine_reshape(): + x = Variable('x', reals(2, 2)) + y = Variable('y', reals(4)) + data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(4))) + with interpretation(lazy): + d = dist.MultivariateNormal(Tensor(torch.zeros(4)), + Tensor(torch.eye(4)), + x.reshape((4,)) - y) + _check_mvn_affine(d, data) + + @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) @pytest.mark.parametrize('syntax', ['eager', 'lazy']) def test_poisson_probs_density(batch_shape, syntax): From 914b223045f142abbc7991aa1f3c93ee847033c0 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 21 Sep 2019 05:55:19 -0700 Subject: [PATCH 26/31] Address review comments --- funsor/distributions.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/funsor/distributions.py b/funsor/distributions.py index 142eb5c7e..0e1a6a112 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -481,13 +481,14 @@ def eager_mvn(loc, scale_tril, value): # Incorporate scale_tril before broadcasting. eye = torch.eye(scale_tril.data.size(-1)).expand(scale_tril.data.shape) - prec_sqrt = Tensor(eye.triangular_solve(scale_tril.data, upper=False).solution, + prec_sqrt = Tensor(eye.triangular_solve(scale_tril.data, upper=False) + .solution.transpose(-1, -2), scale_tril.inputs) tensors = [] for k, (coeff, eqn) in coeffs.items(): - shape = (prec_sqrt.shape[-1], real_inputs[k].num_elements) - tensors.append((prec_sqrt @ coeff.reshape(shape)).reshape(coeff.shape)) - tensors.append((prec_sqrt @ const.reshape((-1,))).reshape(const.shape)) + shape = (real_inputs[k].num_elements, prec_sqrt.shape[-1]) + tensors.append((coeff.reshape(shape) @ prec_sqrt).reshape(coeff.shape)) + tensors.append((const.reshape((-1,)) @ prec_sqrt).reshape(const.shape)) tensors.append(scale_tril) int_inputs, tensors = align_tensors(*tensors, expand=True) @@ -503,6 +504,7 @@ def eager_mvn(loc, scale_tril, value): slice1 = slice(offset1, offset1 + size1) inputs1, output1 = equations1[i1].split('->') input11, input12 = inputs1.split(',') + assert input11 == input12 + output1 info_vec[..., slice1] = torch.einsum( f'...{input11},...{output1}->...{input12}', c1, const) \ .reshape(batch_shape + (size1,)) @@ -512,15 +514,16 @@ def eager_mvn(loc, scale_tril, value): slice2 = slice(offset2, offset2 + size2) inputs2, output2 = equations2[i2].split('->') input21, input22 = inputs2.split(',') + assert input21 == input22 + output2 precision[..., slice1, slice2] = torch.einsum( - f'...{input11},...{input21}->...{input12}{input22}', c1, c2) \ + f'...{input11},...{input22}{output1}->...{input12}{input22}', c1, c2) \ .reshape(batch_shape + (size1, size2)) offset2 += size2 offset1 += size1 info_vec = info_vec.as_tensor() precision = precision.as_tensor() - log_prob = (-0.5 * dim * math.log(2 * math.pi) + log_prob = (-0.5 * scale_tril.size(-1) * math.log(2 * math.pi) - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1) - 0.5 * const.pow(2).reshape(batch_shape + (-1,)).sum(-1)) inputs = int_inputs.copy() From b17cc7eaef2c7c1c9aa8e6332f3ca6bc2e5f8670 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 21 Sep 2019 06:01:40 -0700 Subject: [PATCH 27/31] Simplify --- funsor/distributions.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/funsor/distributions.py b/funsor/distributions.py index 0e1a6a112..afb11f602 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -460,7 +460,10 @@ def eager_mvn(loc, scale_tril, value): assert len(scale_tril.shape) == 2 assert value.output == loc.output - affine = loc - value + eye = torch.eye(scale_tril.data.size(-1)).expand(scale_tril.data.shape) + prec_sqrt = Tensor(eye.triangular_solve(scale_tril.data, upper=False).solution, + scale_tril.inputs) + affine = prec_sqrt @ (loc - value) if not is_affine(affine): return None # lazy @@ -479,18 +482,7 @@ def eager_mvn(loc, scale_tril, value): real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real') assert tuple(real_inputs) == tuple(coeffs) - # Incorporate scale_tril before broadcasting. - eye = torch.eye(scale_tril.data.size(-1)).expand(scale_tril.data.shape) - prec_sqrt = Tensor(eye.triangular_solve(scale_tril.data, upper=False) - .solution.transpose(-1, -2), - scale_tril.inputs) - tensors = [] - for k, (coeff, eqn) in coeffs.items(): - shape = (real_inputs[k].num_elements, prec_sqrt.shape[-1]) - tensors.append((coeff.reshape(shape) @ prec_sqrt).reshape(coeff.shape)) - tensors.append((const.reshape((-1,)) @ prec_sqrt).reshape(const.shape)) - tensors.append(scale_tril) - + tensors = [coeff for coeff, _ in coeffs.values()] + [const, scale_tril] int_inputs, tensors = align_tensors(*tensors, expand=True) coeffs, const, scale_tril = tensors[:-2], tensors[-2], tensors[-1] batch_shape = const.shape[:-1] From 9e001bb017bfeb5fe91e8be063fb4fdaee7b74b2 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 21 Sep 2019 06:19:00 -0700 Subject: [PATCH 28/31] Simplify more --- funsor/distributions.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/funsor/distributions.py b/funsor/distributions.py index afb11f602..2cd24474a 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -459,19 +459,24 @@ def eager_mvn(loc, scale_tril, value): assert len(loc.shape) == 1 assert len(scale_tril.shape) == 2 assert value.output == loc.output + if not is_affine(loc) or not is_affine(value): + return None # lazy + # Extract an affine representation. eye = torch.eye(scale_tril.data.size(-1)).expand(scale_tril.data.shape) prec_sqrt = Tensor(eye.triangular_solve(scale_tril.data, upper=False).solution, scale_tril.inputs) affine = prec_sqrt @ (loc - value) - if not is_affine(affine): - return None # lazy - const, coeffs = extract_affine(affine) if not isinstance(const, Tensor): return None # lazy if not all(isinstance(coeff, Tensor) for coeff, _ in coeffs.values()): return None # lazy + scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), + scale_tril.inputs) + log_prob = (-0.5 * scale_diag.shape[-1] * math.log(2 * math.pi) + - scale_diag.log().sum() - 0.5 * (const ** 2).sum()) + print(f"log_prob = {log_prob}") # Dovetail to avoid variable name collision in einsum. equations1 = [''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a')) for c in eqn) @@ -482,11 +487,12 @@ def eager_mvn(loc, scale_tril, value): real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real') assert tuple(real_inputs) == tuple(coeffs) - tensors = [coeff for coeff, _ in coeffs.values()] + [const, scale_tril] - int_inputs, tensors = align_tensors(*tensors, expand=True) - coeffs, const, scale_tril = tensors[:-2], tensors[-2], tensors[-1] - batch_shape = const.shape[:-1] + # Align and broadcast tensors. + tensors = [coeff for coeff, _ in coeffs.values()] + [const, scale_diag] + inputs, tensors = align_tensors(*tensors, expand=True) + coeffs, const, scale_diag = tensors[:-2], tensors[-2], tensors[-1] dim = sum(d.num_elements for d in real_inputs.values()) + batch_shape = const.shape[:-1] info_vec = BlockVector(batch_shape + (dim,)) precision = BlockMatrix(batch_shape + (dim, dim)) @@ -515,12 +521,8 @@ def eager_mvn(loc, scale_tril, value): info_vec = info_vec.as_tensor() precision = precision.as_tensor() - log_prob = (-0.5 * scale_tril.size(-1) * math.log(2 * math.pi) - - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1) - - 0.5 * const.pow(2).reshape(batch_shape + (-1,)).sum(-1)) - inputs = int_inputs.copy() inputs.update(real_inputs) - return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, inputs) + return log_prob + Gaussian(info_vec, precision, inputs) class Poisson(Distribution): From 3c9e3e51aa8094f57316ad0aef946d5bde07b30d Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 21 Sep 2019 06:29:21 -0700 Subject: [PATCH 29/31] Fix bugs --- funsor/distributions.py | 17 +++++++++-------- test/test_distributions.py | 20 ++++++++++++++++++++ 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/funsor/distributions.py b/funsor/distributions.py index 2cd24474a..6950decaf 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -472,11 +472,11 @@ def eager_mvn(loc, scale_tril, value): return None # lazy if not all(isinstance(coeff, Tensor) for coeff, _ in coeffs.values()): return None # lazy - scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), - scale_tril.inputs) - log_prob = (-0.5 * scale_diag.shape[-1] * math.log(2 * math.pi) + + # Compute log_prob using funsors. + scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), scale_tril.inputs) + log_prob = (-0.5 * scale_diag.shape[0] * math.log(2 * math.pi) - scale_diag.log().sum() - 0.5 * (const ** 2).sum()) - print(f"log_prob = {log_prob}") # Dovetail to avoid variable name collision in einsum. equations1 = [''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a')) for c in eqn) @@ -488,11 +488,12 @@ def eager_mvn(loc, scale_tril, value): assert tuple(real_inputs) == tuple(coeffs) # Align and broadcast tensors. - tensors = [coeff for coeff, _ in coeffs.values()] + [const, scale_diag] + neg_const = - const + tensors = [neg_const] + [coeff for coeff, _ in coeffs.values()] inputs, tensors = align_tensors(*tensors, expand=True) - coeffs, const, scale_diag = tensors[:-2], tensors[-2], tensors[-1] + neg_const, coeffs = tensors[0], tensors[1:] dim = sum(d.num_elements for d in real_inputs.values()) - batch_shape = const.shape[:-1] + batch_shape = neg_const.shape[:-1] info_vec = BlockVector(batch_shape + (dim,)) precision = BlockMatrix(batch_shape + (dim, dim)) @@ -504,7 +505,7 @@ def eager_mvn(loc, scale_tril, value): input11, input12 = inputs1.split(',') assert input11 == input12 + output1 info_vec[..., slice1] = torch.einsum( - f'...{input11},...{output1}->...{input12}', c1, const) \ + f'...{input11},...{output1}->...{input12}', c1, neg_const) \ .reshape(batch_shape + (size1,)) offset2 = 0 for i2, (v2, c2) in enumerate(zip(real_inputs, coeffs)): diff --git a/test/test_distributions.py b/test/test_distributions.py index a77c8e0fe..0e5502805 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -496,6 +496,26 @@ def _check_mvn_affine(d1, data): assert_close(actual, expected) +def test_mvn_affine_mul(): + x = Variable('x', reals(2)) + data = dict(x=Tensor(torch.randn(2))) + with interpretation(lazy): + d = dist.MultivariateNormal(Tensor(torch.zeros(2)), + Tensor(torch.eye(2)), + x * 2) + _check_mvn_affine(d, data) + + +def test_mvn_affine_add(): + x = Variable('x', reals(2)) + data = dict(x=Tensor(torch.randn(2))) + with interpretation(lazy): + d = dist.MultivariateNormal(Tensor(torch.zeros(2)), + Tensor(torch.eye(2)), + x + 1) + _check_mvn_affine(d, data) + + def test_mvn_affine_affine(): x = Variable('x', reals(2)) data = dict(x=Tensor(torch.randn(2))) From 2da0e3bbbee1a71ff338907e7cc1892840b8cf18 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 21 Sep 2019 06:39:12 -0700 Subject: [PATCH 30/31] Strengthen tests --- funsor/distributions.py | 2 +- test/test_distributions.py | 56 +++++++++++--------------------------- 2 files changed, 17 insertions(+), 41 deletions(-) diff --git a/funsor/distributions.py b/funsor/distributions.py index 6950decaf..75df088ef 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -488,7 +488,7 @@ def eager_mvn(loc, scale_tril, value): assert tuple(real_inputs) == tuple(coeffs) # Align and broadcast tensors. - neg_const = - const + neg_const = -const tensors = [neg_const] + [coeff for coeff, _ in coeffs.values()] inputs, tensors = align_tensors(*tensors, expand=True) neg_const, coeffs = tensors[0], tensors[1:] diff --git a/test/test_distributions.py b/test/test_distributions.py index 0e5502805..2a4d992e1 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -11,8 +11,9 @@ from funsor.delta import Delta from funsor.domains import bint, reals from funsor.interpreter import interpretation, reinterpret +from funsor.pyro.convert import dist_to_funsor from funsor.terms import Independent, Variable, lazy -from funsor.testing import assert_close, check_funsor, random_tensor +from funsor.testing import assert_close, check_funsor, random_mvn, random_tensor from funsor.torch import Einsum, Tensor @@ -489,6 +490,7 @@ def test_mvn_gaussian(batch_shape): def _check_mvn_affine(d1, data): + assert isinstance(d1, dist.MultivariateNormal) d2 = reinterpret(d1) assert issubclass(type(d2), GaussianMixture) actual = d2(**data) @@ -496,33 +498,12 @@ def _check_mvn_affine(d1, data): assert_close(actual, expected) -def test_mvn_affine_mul(): +def test_mvn_affine_one_var(): x = Variable('x', reals(2)) data = dict(x=Tensor(torch.randn(2))) with interpretation(lazy): - d = dist.MultivariateNormal(Tensor(torch.zeros(2)), - Tensor(torch.eye(2)), - x * 2) - _check_mvn_affine(d, data) - - -def test_mvn_affine_add(): - x = Variable('x', reals(2)) - data = dict(x=Tensor(torch.randn(2))) - with interpretation(lazy): - d = dist.MultivariateNormal(Tensor(torch.zeros(2)), - Tensor(torch.eye(2)), - x + 1) - _check_mvn_affine(d, data) - - -def test_mvn_affine_affine(): - x = Variable('x', reals(2)) - data = dict(x=Tensor(torch.randn(2))) - with interpretation(lazy): - d = dist.MultivariateNormal(Tensor(torch.zeros(2)), - Tensor(torch.eye(2)), - 2 * x + 1) + d = dist_to_funsor(random_mvn((), 2)) + d = d(value=2 * x + 1) _check_mvn_affine(d, data) @@ -531,9 +512,8 @@ def test_mvn_affine_two_vars(): y = Variable('y', reals(2)) data = dict(x=Tensor(torch.randn(2)), y=Tensor(torch.randn(2))) with interpretation(lazy): - d = dist.MultivariateNormal(Tensor(torch.zeros(2)), - Tensor(torch.eye(2)), - x - y) + d = dist_to_funsor(random_mvn((), 2)) + d = d(value=x - y) _check_mvn_affine(d, data) @@ -543,9 +523,8 @@ def test_mvn_affine_matmul(): m = Tensor(torch.randn(2, 3)) data = dict(x=Tensor(torch.randn(2)), y=Tensor(torch.randn(3))) with interpretation(lazy): - d = dist.MultivariateNormal(Tensor(torch.zeros(3)), - Tensor(torch.eye(3)), - x @ m - y) + d = dist_to_funsor(random_mvn((), 3)) + d = d(value=x @ m - y) _check_mvn_affine(d, data) @@ -555,9 +534,8 @@ def test_mvn_affine_einsum(): y = Variable('y', reals()) data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(()))) with interpretation(lazy): - d = dist.MultivariateNormal(Tensor(torch.zeros(3)), - Tensor(torch.eye(3)), - Einsum("abc,bc->a", c, x) + y) + d = dist_to_funsor(random_mvn((), 3)) + d = d(value=Einsum("abc,bc->a", c, x) + y) _check_mvn_affine(d, data) @@ -565,9 +543,8 @@ def test_mvn_affine_getitem(): x = Variable('x', reals(2, 2)) data = dict(x=Tensor(torch.randn(2, 2))) with interpretation(lazy): - d = dist.MultivariateNormal(Tensor(torch.zeros(2)), - Tensor(torch.eye(2)), - x[0] - x[1]) + d = dist_to_funsor(random_mvn((), 2)) + d = d(value=x[0] - x[1]) _check_mvn_affine(d, data) @@ -576,9 +553,8 @@ def test_mvn_affine_reshape(): y = Variable('y', reals(4)) data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(4))) with interpretation(lazy): - d = dist.MultivariateNormal(Tensor(torch.zeros(4)), - Tensor(torch.eye(4)), - x.reshape((4,)) - y) + d = dist_to_funsor(random_mvn((), 4)) + d = d(value=x.reshape((4,)) - y) _check_mvn_affine(d, data) From c7f71a5412316d399d2acf93cca11efd190db37f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 21 Sep 2019 06:55:37 -0700 Subject: [PATCH 31/31] Revert unnecessary change --- funsor/cnf.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index 633e2bde9..9c917e6f2 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -18,8 +18,6 @@ class Contraction(Funsor): """ Declarative representation of a finitary sum-product operation - - :ivar bool is_affine: Whether this contraction is affine. """ def __init__(self, red_op, bin_op, reduced_vars, terms): terms = (terms,) if isinstance(terms, Funsor) else terms