Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pattern for MultivariateNormal(affine) #245

Merged
merged 37 commits into from
Sep 21, 2019
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
cd694eb
WIP sketch multivariate pattern recognition
fritzo Sep 19, 2019
3c51566
WIP more progress on eager_mvn and .extract_affine()
fritzo Sep 19, 2019
c0ec203
Add minimal failing tests for sensor fusion example (#247)
fritzo Sep 19, 2019
85dc66a
Add ReshapeOp, .shape property, .reshape() method
fritzo Sep 19, 2019
4580d5f
Add Tensor special case and tests
fritzo Sep 19, 2019
98b4ce1
Merge branch 'master' into multivariate-affine
fritzo Sep 20, 2019
708307e
Start writing a test
fritzo Sep 20, 2019
fb50652
Fix bugs in Einsum
fritzo Sep 20, 2019
4ee22e6
Add more affine tests
fritzo Sep 20, 2019
155faaa
Add more tests
fritzo Sep 20, 2019
07b3da6
Revert changes to cnf.py
fritzo Sep 20, 2019
279d7a9
Implement affine funsor approximation
fritzo Sep 20, 2019
0a83471
Add test for Einsum batching
fritzo Sep 20, 2019
437c3a4
Merge branch 'extract-affine' into multivariate-affine
fritzo Sep 20, 2019
d9138fc
Fix docs
fritzo Sep 20, 2019
ad70e1d
Address review comments
fritzo Sep 20, 2019
8df92c0
Merge branch 'extract-affine' into multivariate-affine
fritzo Sep 20, 2019
30dab86
Merge branch 'master' into multivariate-affine
fritzo Sep 20, 2019
c7d10c6
Add sensor fusion test using dist.MultivariateNormal
fritzo Sep 20, 2019
8d57895
WIP attempt to fix eager_mvn
fritzo Sep 20, 2019
4f55bd6
Add ops.matmul
fritzo Sep 20, 2019
3e543d3
Use matmul and expand=True in distributions and pyro.convert
fritzo Sep 20, 2019
6161a86
Merge branch 'matmul-op' into multivariate-affine
fritzo Sep 20, 2019
45878c5
More fixes to eager_mvn
fritzo Sep 20, 2019
dc848d6
Merge branch 'master' into multivariate-affine
fritzo Sep 20, 2019
4eec408
Remove debug statements
fritzo Sep 20, 2019
f36e864
Fix some typos
fritzo Sep 20, 2019
3e8eed1
Fix shape errors
fritzo Sep 20, 2019
cc181f0
Fix const computation
fritzo Sep 20, 2019
ffe0c55
Add failing tests
fritzo Sep 20, 2019
4768323
Add more failing tests
fritzo Sep 20, 2019
914b223
Address review comments
fritzo Sep 21, 2019
b17cc7e
Simplify
fritzo Sep 21, 2019
9e001bb
Simplify more
fritzo Sep 21, 2019
3c9e3e5
Fix bugs
fritzo Sep 21, 2019
2da0e3b
Strengthen tests
fritzo Sep 21, 2019
c7f71a5
Revert unnecessary change
fritzo Sep 21, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions funsor/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from funsor.torch import Tensor


# FIXME change this to a sound but incomplete test using pattern matching.
def is_affine(fn):
return True


def extract_affine(fn):
"""
Extracts an affine representation of a funsor, which is exact for affine
Expand Down
2 changes: 2 additions & 0 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
class Contraction(Funsor):
"""
Declarative representation of a finitary sum-product operation

:ivar bool is_affine: Whether this contraction is affine.
"""
def __init__(self, red_op, bin_op, reduced_vars, terms):
terms = (terms,) if isinstance(terms, Funsor) else terms
Expand Down
88 changes: 73 additions & 15 deletions funsor/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

import funsor.delta
import funsor.ops as ops
from funsor.affine import extract_affine, is_affine
from funsor.cnf import Contraction
from funsor.domains import bint, reals
from funsor.gaussian import BlockMatrix, BlockVector, Gaussian, cholesky_inverse
from funsor.gaussian import BlockMatrix, BlockVector, Gaussian
from funsor.interpreter import interpretation
from funsor.terms import Funsor, FunsorMeta, Number, Variable, eager, lazy, to_funsor
from funsor.torch import Tensor, align_tensors, ignore_jit_warnings, materialize, torch_stack
Expand Down Expand Up @@ -449,23 +450,80 @@ def eager_mvn(loc, scale_tril, value):


# Create a Gaussian from a ground observation.
@eager.register(MultivariateNormal, Tensor, Tensor, Variable)
@eager.register(MultivariateNormal, Variable, Tensor, Tensor)
# TODO refactor this logic into Gaussian.eager_subs() and
# here return Gaussian(...scale_tril...)(value=loc-value).
@eager.register(MultivariateNormal, (Variable, Contraction), Tensor, (Variable, Contraction))
@eager.register(MultivariateNormal, (Variable, Contraction), Tensor, Tensor)
@eager.register(MultivariateNormal, Tensor, Tensor, (Variable, Contraction))
def eager_mvn(loc, scale_tril, value):
if isinstance(loc, Variable):
loc, value = value, loc
assert len(loc.shape) == 1
assert len(scale_tril.shape) == 2
assert value.output == loc.output
if not is_affine(loc) or not is_affine(value):
return None # lazy

# Extract an affine representation.
eye = torch.eye(scale_tril.data.size(-1)).expand(scale_tril.data.shape)
prec_sqrt = Tensor(eye.triangular_solve(scale_tril.data, upper=False).solution,
scale_tril.inputs)
affine = prec_sqrt @ (loc - value)
const, coeffs = extract_affine(affine)
if not isinstance(const, Tensor):
return None # lazy
if not all(isinstance(coeff, Tensor) for coeff, _ in coeffs.values()):
return None # lazy

# Compute log_prob using funsors.
scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), scale_tril.inputs)
log_prob = (-0.5 * scale_diag.shape[0] * math.log(2 * math.pi)
Copy link
Member

@fehiepsi fehiepsi Sep 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: why we use shape[0] instead of shape[-1], scale_diag.log().sum() instead of scale_diag.log().sum(-1), and 0.5 * (const ** 2).sum() instead of 0.5 * (const ** 2).sum(-1) here?

Copy link
Member Author

@fritzo fritzo Sep 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are equivalent: scale_diag.shape[0] == scale_diag.shape[-1]. Here scale_diag is a funsor, and it separates "batch" .inputs from "event" .shape. In fact scale_diag.shape == (dim,) regardless of batching. This also allows us to call .sum() below rather than .sum(-1), since there is only one tensor dimension.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's nice! Thanks for explaining.

- scale_diag.log().sum() - 0.5 * (const ** 2).sum())

# Dovetail to avoid variable name collision in einsum.
equations1 = [''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a')) for c in eqn)
for _, eqn in coeffs.values()]
equations2 = [''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a') + 1) for c in eqn)
for _, eqn in coeffs.values()]

dim, = loc.output.shape
inputs, (loc, scale_tril) = align_tensors(loc, scale_tril)
inputs.update(value.inputs)
int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real')
real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real')
assert tuple(real_inputs) == tuple(coeffs)

precision = cholesky_inverse(scale_tril)
info_vec = precision.matmul(loc.unsqueeze(-1)).squeeze(-1)
log_prob = (-0.5 * dim * math.log(2 * math.pi)
- scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1)
- 0.5 * (loc * info_vec).sum(-1))
return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, inputs)
# Align and broadcast tensors.
neg_const = -const
tensors = [neg_const] + [coeff for coeff, _ in coeffs.values()]
inputs, tensors = align_tensors(*tensors, expand=True)
neg_const, coeffs = tensors[0], tensors[1:]
dim = sum(d.num_elements for d in real_inputs.values())
batch_shape = neg_const.shape[:-1]

info_vec = BlockVector(batch_shape + (dim,))
precision = BlockMatrix(batch_shape + (dim, dim))
offset1 = 0
for i1, (v1, c1) in enumerate(zip(real_inputs, coeffs)):
size1 = real_inputs[v1].num_elements
slice1 = slice(offset1, offset1 + size1)
inputs1, output1 = equations1[i1].split('->')
input11, input12 = inputs1.split(',')
assert input11 == input12 + output1
info_vec[..., slice1] = torch.einsum(
f'...{input11},...{output1}->...{input12}', c1, neg_const) \
.reshape(batch_shape + (size1,))
offset2 = 0
for i2, (v2, c2) in enumerate(zip(real_inputs, coeffs)):
size2 = real_inputs[v2].num_elements
slice2 = slice(offset2, offset2 + size2)
inputs2, output2 = equations2[i2].split('->')
input21, input22 = inputs2.split(',')
assert input21 == input22 + output2
precision[..., slice1, slice2] = torch.einsum(
f'...{input11},...{input22}{output1}->...{input12}{input22}', c1, c2) \
.reshape(batch_shape + (size1, size2))
offset2 += size2
offset1 += size1

info_vec = info_vec.as_tensor()
precision = precision.as_tensor()
inputs.update(real_inputs)
return log_prob + Gaussian(info_vec, precision, inputs)


class Poisson(Distribution):
Expand Down
3 changes: 2 additions & 1 deletion funsor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def log_abs_det_jacobian(x, y):
raise NotImplementedError


# FIXME Most code assumes this is an AssociativeCommutativeOp.
class AssociativeOp(Op):
pass

Expand All @@ -65,7 +66,7 @@ class MulOp(AssociativeOp):
pass


class MatmulOp(AssociativeOp):
class MatmulOp(Op): # Associtive but not commutative.
pass


Expand Down
4 changes: 0 additions & 4 deletions funsor/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,10 +470,6 @@ def eager_binary_tensor_tensor(op, lhs, rhs):
shape = shape[:cut] + (1,) * (lhs_dim - rhs_dim) + shape[cut:]
rhs_data = rhs_data.reshape(shape)

print(f"lhs.data.shape = {lhs.data.shape}")
print(f"rhs.data.shape = {rhs.data.shape}")
print(f"lhs_data.shape = {lhs_data.shape}")
print(f"rhs_data.shape = {rhs_data.shape}")
data = op(lhs_data, rhs_data)
if len(lhs.shape) == 1:
data = data.squeeze(-2)
Expand Down
46 changes: 44 additions & 2 deletions test/examples/test_sensor_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import torch

import funsor.distributions as dist
import funsor.ops as ops
from funsor.cnf import Contraction
from funsor.domains import bint, reals
Expand All @@ -13,8 +14,49 @@
from funsor.torch import Tensor


# This version constructs factors using funsor.distributions.
@pytest.mark.parametrize('state_dim,obs_dim', [(3, 2), (2, 3)])
def test_distributions(state_dim, obs_dim):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jpchen you can follow the idioms of this test in your experiment.

data = Tensor(torch.randn(2, obs_dim))["time"]

bias = Variable("bias", reals(obs_dim))
bias_dist = dist_to_funsor(random_mvn((), obs_dim))(value=bias)

prev = Variable("prev", reals(state_dim))
curr = Variable("curr", reals(state_dim))
trans_mat = Tensor(torch.eye(state_dim) + 0.1 * torch.randn(state_dim, state_dim))
trans_mvn = random_mvn((), state_dim)
trans_dist = dist.MultivariateNormal(
loc=trans_mvn.loc,
scale_tril=trans_mvn.scale_tril,
value=curr - prev @ trans_mat)

state = Variable("state", reals(state_dim))
obs = Variable("obs", reals(obs_dim))
obs_mat = Tensor(torch.randn(state_dim, obs_dim))
obs_mvn = random_mvn((), obs_dim)
obs_dist = dist.MultivariateNormal(
loc=obs_mvn.loc,
scale_tril=obs_mvn.scale_tril,
value=state @ obs_mat + bias - obs)

log_prob = 0
log_prob += bias_dist

state_0 = Variable("state_0", reals(state_dim))
log_prob += obs_dist(state=state_0, obs=data(time=0))

state_1 = Variable("state_1", reals(state_dim))
log_prob += trans_dist(prev=state_0, curr=state_1)
log_prob += obs_dist(state=state_1, obs=data(time=1))

log_prob = log_prob.reduce(ops.logaddexp)
assert isinstance(log_prob, Tensor), log_prob.pretty()


# This version constructs factors using funsor.pyro.convert.
@pytest.mark.xfail(reason="missing pattern")
def test_end_to_end():
def test_pyro_convert():
data = Tensor(torch.randn(2, 2), OrderedDict([("time", bint(2))]))

bias_dist = dist_to_funsor(random_mvn((), 2))
Expand Down Expand Up @@ -44,7 +86,7 @@ def test_end_to_end():

@pytest.mark.xfail(reason="missing pattern")
def test_affine_subs():
# This was recorded from test_end_to_end.
# This was recorded from test_pyro_convert.
x = Subs(
Gaussian(
torch.tensor([1.3027106523513794, 1.4167094230651855, -0.9750942587852478, 0.5321089029312134, -0.9039931297302246], dtype=torch.float32), # noqa
Expand Down
1 change: 1 addition & 0 deletions test/test_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def test_affine_subs(expr, expected_type, expected_inputs):

@pytest.mark.parametrize('expr', [
"Variable('x', reals()) + 0.5",
"Variable('x', reals(2, 3)) + Variable('y', reals(2, 3))",
"Variable('x', reals(2)) + Variable('y', reals(2))",
"Variable('x', reals(2)) + torch.ones(2)",
"Variable('x', reals(2)) * torch.randn(2)",
Expand Down
80 changes: 75 additions & 5 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

import funsor
import funsor.distributions as dist
from funsor.cnf import Contraction
from funsor.cnf import Contraction, GaussianMixture
from funsor.delta import Delta
from funsor.domains import bint, reals
from funsor.terms import Independent, Variable
from funsor.testing import assert_close, check_funsor, random_tensor
from funsor.torch import Tensor
from funsor.interpreter import interpretation, reinterpret
from funsor.pyro.convert import dist_to_funsor
from funsor.terms import Independent, Variable, lazy
from funsor.testing import assert_close, check_funsor, random_mvn, random_tensor
from funsor.torch import Einsum, Tensor


@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
Expand Down Expand Up @@ -398,7 +400,6 @@ def test_normal_gaussian_3(batch_shape):
'dist.Normal(x - y, scale, 0)',
'dist.Normal(0, scale, y - x)',
'dist.Normal(2 * x - y, scale, x)',
# TODO should we expect these to work without correction terms?
'dist.Normal(0, 1, (x - y) / scale) - scale.log()',
'dist.Normal(2 * y, 2 * scale, 2 * x) + math.log(2)',
]
Expand Down Expand Up @@ -488,6 +489,75 @@ def test_mvn_gaussian(batch_shape):
assert_close(actual, expected, atol=1e-3, rtol=1e-4)


def _check_mvn_affine(d1, data):
assert isinstance(d1, dist.MultivariateNormal)
d2 = reinterpret(d1)
assert issubclass(type(d2), GaussianMixture)
actual = d2(**data)
expected = d1(**data)
assert_close(actual, expected)


def test_mvn_affine_one_var():
x = Variable('x', reals(2))
data = dict(x=Tensor(torch.randn(2)))
with interpretation(lazy):
d = dist_to_funsor(random_mvn((), 2))
d = d(value=2 * x + 1)
_check_mvn_affine(d, data)


def test_mvn_affine_two_vars():
x = Variable('x', reals(2))
y = Variable('y', reals(2))
data = dict(x=Tensor(torch.randn(2)), y=Tensor(torch.randn(2)))
with interpretation(lazy):
d = dist_to_funsor(random_mvn((), 2))
d = d(value=x - y)
_check_mvn_affine(d, data)


def test_mvn_affine_matmul():
x = Variable('x', reals(2))
y = Variable('y', reals(3))
m = Tensor(torch.randn(2, 3))
data = dict(x=Tensor(torch.randn(2)), y=Tensor(torch.randn(3)))
with interpretation(lazy):
d = dist_to_funsor(random_mvn((), 3))
d = d(value=x @ m - y)
_check_mvn_affine(d, data)


def test_mvn_affine_einsum():
c = Tensor(torch.randn(3, 2, 2))
x = Variable('x', reals(2, 2))
y = Variable('y', reals())
data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(())))
with interpretation(lazy):
d = dist_to_funsor(random_mvn((), 3))
d = d(value=Einsum("abc,bc->a", c, x) + y)
_check_mvn_affine(d, data)


def test_mvn_affine_getitem():
x = Variable('x', reals(2, 2))
data = dict(x=Tensor(torch.randn(2, 2)))
with interpretation(lazy):
d = dist_to_funsor(random_mvn((), 2))
d = d(value=x[0] - x[1])
_check_mvn_affine(d, data)


def test_mvn_affine_reshape():
x = Variable('x', reals(2, 2))
y = Variable('y', reals(4))
data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(4)))
with interpretation(lazy):
d = dist_to_funsor(random_mvn((), 4))
d = d(value=x.reshape((4,)) - y)
_check_mvn_affine(d, data)


@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
@pytest.mark.parametrize('syntax', ['eager', 'lazy'])
def test_poisson_probs_density(batch_shape, syntax):
Expand Down