Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pattern for MultivariateNormal(affine) #245

Merged
merged 37 commits into from
Sep 21, 2019
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
cd694eb
WIP sketch multivariate pattern recognition
fritzo Sep 19, 2019
3c51566
WIP more progress on eager_mvn and .extract_affine()
fritzo Sep 19, 2019
c0ec203
Add minimal failing tests for sensor fusion example (#247)
fritzo Sep 19, 2019
85dc66a
Add ReshapeOp, .shape property, .reshape() method
fritzo Sep 19, 2019
4580d5f
Add Tensor special case and tests
fritzo Sep 19, 2019
98b4ce1
Merge branch 'master' into multivariate-affine
fritzo Sep 20, 2019
708307e
Start writing a test
fritzo Sep 20, 2019
fb50652
Fix bugs in Einsum
fritzo Sep 20, 2019
4ee22e6
Add more affine tests
fritzo Sep 20, 2019
155faaa
Add more tests
fritzo Sep 20, 2019
07b3da6
Revert changes to cnf.py
fritzo Sep 20, 2019
279d7a9
Implement affine funsor approximation
fritzo Sep 20, 2019
0a83471
Add test for Einsum batching
fritzo Sep 20, 2019
437c3a4
Merge branch 'extract-affine' into multivariate-affine
fritzo Sep 20, 2019
d9138fc
Fix docs
fritzo Sep 20, 2019
ad70e1d
Address review comments
fritzo Sep 20, 2019
8df92c0
Merge branch 'extract-affine' into multivariate-affine
fritzo Sep 20, 2019
30dab86
Merge branch 'master' into multivariate-affine
fritzo Sep 20, 2019
c7d10c6
Add sensor fusion test using dist.MultivariateNormal
fritzo Sep 20, 2019
8d57895
WIP attempt to fix eager_mvn
fritzo Sep 20, 2019
4f55bd6
Add ops.matmul
fritzo Sep 20, 2019
3e543d3
Use matmul and expand=True in distributions and pyro.convert
fritzo Sep 20, 2019
6161a86
Merge branch 'matmul-op' into multivariate-affine
fritzo Sep 20, 2019
45878c5
More fixes to eager_mvn
fritzo Sep 20, 2019
dc848d6
Merge branch 'master' into multivariate-affine
fritzo Sep 20, 2019
4eec408
Remove debug statements
fritzo Sep 20, 2019
f36e864
Fix some typos
fritzo Sep 20, 2019
3e8eed1
Fix shape errors
fritzo Sep 20, 2019
cc181f0
Fix const computation
fritzo Sep 20, 2019
ffe0c55
Add failing tests
fritzo Sep 20, 2019
4768323
Add more failing tests
fritzo Sep 20, 2019
914b223
Address review comments
fritzo Sep 21, 2019
b17cc7e
Simplify
fritzo Sep 21, 2019
9e001bb
Simplify more
fritzo Sep 21, 2019
3c9e3e5
Fix bugs
fritzo Sep 21, 2019
2da0e3b
Strengthen tests
fritzo Sep 21, 2019
c7f71a5
Revert unnecessary change
fritzo Sep 21, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from . import (
adjoint,
affine,
cnf,
delta,
distributions,
Expand Down Expand Up @@ -38,6 +39,7 @@
'Tensor',
'Variable',
'adjoint',
'affine',
'arange',
'backward',
'bint',
Expand Down
48 changes: 48 additions & 0 deletions funsor/affine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from collections import OrderedDict

import opt_einsum
import torch

from funsor.interpreter import gensym
from funsor.terms import Lambda, Variable, bint
from funsor.torch import Tensor


def extract_affine(fn):
"""
Extracts an affine representation of a funsor, which is exact for affine
funsors and approximate otherwise. For affine funsors this satisfies::

x = Contraction(...)
assert x.is_affine
const, coeffs = x.extract_affine()
y = sum(Einsum(eqn, (coeff, Variable(var, coeff.output)))
for var, (coeff, eqn) in coeffs.items())
assert_close(y, x)

:param Funsor fn: A funsor.
:return: A pair ``(const, coeffs)`` where const is a funsor with no real
inputs and ``coeffs`` is an OrderedDict mapping input name to a
``(coefficient, eqn)`` pair in einsum form.
:rtype: tuple
"""
# Determine constant part by evaluating fn at zero.
real_inputs = OrderedDict((k, v) for k, v in fn.inputs.items() if v.dtype == 'real')
zeros = {k: Tensor(torch.zeros(v.shape)) for k, v in real_inputs.items()}
const = fn(**zeros)

# Determine linear coefficients by evaluating fn on basis vectors.
name = gensym('probe')
coeffs = OrderedDict()
for k, v in real_inputs.items():
dim = v.num_elements
var = Variable(name, bint(dim))
subs = zeros.copy()
subs[k] = Tensor(torch.eye(dim).reshape((dim,) + v.shape))[var]
coeff = Lambda(var, fn(**subs) - const).reshape(v.shape + const.shape)
inputs1 = ''.join(map(opt_einsum.get_symbol, range(len(coeff.shape))))
inputs2 = inputs1[:len(v.shape)]
output = inputs1[len(v.shape):]
eqn = f'{inputs1},{inputs2}->{output}'
coeffs[k] = coeff, eqn
return const, coeffs
2 changes: 2 additions & 0 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
class Contraction(Funsor):
"""
Declarative representation of a finitary sum-product operation

:ivar bool is_affine: Whether this contraction is affine.
"""
def __init__(self, red_op, bin_op, reduced_vars, terms):
terms = (terms,) if isinstance(terms, Funsor) else terms
Expand Down
74 changes: 65 additions & 9 deletions funsor/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import funsor.delta
import funsor.ops as ops
from funsor.affine import extract_affine
from funsor.cnf import Contraction
from funsor.domains import bint, reals
from funsor.gaussian import BlockMatrix, BlockVector, Gaussian, cholesky_inverse
Expand Down Expand Up @@ -451,22 +452,77 @@ def eager_mvn(loc, scale_tril, value):


# Create a Gaussian from a ground observation.
@eager.register(MultivariateNormal, Tensor, Tensor, Variable)
@eager.register(MultivariateNormal, Variable, Tensor, Tensor)
# TODO refactor this logic into Gaussian.eager_subs() and
# here return Gaussian(...scale_tril...)(value=loc-value).
@eager.register(MultivariateNormal, (Variable, Contraction), Tensor, (Variable, Contraction))
@eager.register(MultivariateNormal, (Variable, Contraction), Tensor, Tensor)
@eager.register(MultivariateNormal, Tensor, Tensor, (Variable, Contraction))
def eager_mvn(loc, scale_tril, value):
if isinstance(loc, Variable):
loc, value = value, loc
affine = loc - value
if not isinstance(affine, Contraction):
return None # lazy

const, coeffs = extract_affine(affine)
if not isinstance(const, Tensor):
return None # lazy
if not all(isinstance(coeff, Tensor) for coeff, _ in coeffs.values()):
return None # lazy

# Dovetail to avoid variable name collision in einsum.
equations1 = [''.join(c if c in '.,->' else chr(ord(c) * 2))
for _, eqn in coeffs.values() for c in eqn]
equations2 = [''.join(c if c in '.,->' else chr(ord(c) * 2 + 1))
for _, eqn in coeffs.values() for c in eqn]

dim, = loc.output.shape
inputs, (loc, scale_tril) = align_tensors(loc, scale_tril)
inputs.update(value.inputs)
int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real')
real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real')
assert len(real_inputs) == len(coeffs)
tensors = [coeffs[k] for k in real_inputs] + [const, scale_tril]
inputs, tensors = align_tensors(*tensors)

precision = cholesky_inverse(scale_tril)
# Incorporate scale_tril before broadcasting.
scale_tril = tensors[-1]
prec_sqrt = cholesky_inverse(scale_tril) # FIXME is transpose correct?
for i in range(len(coeffs)):
coeff = tensors[i]
inputs1, output1 = equations1[i].split('->')
input11, _ = inputs1.split(',')
inputs2, output2 = equations2[i].split('->')
input21, _ = inputs2.split(',')
coeff = torch.einsum(f'...{input11},...{output1}{output2}->...{input21}',
coeff, prec_sqrt)
tensors[i] = coeff

tensors = torch.broadcast_tensors(*tensors)
coeffs, const, scale_tril = tensors[:-2], tensors[-2], tensors[-1]
batch_shape = const.shape
dim = sum(d.num_elements for d in real_inputs.values())

loc = BlockVector(batch_shape + (dim,))
loc[..., 0] = -const / coeffs[0] # FIXME consider directly constructing info_vec

precision = BlockMatrix(batch_shape + (dim, dim))
offset1 = 0
for i1, (v1, c1) in enumerate(zip(real_inputs, coeffs)):
slice1 = slice(offset1, offset1 + real_inputs[v1].num_elements)
input11, input12 = equations1[i1].split('->')[0].split(',')
offset2 = 0
for i2, (v2, c2) in enumerate(zip(real_inputs, coeffs)):
slice2 = slice(offset2, offset2 + real_inputs[v2].num_elements)
input21, input22 = equations2[i2].split('->')[0].split(',')
precision[..., slice1, slice2] = torch.einsum(
f'...{input11}{input12},...{input21}{input22}->...{input12}{input22}',
c1, c2)
offset2 = slice2.stop
offset1 = slice1.stop

loc = loc.as_tensor()
precision = precision.as_tensor()
info_vec = precision.matmul(loc.unsqueeze(-1)).squeeze(-1)

log_prob = (-0.5 * dim * math.log(2 * math.pi)
- scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1)
- 0.5 * (loc * info_vec).sum(-1))
int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real')
return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, inputs)


Expand Down
32 changes: 28 additions & 4 deletions funsor/torch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import functools
import itertools
import warnings
from collections import OrderedDict
from functools import reduce

import opt_einsum
import torch
from contextlib2 import contextmanager
from multipledispatch import dispatch
Expand Down Expand Up @@ -750,8 +752,8 @@ def __init__(self, equation, operands):
for ein_input, x in zip(ein_inputs, operands):
assert x.dtype == 'real'
inputs.update(x.inputs)
assert len(ein_inputs) == len(x.output.shape)
for name, size in zip(ein_inputs, x.output.shape):
assert len(ein_input) == len(x.output.shape)
for name, size in zip(ein_input, x.output.shape):
other_size = size_dict.setdefault(name, size)
if other_size != size:
raise ValueError("Size mismatch at {}: {} vs {}"
Expand All @@ -771,8 +773,30 @@ def __str__(self):
@eager.register(Einsum, str, tuple)
def eager_einsum(equation, operands):
if all(isinstance(x, Tensor) for x in operands):
inputs, tensors = align_tensors(*operands)
data = torch.einsum(equation, tensors)
# Make new symbols for inputs of operands.
inputs = OrderedDict()
for x in operands:
inputs.update(x.inputs)
symbols = set(equation)
get_symbol = iter(map(opt_einsum.get_symbol, itertools.count()))
new_symbols = {}
for k in inputs:
symbol = next(get_symbol)
while symbol in symbols:
symbol = next(get_symbol)
symbols.add(symbol)
new_symbols[k] = symbol

# Manually broadcast using einsum symbols.
assert '.' not in equation
ins, out = equation.split('->')
ins = ins.split(',')
ins = [''.join(new_symbols[k] for k in x.inputs) + x_out
for x, x_out in zip(operands, ins)]
out = ''.join(new_symbols[k] for k in inputs) + out
equation = ','.join(ins) + '->' + out

data = torch.einsum(equation, [x.data for x in operands])
return Tensor(data, inputs)

return None # defer to default implementation
Expand Down
44 changes: 42 additions & 2 deletions test/test_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import pytest
import torch

from funsor.affine import extract_affine
from funsor.cnf import Contraction
from funsor.domains import bint, reals
from funsor.terms import Number, Variable
from funsor.testing import check_funsor
from funsor.torch import Tensor
from funsor.testing import assert_close, check_funsor, random_tensor
from funsor.torch import Einsum, Tensor

SMOKE_TESTS = [
('t+x', Contraction),
Expand Down Expand Up @@ -75,3 +76,42 @@ def test_affine_subs(expr, expected_type, expected_inputs):
assert isinstance(result, expected_type)
check_funsor(result, expected_inputs, expected_output)
assert result.is_affine


@pytest.mark.parametrize('expr', [
"Variable('x', reals()) + 0.5",
"Variable('x', reals(2)) + Variable('y', reals(2))",
"Variable('x', reals(2)) + torch.ones(2)",
"Variable('x', reals(2)) * torch.randn(2)",
"Variable('x', reals(2)) * torch.randn(2) + torch.ones(2)",
"Variable('x', reals(2)) + Tensor(torch.randn(3, 2), OrderedDict(i=bint(3)))",
"Einsum('abcd,ac->bd',"
" (Tensor(torch.randn(2, 3, 4, 5)), Variable('x', reals(2, 4))))",
"Tensor(torch.randn(3, 5)) + Einsum('abcd,ac->bd',"
" (Tensor(torch.randn(2, 3, 4, 5)), Variable('x', reals(2, 4))))",
"Variable('x', reals(2, 8))[0] + torch.randn(8)",
"Variable('x', reals(2, 8))[Variable('i', bint(2))] / 4 - 3.5",
])
def test_extract_affine(expr):
x = eval(expr)
assert isinstance(x, (Contraction, Einsum))
real_inputs = OrderedDict((k, d) for k, d in x.inputs.items()
if d.dtype == 'real')

const, coeffs = extract_affine(x)
assert isinstance(const, Tensor)
assert const.shape == x.shape
assert list(coeffs) == list(real_inputs)
for name, (coeff, eqn) in coeffs.items():
assert isinstance(name, str)
assert isinstance(coeff, Tensor)
assert isinstance(eqn, str)

subs = {k: random_tensor(OrderedDict(), d) for k, d in real_inputs.items()}
expected = x(**subs)
assert isinstance(expected, Tensor)

actual = const + sum(Einsum(eqn, (coeff, subs[k]))
for k, (coeff, eqn) in coeffs.items())
assert isinstance(actual, Tensor)
assert_close(actual, expected)
29 changes: 27 additions & 2 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ def test_align():
assert x(i=i, j=j, k=k) == y(i=i, j=j, k=k)


@pytest.mark.parametrize('equation', [
EINSUM_EXAMPLES = [
'a->a',
'a,a->a',
'a,b->',
Expand All @@ -689,7 +689,10 @@ def test_align():
'ab,ba->ab',
'ab,ba->ba',
'ab,bc->ac',
])
]


@pytest.mark.parametrize('equation', EINSUM_EXAMPLES)
def test_einsum(equation):
sizes = dict(a=2, b=3, c=4)
inputs, outputs = equation.split('->')
Expand All @@ -701,6 +704,28 @@ def test_einsum(equation):
assert_close(actual, expected, atol=1e-5, rtol=None)


@pytest.mark.parametrize('equation', EINSUM_EXAMPLES)
@pytest.mark.parametrize('batch1', ['', 'i', 'j', 'ij'])
@pytest.mark.parametrize('batch2', ['', 'i', 'j', 'ij'])
def test_batched_einsum(equation, batch1, batch2):
inputs, output = equation.split('->')
inputs = inputs.split(',')

sizes = dict(a=2, b=3, c=4, i=5, j=6)
batch1 = OrderedDict([(k, bint(sizes[k])) for k in batch1])
batch2 = OrderedDict([(k, bint(sizes[k])) for k in batch2])
funsors = [random_tensor(batch, reals(*(sizes[d] for d in dims)))
for batch, dims in zip([batch1, batch2], inputs)]
actual = Einsum(equation, tuple(funsors))

_equation = ','.join('...' + i for i in inputs) + '->...' + output
inputs, tensors = align_tensors(*funsors)
batch = tuple(v.size for v in inputs.values())
tensors = [x.expand(batch + f.shape) for (x, f) in zip(tensors, funsors)]
expected = Tensor(torch.einsum(_equation, tensors), inputs)
assert_close(actual, expected, atol=1e-5, rtol=None)


@pytest.mark.parametrize('y_shape', [(), (4,), (4, 5)], ids=str)
@pytest.mark.parametrize('xy_shape', [(), (6,), (6, 7)], ids=str)
@pytest.mark.parametrize('x_shape', [(), (2,), (2, 3)], ids=str)
Expand Down