-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Sketch funsor.contract module * Add unit test for _partition * Add tests for partial_sum_product() * Fix test in python 3
- Loading branch information
Showing
5 changed files
with
193 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
from collections import defaultdict, OrderedDict | ||
from six.moves import reduce | ||
|
||
from funsor.terms import Funsor | ||
|
||
|
||
def _partition(terms, sum_vars): | ||
# Construct a bipartite graph between terms and the vars | ||
neighbors = OrderedDict([(t, []) for t in terms]) | ||
for term in terms: | ||
for dim in term.inputs.keys(): | ||
if dim in sum_vars: | ||
neighbors[term].append(dim) | ||
neighbors.setdefault(dim, []).append(term) | ||
|
||
# Partition the bipartite graph into connected components for contraction. | ||
components = [] | ||
while neighbors: | ||
v, pending = neighbors.popitem() | ||
component = OrderedDict([(v, None)]) # used as an OrderedSet | ||
for v in pending: | ||
component[v] = None | ||
while pending: | ||
v = pending.pop() | ||
for v in neighbors.pop(v): | ||
if v not in component: | ||
component[v] = None | ||
pending.append(v) | ||
|
||
# Split this connected component into tensors and dims. | ||
component_terms = tuple(v for v in component if isinstance(v, Funsor)) | ||
if component_terms: | ||
component_dims = frozenset(v for v in component if not isinstance(v, Funsor)) | ||
components.append((component_terms, component_dims)) | ||
return components | ||
|
||
|
||
def partial_sum_product(sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset()): | ||
""" | ||
Performs partial sum-product contraction of a collection of factors. | ||
:return: a list of partially contracted Funsors. | ||
:rtype: list | ||
""" | ||
assert callable(sum_op) | ||
assert callable(prod_op) | ||
assert isinstance(factors, (tuple, list)) | ||
assert all(isinstance(f, Funsor) for f in factors) | ||
assert isinstance(eliminate, frozenset) | ||
assert isinstance(plates, frozenset) | ||
sum_vars = eliminate - plates | ||
|
||
var_to_ordinal = {} | ||
ordinal_to_factors = defaultdict(list) | ||
for f in factors: | ||
ordinal = plates.intersection(f.inputs) | ||
ordinal_to_factors[ordinal].append(f) | ||
for var in sum_vars.intersection(f.inputs): | ||
var_to_ordinal[var] = var_to_ordinal.get(var, ordinal) & ordinal | ||
|
||
ordinal_to_vars = defaultdict(set) | ||
for var, ordinal in var_to_ordinal.items(): | ||
ordinal_to_vars[ordinal].add(var) | ||
|
||
results = [] | ||
while ordinal_to_factors: | ||
leaf = max(ordinal_to_factors, key=len) | ||
leaf_factors = ordinal_to_factors.pop(leaf) | ||
leaf_reduce_vars = ordinal_to_vars[leaf] | ||
for (group_factors, group_vars) in _partition(leaf_factors, leaf_reduce_vars): | ||
f = reduce(prod_op, group_factors).reduce(sum_op, group_vars) | ||
remaining_sum_vars = sum_vars.intersection(f.inputs) | ||
if not remaining_sum_vars: | ||
results.append(f.reduce(prod_op, leaf & eliminate)) | ||
else: | ||
new_plates = frozenset().union( | ||
*(var_to_ordinal[v] for v in remaining_sum_vars)) | ||
if new_plates == leaf: | ||
raise ValueError("intractable!") | ||
if not (leaf - new_plates).issubset(eliminate): | ||
raise ValueError("cannot reduce {} before {}".format( | ||
remaining_sum_vars, (leaf - new_plates) - eliminate)) | ||
f = f.reduce(prod_op, leaf - new_plates) | ||
ordinal_to_factors[new_plates].append(f) | ||
|
||
return results | ||
|
||
|
||
def sum_product(sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset()): | ||
""" | ||
Performs sum-product contraction of a collection of factors. | ||
:return: a single contracted Funsor. | ||
:rtype: :class:`~funsor.terms.Funsor` | ||
""" | ||
factors = partial_sum_product(sum_op, prod_op, factors, eliminate, plates) | ||
return reduce(prod_op, factors) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
from collections import OrderedDict | ||
|
||
import pytest | ||
from six.moves import reduce | ||
|
||
import funsor.ops as ops | ||
from funsor.contract import _partition, partial_sum_product, sum_product | ||
from funsor.domains import bint | ||
from funsor.testing import assert_close, random_tensor | ||
|
||
|
||
@pytest.mark.parametrize('inputs,dims,expected_num_components', [ | ||
([''], set(), 1), | ||
(['a'], set(), 1), | ||
(['a'], set('a'), 1), | ||
(['a', 'a'], set(), 2), | ||
(['a', 'a'], set('a'), 1), | ||
(['a', 'a', 'b', 'b'], set(), 4), | ||
(['a', 'a', 'b', 'b'], set('a'), 3), | ||
(['a', 'a', 'b', 'b'], set('b'), 3), | ||
(['a', 'a', 'b', 'b'], set('ab'), 2), | ||
(['a', 'ab', 'b'], set(), 3), | ||
(['a', 'ab', 'b'], set('a'), 2), | ||
(['a', 'ab', 'b'], set('b'), 2), | ||
(['a', 'ab', 'b'], set('ab'), 1), | ||
(['a', 'ab', 'bc', 'c'], set(), 4), | ||
(['a', 'ab', 'bc', 'c'], set('c'), 3), | ||
(['a', 'ab', 'bc', 'c'], set('b'), 3), | ||
(['a', 'ab', 'bc', 'c'], set('a'), 3), | ||
(['a', 'ab', 'bc', 'c'], set('ac'), 2), | ||
(['a', 'ab', 'bc', 'c'], set('abc'), 1), | ||
]) | ||
def test_partition(inputs, dims, expected_num_components): | ||
sizes = dict(zip('abc', [2, 3, 4])) | ||
terms = [random_tensor(OrderedDict((s, bint(sizes[s])) for s in input_)) | ||
for input_ in inputs] | ||
components = list(_partition(terms, dims)) | ||
|
||
# Check that result is a partition. | ||
expected_terms = sorted(terms, key=id) | ||
actual_terms = sorted((x for c in components for x in c[0]), key=id) | ||
assert actual_terms == expected_terms | ||
assert dims == set.union(set(), *(c[1] for c in components)) | ||
|
||
# Check that the partition is not too coarse. | ||
assert len(components) == expected_num_components | ||
|
||
# Check that partition is not too fine. | ||
component_dict = {x: i for i, (terms, _) in enumerate(components) for x in terms} | ||
for x in terms: | ||
for y in terms: | ||
if x is not y: | ||
if dims.intersection(x.inputs, y.inputs): | ||
assert component_dict[x] == component_dict[y] | ||
|
||
|
||
@pytest.mark.parametrize('sum_op,prod_op', [(ops.add, ops.mul), (ops.logaddexp, ops.add)]) | ||
@pytest.mark.parametrize('inputs,plates', [('a,abi,bcij', 'ij')]) | ||
@pytest.mark.parametrize('vars1,vars2', [ | ||
('', 'abcij'), | ||
('c', 'abij'), | ||
('cj', 'abi'), | ||
('bcj', 'ai'), | ||
('bcij', 'a'), | ||
('abcij', ''), | ||
]) | ||
def test_partial_sum_product(sum_op, prod_op, inputs, plates, vars1, vars2): | ||
inputs = inputs.split(',') | ||
factors = [random_tensor(OrderedDict((d, bint(2)) for d in ds)) for ds in inputs] | ||
plates = frozenset(plates) | ||
vars1 = frozenset(vars1) | ||
vars2 = frozenset(vars2) | ||
|
||
factors1 = partial_sum_product(sum_op, prod_op, factors, vars1, plates) | ||
factors2 = partial_sum_product(sum_op, prod_op, factors1, vars2, plates) | ||
actual = reduce(prod_op, factors2) | ||
|
||
expected = sum_product(sum_op, prod_op, factors, vars1 | vars2, plates) | ||
assert_close(actual, expected) |