Skip to content

Commit

Permalink
Refactor contract (#60)
Browse files Browse the repository at this point in the history
* Sketch funsor.contract module

* Add unit test for _partition

* Add tests for partial_sum_product()

* Fix test in python 3
  • Loading branch information
fritzo authored and eb8680 committed Mar 9, 2019
1 parent 97153c7 commit e46c2df
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 71 deletions.
5 changes: 3 additions & 2 deletions funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from funsor.domains import Domain, bint, find_domain, reals
from funsor.interpreter import reinterpret
from funsor.terms import Funsor, Number, Variable, of_shape, to_funsor
from funsor.torch import Function, Tensor, arange, torch_einsum, function
from funsor.torch import Function, Tensor, arange, function, torch_einsum

from . import distributions, domains, einsum, gaussian, handlers, interpreter, minipyro, ops, terms, torch
from . import contract, distributions, domains, einsum, gaussian, handlers, interpreter, minipyro, ops, terms, torch

__all__ = [
'Domain',
Expand All @@ -17,6 +17,7 @@
'arange',
'backward',
'bint',
'contract',
'distributions',
'domains',
'einsum',
Expand Down
99 changes: 99 additions & 0 deletions funsor/contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import absolute_import, division, print_function

from collections import defaultdict, OrderedDict
from six.moves import reduce

from funsor.terms import Funsor


def _partition(terms, sum_vars):
# Construct a bipartite graph between terms and the vars
neighbors = OrderedDict([(t, []) for t in terms])
for term in terms:
for dim in term.inputs.keys():
if dim in sum_vars:
neighbors[term].append(dim)
neighbors.setdefault(dim, []).append(term)

# Partition the bipartite graph into connected components for contraction.
components = []
while neighbors:
v, pending = neighbors.popitem()
component = OrderedDict([(v, None)]) # used as an OrderedSet
for v in pending:
component[v] = None
while pending:
v = pending.pop()
for v in neighbors.pop(v):
if v not in component:
component[v] = None
pending.append(v)

# Split this connected component into tensors and dims.
component_terms = tuple(v for v in component if isinstance(v, Funsor))
if component_terms:
component_dims = frozenset(v for v in component if not isinstance(v, Funsor))
components.append((component_terms, component_dims))
return components


def partial_sum_product(sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset()):
"""
Performs partial sum-product contraction of a collection of factors.
:return: a list of partially contracted Funsors.
:rtype: list
"""
assert callable(sum_op)
assert callable(prod_op)
assert isinstance(factors, (tuple, list))
assert all(isinstance(f, Funsor) for f in factors)
assert isinstance(eliminate, frozenset)
assert isinstance(plates, frozenset)
sum_vars = eliminate - plates

var_to_ordinal = {}
ordinal_to_factors = defaultdict(list)
for f in factors:
ordinal = plates.intersection(f.inputs)
ordinal_to_factors[ordinal].append(f)
for var in sum_vars.intersection(f.inputs):
var_to_ordinal[var] = var_to_ordinal.get(var, ordinal) & ordinal

ordinal_to_vars = defaultdict(set)
for var, ordinal in var_to_ordinal.items():
ordinal_to_vars[ordinal].add(var)

results = []
while ordinal_to_factors:
leaf = max(ordinal_to_factors, key=len)
leaf_factors = ordinal_to_factors.pop(leaf)
leaf_reduce_vars = ordinal_to_vars[leaf]
for (group_factors, group_vars) in _partition(leaf_factors, leaf_reduce_vars):
f = reduce(prod_op, group_factors).reduce(sum_op, group_vars)
remaining_sum_vars = sum_vars.intersection(f.inputs)
if not remaining_sum_vars:
results.append(f.reduce(prod_op, leaf & eliminate))
else:
new_plates = frozenset().union(
*(var_to_ordinal[v] for v in remaining_sum_vars))
if new_plates == leaf:
raise ValueError("intractable!")
if not (leaf - new_plates).issubset(eliminate):
raise ValueError("cannot reduce {} before {}".format(
remaining_sum_vars, (leaf - new_plates) - eliminate))
f = f.reduce(prod_op, leaf - new_plates)
ordinal_to_factors[new_plates].append(f)

return results


def sum_product(sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset()):
"""
Performs sum-product contraction of a collection of factors.
:return: a single contracted Funsor.
:rtype: :class:`~funsor.terms.Funsor`
"""
factors = partial_sum_product(sum_op, prod_op, factors, eliminate, plates)
return reduce(prod_op, factors)
75 changes: 8 additions & 67 deletions funsor/einsum.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import absolute_import, division, print_function

from collections import defaultdict, OrderedDict
from six.moves import reduce

import funsor.ops as ops
from funsor.interpreter import interpretation, reinterpret
from funsor.optimizer import apply_optimizer
from funsor.terms import Funsor, reflect
from funsor.contract import sum_product


def naive_einsum(eqn, *terms, **kwargs):
Expand All @@ -28,37 +28,6 @@ def naive_einsum(eqn, *terms, **kwargs):
return reduce(prod_op, terms).reduce(sum_op, reduce_dims)


def _partition(terms, sum_vars):
# Construct a bipartite graph between terms and the vars
neighbors = OrderedDict([(t, []) for t in terms])
for term in terms:
for dim in term.inputs.keys():
if dim in sum_vars:
neighbors[term].append(dim)
neighbors.setdefault(dim, []).append(term)

# Partition the bipartite graph into connected components for contraction.
components = []
while neighbors:
v, pending = neighbors.popitem()
component = OrderedDict([(v, None)]) # used as an OrderedSet
for v in pending:
component[v] = None
while pending:
v = pending.pop()
for v in neighbors.pop(v):
if v not in component:
component[v] = None
pending.append(v)

# Split this connected component into tensors and dims.
component_terms = tuple(v for v in component if isinstance(v, Funsor))
if component_terms:
component_dims = frozenset(v for v in component if not isinstance(v, Funsor))
components.append((component_terms, component_dims))
return components


def naive_plated_einsum(eqn, *terms, **kwargs):
"""
Implements Tensor Variable Elimination (Algorithm 1 in [Obermeyer et al 2019])
Expand All @@ -82,48 +51,20 @@ def naive_plated_einsum(eqn, *terms, **kwargs):
assert isinstance(eqn, str)
assert all(isinstance(term, Funsor) for term in terms)
inputs, output = eqn.split('->')
inputs = inputs.split(',')
assert len(inputs) == len(terms)
assert len(output.split(',')) == 1
input_dims = frozenset(d for inp in inputs.split(',') for d in inp)
input_dims = frozenset(d for inp in inputs for d in inp)
output_dims = frozenset(d for d in output)
plate_dims = frozenset(plates) - output_dims
reduce_vars = input_dims - output_dims - frozenset(plates)

if output_dims:
output_plates = output_dims & frozenset(plates)
if not all(output_plates.issubset(inp) for inp in inputs):
raise NotImplementedError("TODO")

var_tree = {}
term_tree = defaultdict(list)
for term in terms:
ordinal = frozenset(term.inputs) & plate_dims
term_tree[ordinal].append(term)
for var in term.inputs:
if var not in plate_dims:
var_tree[var] = var_tree.get(var, ordinal) & ordinal

ordinal_to_var = defaultdict(set)
for var, ordinal in var_tree.items():
ordinal_to_var[ordinal].add(var)

# Direct translation of Algorithm 1
scalars = []
while term_tree:
leaf = max(term_tree, key=len)
leaf_terms = term_tree.pop(leaf)
leaf_reduce_vars = ordinal_to_var[leaf]
for (group_terms, group_vars) in _partition(leaf_terms, leaf_reduce_vars):
term = reduce(prod_op, group_terms).reduce(sum_op, group_vars)
remaining_vars = frozenset(term.inputs) & reduce_vars
if not remaining_vars:
scalars.append(term.reduce(prod_op, leaf))
else:
new_plates = frozenset().union(
*(var_tree[v] for v in remaining_vars))
if new_plates == leaf:
raise ValueError("intractable!")
term = term.reduce(prod_op, leaf - new_plates)
term_tree[new_plates].append(term)

return reduce(prod_op, scalars)
eliminate = plate_dims | reduce_vars
return sum_product(sum_op, prod_op, terms, eliminate, frozenset(plates))


def einsum(eqn, *terms, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions funsor/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from six.moves import reduce

from funsor.domains import Domain, bint
from funsor.domains import Domain, bint, reals
from funsor.gaussian import Gaussian
from funsor.terms import Funsor
from funsor.torch import Tensor
Expand Down Expand Up @@ -116,7 +116,7 @@ def assert_equiv(x, y):
check_funsor(x, y.inputs, y.output, y.data)


def random_tensor(inputs, output):
def random_tensor(inputs, output=reals()):
"""
Creates a random :class:`funsor.torch.Tensor` with given inputs and output.
"""
Expand Down
81 changes: 81 additions & 0 deletions test/test_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import absolute_import, division, print_function

from collections import OrderedDict

import pytest
from six.moves import reduce

import funsor.ops as ops
from funsor.contract import _partition, partial_sum_product, sum_product
from funsor.domains import bint
from funsor.testing import assert_close, random_tensor


@pytest.mark.parametrize('inputs,dims,expected_num_components', [
([''], set(), 1),
(['a'], set(), 1),
(['a'], set('a'), 1),
(['a', 'a'], set(), 2),
(['a', 'a'], set('a'), 1),
(['a', 'a', 'b', 'b'], set(), 4),
(['a', 'a', 'b', 'b'], set('a'), 3),
(['a', 'a', 'b', 'b'], set('b'), 3),
(['a', 'a', 'b', 'b'], set('ab'), 2),
(['a', 'ab', 'b'], set(), 3),
(['a', 'ab', 'b'], set('a'), 2),
(['a', 'ab', 'b'], set('b'), 2),
(['a', 'ab', 'b'], set('ab'), 1),
(['a', 'ab', 'bc', 'c'], set(), 4),
(['a', 'ab', 'bc', 'c'], set('c'), 3),
(['a', 'ab', 'bc', 'c'], set('b'), 3),
(['a', 'ab', 'bc', 'c'], set('a'), 3),
(['a', 'ab', 'bc', 'c'], set('ac'), 2),
(['a', 'ab', 'bc', 'c'], set('abc'), 1),
])
def test_partition(inputs, dims, expected_num_components):
sizes = dict(zip('abc', [2, 3, 4]))
terms = [random_tensor(OrderedDict((s, bint(sizes[s])) for s in input_))
for input_ in inputs]
components = list(_partition(terms, dims))

# Check that result is a partition.
expected_terms = sorted(terms, key=id)
actual_terms = sorted((x for c in components for x in c[0]), key=id)
assert actual_terms == expected_terms
assert dims == set.union(set(), *(c[1] for c in components))

# Check that the partition is not too coarse.
assert len(components) == expected_num_components

# Check that partition is not too fine.
component_dict = {x: i for i, (terms, _) in enumerate(components) for x in terms}
for x in terms:
for y in terms:
if x is not y:
if dims.intersection(x.inputs, y.inputs):
assert component_dict[x] == component_dict[y]


@pytest.mark.parametrize('sum_op,prod_op', [(ops.add, ops.mul), (ops.logaddexp, ops.add)])
@pytest.mark.parametrize('inputs,plates', [('a,abi,bcij', 'ij')])
@pytest.mark.parametrize('vars1,vars2', [
('', 'abcij'),
('c', 'abij'),
('cj', 'abi'),
('bcj', 'ai'),
('bcij', 'a'),
('abcij', ''),
])
def test_partial_sum_product(sum_op, prod_op, inputs, plates, vars1, vars2):
inputs = inputs.split(',')
factors = [random_tensor(OrderedDict((d, bint(2)) for d in ds)) for ds in inputs]
plates = frozenset(plates)
vars1 = frozenset(vars1)
vars2 = frozenset(vars2)

factors1 = partial_sum_product(sum_op, prod_op, factors, vars1, plates)
factors2 = partial_sum_product(sum_op, prod_op, factors1, vars2, plates)
actual = reduce(prod_op, factors2)

expected = sum_product(sum_op, prod_op, factors, vars1 | vars2, plates)
assert_close(actual, expected)

0 comments on commit e46c2df

Please sign in to comment.