diff --git a/funsor/__init__.py b/funsor/__init__.py index 4b23515f4..aabfa800d 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -3,9 +3,9 @@ from funsor.domains import Domain, bint, find_domain, reals from funsor.interpreter import reinterpret from funsor.terms import Funsor, Number, Variable, of_shape, to_funsor -from funsor.torch import Function, Tensor, arange, einsum, function +from funsor.torch import Function, Tensor, arange, torch_einsum, function -from . import distributions, domains, gaussian, handlers, interpreter, minipyro, ops, terms, torch +from . import distributions, domains, einsum, gaussian, handlers, interpreter, minipyro, ops, terms, torch __all__ = [ 'Domain', @@ -33,4 +33,5 @@ 'terms', 'to_funsor', 'torch', + 'torch_einsum', ] diff --git a/funsor/einsum.py b/funsor/einsum.py new file mode 100644 index 000000000..446b77051 --- /dev/null +++ b/funsor/einsum.py @@ -0,0 +1,133 @@ +from __future__ import absolute_import, division, print_function + +from collections import defaultdict, OrderedDict +from six.moves import reduce + +import funsor.ops as ops +from funsor.interpreter import interpretation, reinterpret +from funsor.optimizer import apply_optimizer +from funsor.terms import Funsor, reflect + + +def naive_einsum(eqn, *terms, **kwargs): + backend = kwargs.pop('backend', 'torch') + if backend == 'torch': + sum_op, prod_op = ops.add, ops.mul + elif backend == 'pyro.ops.einsum.torch_log': + sum_op, prod_op = ops.logaddexp, ops.add + else: + raise ValueError("{} backend not implemented".format(backend)) + + assert isinstance(eqn, str) + assert all(isinstance(term, Funsor) for term in terms) + inputs, output = eqn.split('->') + assert len(output.split(',')) == 1 + input_dims = frozenset(d for inp in inputs.split(',') for d in inp) + output_dims = frozenset(output) + reduce_dims = input_dims - output_dims + return reduce(prod_op, terms).reduce(sum_op, reduce_dims) + + +def _partition(terms, sum_vars): + # Construct a bipartite graph between terms and the vars + neighbors = OrderedDict([(t, []) for t in terms]) + for term in terms: + for dim in term.inputs.keys(): + if dim in sum_vars: + neighbors[term].append(dim) + neighbors.setdefault(dim, []).append(term) + + # Partition the bipartite graph into connected components for contraction. + components = [] + while neighbors: + v, pending = neighbors.popitem() + component = OrderedDict([(v, None)]) # used as an OrderedSet + for v in pending: + component[v] = None + while pending: + v = pending.pop() + for v in neighbors.pop(v): + if v not in component: + component[v] = None + pending.append(v) + + # Split this connected component into tensors and dims. + component_terms = tuple(v for v in component if isinstance(v, Funsor)) + if component_terms: + component_dims = frozenset(v for v in component if not isinstance(v, Funsor)) + components.append((component_terms, component_dims)) + return components + + +def naive_plated_einsum(eqn, *terms, **kwargs): + """ + Implements Tensor Variable Elimination (Algorithm 1 in [Obermeyer et al 2019]) + + [Obermeyer et al 2019] Obermeyer, F., Bingham, E., Jankowiak, M., Chiu, J., + Pradhan, N., Rush, A., and Goodman, N. Tensor Variable Elimination for + Plated Factor Graphs, 2019 + """ + plates = kwargs.pop('plates', '') + if not plates: + return naive_einsum(eqn, *terms, **kwargs) + + backend = kwargs.pop('backend', 'torch') + if backend == 'torch': + sum_op, prod_op = ops.add, ops.mul + elif backend == 'pyro.ops.einsum.torch_log': + sum_op, prod_op = ops.logaddexp, ops.add + else: + raise ValueError("{} backend not implemented".format(backend)) + + assert isinstance(eqn, str) + assert all(isinstance(term, Funsor) for term in terms) + inputs, output = eqn.split('->') + assert len(output.split(',')) == 1 + input_dims = frozenset(d for inp in inputs.split(',') for d in inp) + output_dims = frozenset(d for d in output) + plate_dims = frozenset(plates) - output_dims + reduce_vars = input_dims - output_dims - frozenset(plates) + + if output_dims: + raise NotImplementedError("TODO") + + var_tree = {} + term_tree = defaultdict(list) + for term in terms: + ordinal = frozenset(term.inputs) & plate_dims + term_tree[ordinal].append(term) + for var in term.inputs: + if var not in plate_dims: + var_tree[var] = var_tree.get(var, ordinal) & ordinal + + ordinal_to_var = defaultdict(set) + for var, ordinal in var_tree.items(): + ordinal_to_var[ordinal].add(var) + + # Direct translation of Algorithm 1 + scalars = [] + while term_tree: + leaf = max(term_tree, key=len) + leaf_terms = term_tree.pop(leaf) + leaf_reduce_vars = ordinal_to_var[leaf] + for (group_terms, group_vars) in _partition(leaf_terms, leaf_reduce_vars): + term = reduce(prod_op, group_terms).reduce(sum_op, group_vars) + remaining_vars = frozenset(term.inputs) & reduce_vars + if not remaining_vars: + scalars.append(term.reduce(prod_op, leaf)) + else: + new_plates = frozenset().union( + *(var_tree[v] for v in remaining_vars)) + if new_plates == leaf: + raise ValueError("intractable!") + term = term.reduce(prod_op, leaf - new_plates) + term_tree[new_plates].append(term) + + return reduce(prod_op, scalars) + + +def einsum(eqn, *terms, **kwargs): + with interpretation(reflect): + naive_ast = naive_plated_einsum(eqn, *terms, **kwargs) + optimized_ast = apply_optimizer(naive_ast) + return reinterpret(optimized_ast) # eager by default diff --git a/funsor/ops.py b/funsor/ops.py index 4989fb0db..98df701fc 100644 --- a/funsor/ops.py +++ b/funsor/ops.py @@ -83,6 +83,16 @@ def sample(x, y): raise ValueError +def reciprocal(x): + if isinstance(x, Number): + return 1. / x + if isinstance(x, torch.Tensor): + result = x.reciprocal() + result.clamp_(max=torch.finfo(result.dtype).max) + return result + raise ValueError("No reciprocal for type {}".format(type(x))) + + REDUCE_OP_TO_TORCH = { add: torch.sum, mul: torch.prod, @@ -115,10 +125,17 @@ def sample(x, y): ]) +PRODUCT_INVERSES = { + mul: reciprocal, + add: neg, +} + + __all__ = [ - 'REDUCE_OP_TO_TORCH', 'ASSOCIATIVE_OPS', 'DISTRIBUTIVE_OPS', + 'PRODUCT_INVERSES', + 'REDUCE_OP_TO_TORCH', 'abs', 'add', 'and_', @@ -137,6 +154,7 @@ def sample(x, y): 'neg', 'or_', 'pow', + 'reciprocal', 'sample', 'sub', 'truediv', diff --git a/funsor/testing.py b/funsor/testing.py index 8a14c9464..ce9665b4c 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -9,10 +9,9 @@ import torch from six.moves import reduce -import funsor.ops as ops from funsor.domains import Domain, bint from funsor.gaussian import Gaussian -from funsor.terms import Binary, Funsor +from funsor.terms import Funsor from funsor.torch import Tensor @@ -140,27 +139,3 @@ def random_gaussian(inputs): prec_sqrt = torch.randn(batch_shape + event_shape + event_shape) precision = torch.matmul(prec_sqrt, prec_sqrt.transpose(-1, -2)) return Gaussian(log_density, loc, precision, inputs) - - -def naive_einsum(eqn, *terms, **kwargs): - backend = kwargs.pop('backend', 'torch') - if backend == 'torch': - sum_op, prod_op = ops.add, ops.mul - elif backend == 'pyro.ops.einsum.torch_log': - sum_op, prod_op = ops.logaddexp, ops.add - else: - raise ValueError("{} backend not implemented".format(backend)) - - assert isinstance(eqn, str) - assert all(isinstance(term, Funsor) for term in terms) - inputs, output = eqn.split('->') - assert len(output.split(',')) == 1 - input_dims = frozenset(d for inp in inputs.split(',') for d in inp) - output_dims = frozenset(d for d in output) - reduce_dims = tuple(d for d in input_dims - output_dims) - prod = terms[0] - for term in terms[1:]: - prod = Binary(prod_op, prod, term) - for reduce_dim in reduce_dims: - prod = prod.reduce(sum_op, reduce_dim) - return prod diff --git a/funsor/torch.py b/funsor/torch.py index 544171e26..1cd67011d 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -383,7 +383,7 @@ def mvn_log_prob(loc, scale_tril, x): return functools.partial(_function, inputs, output) -def einsum(equation, *operands): +def torch_einsum(equation, *operands): """ Wrapper around :func:`torch.einsum` to operate on real-valued Funsors. @@ -412,7 +412,7 @@ def einsum(equation, *operands): 'align_tensor', 'align_tensors', 'arange', - 'einsum', + 'torch_einsum', 'function', 'materialize', ] diff --git a/test/test_einsum.py b/test/test_einsum.py index 178803a6f..81bff8cb4 100644 --- a/test/test_einsum.py +++ b/test/test_einsum.py @@ -15,15 +15,9 @@ from funsor.torch import Tensor from funsor.interpreter import interpretation, reinterpret from funsor.optimizer import apply_optimizer +from funsor.testing import assert_close, make_einsum_example -from funsor.testing import assert_close, make_einsum_example, naive_einsum - - -def naive_plated_einsum(eqn, *terms, **kwargs): - assert isinstance(eqn, str) - assert all(isinstance(term, funsor.Funsor) for term in terms) - # ... - raise NotImplementedError("TODO implement naive plated einsum") +from funsor.einsum import naive_einsum, naive_plated_einsum EINSUM_EXAMPLES = [ @@ -103,25 +97,34 @@ def test_einsum_categorical(equation): assert actual.inputs[output_dim].dtype == sizes[output_dim] -PLATED_EINSUM_EXAMPLES = [(ex, '') for ex in EINSUM_EXAMPLES] + [ +PLATED_EINSUM_EXAMPLES = [ ('i->', 'i'), - ('i->i', 'i'), (',i->', 'i'), - (',i->i', 'i'), ('ai->', 'i'), - ('ai->i', 'i'), - ('ai->ai', 'i'), - (',ai,abij->aij', 'ij'), - ('a,ai,bij->bij', 'ij'), + (',ai,abij->', 'ij'), + ('a,ai,bij->', 'ij'), + ('ai,abi,bci,cdi->', 'i'), + ('aij,abij,bcij,cdij->', 'ij'), + ('a,abi,bcij,cdij->', 'ij'), ] -@pytest.mark.xfail(reason="naive plated einsum not implemented") @pytest.mark.parametrize('equation,plates', PLATED_EINSUM_EXAMPLES) -def test_plated_einsum(equation, plates): +@pytest.mark.parametrize('backend', ['torch', 'pyro.ops.einsum.torch_log']) +def test_plated_einsum(equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) - expected = naive_ubersum(equation, *operands, plates=plates, backend='torch', modulo_total=False)[0] - actual = naive_plated_einsum(equation, *funsor_operands, plates=plates) + expected = naive_ubersum(equation, *operands, plates=plates, backend=backend, modulo_total=False)[0] + with interpretation(reflect): + naive_ast = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) + optimized_ast = apply_optimizer(naive_ast) + actual_optimized = reinterpret(optimized_ast) # eager by default + actual = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) + + assert_close(actual, actual_optimized, atol=1e-4) + + if len(outputs[0]) > 0: + actual = actual.align(tuple(outputs[0])) + assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: diff --git a/test/test_optimizer.py b/test/test_optimizer.py index cbf337133..479c2f7a7 100644 --- a/test/test_optimizer.py +++ b/test/test_optimizer.py @@ -5,6 +5,7 @@ import opt_einsum import torch +import pyro.ops.contract as pyro_einsum import funsor @@ -15,7 +16,8 @@ from funsor.terms import reflect, Variable from funsor.torch import Tensor -from funsor.testing import make_einsum_example, naive_einsum +from funsor.testing import make_einsum_example, assert_close +from funsor.einsum import naive_einsum, naive_plated_einsum, einsum def make_chain_einsum(num_steps): @@ -101,3 +103,49 @@ def test_nested_einsum(eqn1, eqn2, optimize1, optimize2, backend1, backend2): assert torch.allclose(expected1, actual1.data) assert torch.allclose(expected2, actual2.data) + + +def make_plated_hmm_einsum(num_steps, num_obs_plates=1, num_hidden_plates=0): + + assert num_obs_plates >= num_hidden_plates + t0 = num_obs_plates + + obs_plates = ''.join(opt_einsum.get_symbol(i) for i in range(num_obs_plates)) + hidden_plates = ''.join(opt_einsum.get_symbol(i) for i in range(num_hidden_plates)) + + inputs = [str(opt_einsum.get_symbol(t0))] + for t in range(t0, num_steps+t0): + inputs.append(str(opt_einsum.get_symbol(t)) + str(opt_einsum.get_symbol(t+1)) + hidden_plates) + inputs.append(str(opt_einsum.get_symbol(t+1)) + obs_plates) + equation = ",".join(inputs) + "->" + return (equation, ''.join(set(obs_plates + hidden_plates))) + + +PLATED_EINSUM_EXAMPLES = [ + make_plated_hmm_einsum(num_steps, num_obs_plates=b, num_hidden_plates=a) + for num_steps in range(2, 6) + for (a, b) in [(0, 1), (0, 2), (0, 0), (1, 1), (1, 2), (1, 2)] +] + + +@pytest.mark.parametrize('equation,plates', PLATED_EINSUM_EXAMPLES) +@pytest.mark.parametrize('backend', ['pyro.ops.einsum.torch_log']) +def test_optimized_plated_einsum(equation, plates, backend): + inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) + expected = pyro_einsum.einsum(equation, *operands, plates=plates, backend=backend)[0] + actual = einsum(equation, *funsor_operands, plates=plates, backend=backend) + + if len(equation) < 10: + actual_naive = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) + assert_close(actual, actual_naive) + + assert isinstance(actual, funsor.Tensor) and len(outputs) == 1 + if len(outputs[0]) > 0: + actual = actual.align(tuple(outputs[0])) + + assert expected.shape == actual.data.shape + assert torch.allclose(expected, actual.data) + for output in outputs: + for i, output_dim in enumerate(output): + assert output_dim in actual.inputs + assert actual.inputs[output_dim].dtype == sizes[output_dim] diff --git a/test/test_torch.py b/test/test_torch.py index c565b6da8..142b72636 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -10,7 +10,7 @@ from funsor.domains import Domain, bint, reals from funsor.terms import Variable from funsor.testing import assert_close, assert_equiv, check_funsor, random_tensor -from funsor.torch import Tensor, align_tensors +from funsor.torch import Tensor, align_tensors, torch_einsum @pytest.mark.parametrize('shape', [(), (4,), (3, 2)]) @@ -412,5 +412,5 @@ def test_einsum(equation): tensors = [torch.randn(tuple(sizes[d] for d in dims)) for dims in inputs] funsors = [Tensor(x) for x in tensors] expected = Tensor(torch.einsum(equation, *tensors)) - actual = funsor.einsum(equation, *funsors) + actual = torch_einsum(equation, *funsors) assert_close(actual, expected)