Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scale factors across plate dims in partial_sum_product #606

Merged
merged 16 commits into from
Aug 31, 2023
48 changes: 44 additions & 4 deletions funsor/sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,13 @@ def partial_unroll(factors, eliminate=frozenset(), plate_to_step=dict()):


def partial_sum_product(
sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset(), pedantic=False
sum_op,
prod_op,
fritzo marked this conversation as resolved.
Show resolved Hide resolved
factors,
eliminate=frozenset(),
plates=frozenset(),
pedantic=False,
plate_to_scale={},
):
"""
Performs partial sum-product contraction of a collection of factors.
Expand All @@ -217,6 +223,15 @@ def partial_sum_product(
assert all(isinstance(f, Funsor) for f in factors)
assert isinstance(eliminate, frozenset)
assert isinstance(plates, frozenset)
assert isinstance(plate_to_scale, dict)

if plate_to_scale:
if sum_op is ops.logaddexp and prod_op is ops.add:
pow_op = ops.mul
elif sum_op is ops.add and prod_op is ops.mul:
pow_op = ops.pow
else:
raise ValueError("should not be here!")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be NotImplementedError

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this out to a PROD_TO_POWER dict or similar, see ops


if pedantic:
var_to_errors = defaultdict(lambda: eliminate)
Expand Down Expand Up @@ -256,7 +271,16 @@ def partial_sum_product(
f = reduce(prod_op, group_factors).reduce(sum_op, group_vars & eliminate)
remaining_sum_vars = sum_vars.intersection(f.inputs)
if not remaining_sum_vars:
results.append(f.reduce(prod_op, leaf & eliminate))
f = f.reduce(prod_op, leaf & eliminate)
f_scales = [
plate_to_scale[plate]
for plate in leaf & eliminate
if plate in plate_to_scale
]
if f_scales:
scale = reduce(ops.mul, f_scales)
f = pow_op(f, scale)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we wrap this in an if plate_to_scale: guard to improve readability?

results.append(f)
else:
new_plates = frozenset().union(
*(var_to_ordinal[v] for v in remaining_sum_vars)
Expand Down Expand Up @@ -306,6 +330,14 @@ def partial_sum_product(
reduced_plates = leaf - new_plates
assert reduced_plates.issubset(eliminate)
f = f.reduce(prod_op, reduced_plates)
f_scales = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: if plate_to_scale: ...

plate_to_scale[plate]
for plate in reduced_plates
if plate in plate_to_scale
]
if f_scales:
scale = reduce(ops.mul, f_scales)
f = pow_op(f, scale)
ordinal_to_factors[new_plates].append(f)

return results
Expand Down Expand Up @@ -571,15 +603,23 @@ def modified_partial_sum_product(


def sum_product(
sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset(), pedantic=False
sum_op,
prod_op,
factors,
eliminate=frozenset(),
plates=frozenset(),
pedantic=False,
plate_to_scale={},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default to None, maybe comment on datatype

):
"""
Performs sum-product contraction of a collection of factors.

:return: a single contracted Funsor.
:rtype: :class:`~funsor.terms.Funsor`
"""
factors = partial_sum_product(sum_op, prod_op, factors, eliminate, plates, pedantic)
factors = partial_sum_product(
sum_op, prod_op, factors, eliminate, plates, pedantic, plate_to_scale
)
return reduce(prod_op, factors, Number(UNITS[prod_op]))


Expand Down
96 changes: 95 additions & 1 deletion test/test_sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
sum_product,
)
from funsor.tensor import Tensor, get_default_prototype
from funsor.terms import Variable
from funsor.terms import Cat, Variable
from funsor.testing import assert_close, random_gaussian, random_tensor
from funsor.util import get_backend

Expand Down Expand Up @@ -2899,3 +2899,97 @@ def test_mixed_sequential_sum_product(duration, num_segments):
)

assert_close(actual, expected)


@pytest.mark.parametrize(
"sum_op,prod_op",
[(ops.logaddexp, ops.add), (ops.add, ops.mul)],
)
@pytest.mark.parametrize("scale", [1, 2])
def test_partial_sum_product_scale_1(sum_op, prod_op, scale):
f1 = random_tensor(OrderedDict(a=Bint[2]))
f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[3]))

eliminate = frozenset("ai")
plates = frozenset("i")

# Actual result based on applying scaling
factors = [f1, f2]
scales = {"i": scale}
actual = sum_product(
sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales
)

# Expected result based on concatenating factors
f3 = Cat("i", (f2,) * scale)
factors = [f1, f3]
expected = sum_product(sum_op, prod_op, factors, eliminate, plates)

assert_close(actual, expected, atol=1e-4, rtol=1e-4)


@pytest.mark.parametrize(
"sum_op,prod_op",
[(ops.logaddexp, ops.add), (ops.add, ops.mul)],
)
@pytest.mark.parametrize("scale_i", [1, 2])
@pytest.mark.parametrize("scale_j", [1, 3])
def test_partial_sum_product_scale_2(sum_op, prod_op, scale_i, scale_j):
f1 = random_tensor(OrderedDict(a=Bint[2]))
f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[3]))
f3 = random_tensor(OrderedDict(a=Bint[2], j=Bint[4]))

eliminate = frozenset("aij")
plates = frozenset("ij")

# Actual result based on applying scaling
factors = [f1, f2, f3]
scales = {"i": scale_i, "j": scale_j}
actual = sum_product(
sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales
)

# Expected result based on concatenating factors
f4 = Cat("i", (f2,) * scale_i)
f5 = Cat("j", (f3,) * scale_j)
factors = [f1, f4, f5]
expected = sum_product(sum_op, prod_op, factors, eliminate, plates)

assert_close(actual, expected, atol=1e-4, rtol=1e-4)


@pytest.mark.parametrize(
"sum_op,prod_op",
[(ops.logaddexp, ops.add), (ops.add, ops.mul)],
)
@pytest.mark.parametrize("scale_i", [1, 2])
@pytest.mark.parametrize("scale_j", [1, 3])
@pytest.mark.parametrize("scale_k", [1, 4])
def test_partial_sum_product_scale_3(sum_op, prod_op, scale_i, scale_j, scale_k):
f1 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2]))
f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2], j=Bint[3]))
f3 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2], j=Bint[3], k=Bint[3]))

eliminate = frozenset("aijk")
plates = frozenset("ijk")

# Actual result based on applying scaling
factors = [f1, f2, f3]
scales = {"i": scale_i, "j": scale_j, "k": scale_k}
actual = sum_product(
sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales
)

# Expected result based on concatenating factors
f4 = Cat("i", (f1,) * scale_i)
# concatenate across multiple dims
f5 = Cat("i", (f2,) * scale_i)
f5 = Cat("j", (f5,) * scale_j)
# concatenate across multiple dims
f6 = Cat("i", (f3,) * scale_i)
f6 = Cat("j", (f6,) * scale_j)
f6 = Cat("k", (f6,) * scale_k)
factors = [f4, f5, f6]
expected = sum_product(sum_op, prod_op, factors, eliminate, plates)

assert_close(actual, expected, atol=1e-4, rtol=1e-4)