Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement negation and subtraction ops for R-N derivatives #104

Merged
merged 1 commit into from
Mar 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions funsor/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from funsor.domains import Domain, reals
from funsor.integrate import Integrate, integrator
from funsor.interpreter import debug_logged
from funsor.ops import Op, TransformOp
from funsor.ops import AddOp, SubOp, TransformOp
from funsor.registry import KeyedRegistry
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Subs, Unary, Variable, eager, to_funsor

Expand Down Expand Up @@ -96,22 +96,29 @@ def eager_reduce(self, op, reduced_vars):
return None # defer to default implementation


@eager.register(Binary, Op, Delta, (Funsor, Delta, Align))
def eager_binary(op, lhs, rhs):
if op is ops.add or op is ops.sub:
if lhs.name in rhs.inputs:
rhs = rhs(**{lhs.name: lhs.point})
return op(lhs, rhs)
@eager.register(Binary, AddOp, Delta, (Funsor, Align))
def eager_add(op, lhs, rhs):
if lhs.name in rhs.inputs:
rhs = rhs(**{lhs.name: lhs.point})
return op(lhs, rhs)

return None # defer to default implementation


@eager.register(Binary, Op, (Funsor, Align), Delta)
def eager_binary(op, lhs, rhs):
if op is ops.add:
if rhs.name in lhs.inputs:
lhs = lhs(**{rhs.name: rhs.point})
return op(lhs, rhs)
@eager.register(Binary, SubOp, Delta, (Funsor, Align))
def eager_sub(op, lhs, rhs):
if lhs.name in rhs.inputs:
rhs = rhs(**{lhs.name: lhs.point})
return op(lhs, rhs)

return None # defer to default implementation


@eager.register(Binary, AddOp, (Funsor, Align), Delta)
def eager_add(op, lhs, rhs):
if rhs.name in lhs.inputs:
lhs = lhs(**{rhs.name: rhs.point})
return op(lhs, rhs)

return None # defer to default implementation

Expand Down
39 changes: 35 additions & 4 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import absolute_import, division, print_function

import math
import sys
from collections import OrderedDict

import six
import torch
from pyro.distributions.util import broadcast_shape
from six import add_metaclass, integer_types
Expand All @@ -13,8 +15,8 @@
from funsor.domains import reals
from funsor.integrate import Integrate, integrator
from funsor.montecarlo import monte_carlo
from funsor.ops import AddOp
from funsor.terms import Binary, Funsor, FunsorMeta, Number, Subs, Variable, eager
from funsor.ops import AddOp, NegOp, SubOp
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Subs, Unary, Variable, eager
from funsor.torch import Tensor, align_tensor, align_tensors, materialize
from funsor.util import lazy_property

Expand Down Expand Up @@ -60,6 +62,24 @@ def _trace_mm(x, y):
return xy.reshape(xy.shape[:-2] + (-1,)).sum(-1)


def _sym_solve_mv(mat, vec):
r"""
Computes ``mat \ vec`` assuming mat is symmetric and usually positive definite,
but falling back to general pseudoinverse if positive definiteness fails.
"""
try:
# Attempt to use stable positive definite math.
tri = torch.inverse(torch.cholesky(mat))
return _mv(tri.transpose(-1, -2), _mv(tri, vec))
except RuntimeError as e:
if 'not positive definite' not in e.message:
_, exc_value, traceback = sys.exc_info()
six.reraise(RuntimeError, e, traceback)

# Fall back to pseudoinverse.
return _mv(torch.pinverse(mat), vec)


def _compute_offsets(inputs):
"""
Compute offsets of real inputs into the concatenated Gaussian dims.
Expand Down Expand Up @@ -356,13 +376,24 @@ def eager_add_gaussian_gaussian(op, lhs, rhs):
# Fuse aligned Gaussians.
precision_loc = _mv(lhs_precision, lhs_loc) + _mv(rhs_precision, rhs_loc)
precision = lhs_precision + rhs_precision
scale_tri = torch.inverse(torch.cholesky(precision)).transpose(-1, -2)
loc = _mv(scale_tri, _mv(scale_tri.transpose(-1, -2), precision_loc))
loc = _sym_solve_mv(precision, precision_loc)
quadratic_term = _vmv(lhs_precision, loc - lhs_loc) + _vmv(rhs_precision, loc - rhs_loc)
likelihood = Tensor(-0.5 * quadratic_term, int_inputs)
return likelihood + Gaussian(loc, precision, inputs)


@eager.register(Binary, SubOp, Gaussian, (Funsor, Align, Gaussian))
@eager.register(Binary, SubOp, (Funsor, Align), Gaussian)
def eager_sub(op, lhs, rhs):
return lhs + -rhs


@eager.register(Unary, NegOp, Gaussian)
def eager_neg(op, arg):
precision = -arg.precision
return Gaussian(arg.loc, precision, arg.inputs)


@eager.register(Integrate, Gaussian, Variable, frozenset)
@integrator
def eager_integrate(log_measure, integrand, reduced_vars):
Expand Down
50 changes: 30 additions & 20 deletions funsor/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from funsor.gaussian import Gaussian
from funsor.integrate import Integrate, integrator
from funsor.montecarlo import monte_carlo
from funsor.ops import AddOp, Op
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Subs, Variable, eager, to_funsor
from funsor.ops import AddOp, NegOp, SubOp
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Subs, Unary, Variable, eager, to_funsor
from funsor.torch import Tensor, arange


Expand Down Expand Up @@ -191,14 +191,6 @@ def eager_add(op, joint, other):
return Joint(joint.deltas, joint.discrete + other, joint.gaussian)


@eager.register(Binary, Op, Joint, (Number, Tensor))
def eager_add(op, joint, other):
if op is ops.sub:
return joint + -other

return None # defer to default implementation


@eager.register(Binary, AddOp, Joint, Gaussian)
def eager_add(op, joint, other):
# Update with a delayed gaussian random variable.
Expand Down Expand Up @@ -244,12 +236,6 @@ def eager_add(op, delta, other):
return Joint((delta,), gaussian=other)


@eager.register(Binary, Op, Delta, (Number, Tensor))
def eager_binary(op, delta, other):
if op is ops.sub:
return delta + -other


@eager.register(Binary, AddOp, (Number, Tensor, Gaussian), Delta)
def eager_add(op, other, delta):
return delta + other
Expand All @@ -265,10 +251,34 @@ def eager_add(op, discrete, gaussian):
return Joint(discrete=discrete, gaussian=gaussian)


@eager.register(Binary, Op, Gaussian, (Number, Tensor))
def eager_binary(op, gaussian, discrete):
if op is ops.sub:
return Joint(discrete=-discrete, gaussian=gaussian)
################################################################################
# Patterns to compute Radon-Nikodym derivatives
################################################################################

@eager.register(Binary, SubOp, Joint, (Funsor, Align, Gaussian, Joint))
def eager_sub(op, joint, other):
return joint + -other


@eager.register(Binary, SubOp, (Funsor, Align), Joint)
def eager_sub(op, other, joint):
return -joint + other


@eager.register(Binary, SubOp, Delta, (Number, Tensor, Gaussian, Joint))
@eager.register(Binary, SubOp, (Number, Tensor), Gaussian)
@eager.register(Binary, SubOp, Gaussian, (Number, Tensor, Joint))
def eager_sub(op, lhs, rhs):
return lhs + -rhs


@eager.register(Unary, NegOp, Joint)
def eager_neg(op, joint):
if joint.deltas:
raise ValueError("Cannot negate deltas")
discrete = -joint.discrete
gaussian = -joint.gaussian
return Joint(discrete=discrete, gaussian=gaussian)


################################################################################
Expand Down
15 changes: 13 additions & 2 deletions funsor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ class AddOp(AssociativeOp):
pass


class SubOp(Op):
pass


class NegOp(Op):
pass


class GetitemMeta(type):
_cache = {}

Expand Down Expand Up @@ -102,8 +110,8 @@ def _default(self, x, y):
le = Op(operator.le)
lt = Op(operator.lt)
ne = Op(operator.ne)
neg = Op(operator.neg)
sub = Op(operator.sub)
neg = NegOp(operator.neg)
sub = SubOp(operator.sub)
truediv = Op(operator.truediv)

add = AddOp(operator.add)
Expand Down Expand Up @@ -227,11 +235,14 @@ def reciprocal(x):


__all__ = [
'AddOp',
'AssociativeOp',
'DISTRIBUTIVE_OPS',
'GetitemOp',
'NegOp',
'Op',
'PRODUCT_INVERSES',
'SubOp',
'abs',
'add',
'and_',
Expand Down
5 changes: 4 additions & 1 deletion test/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@


@pytest.mark.parametrize('expr,expected_type', [
('-g1', Gaussian),
('g1 + 1', Joint),
('g1 - 1', Joint),
('1 + g1', Joint),
('g1 + shift', Joint),
('g1 - shift', Joint),
('g1 + shift', Joint),
('shift + g1', Joint),
('shift - g1', Joint),
('g1 + g1', Joint),
('(g1 + g2 + g2) - g2', Joint),
('g1(i=i0)', Gaussian),
('g2(i=i0)', Gaussian),
('g1(i=i0) + g2(i=i0)', Joint),
Expand Down
2 changes: 2 additions & 0 deletions test/test_joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def id_from_inputs(inputs):
('g + t', Joint),
('g - t', Joint),
('t + g', Joint),
('t - g', Joint),
('g + g', Joint),
('-(g + g)', Joint),
('(dx + dy)(i=i0)', Joint),
('(dx + g)(i=i0)', Joint),
('(dy + g)(i=i0)', Joint),
Expand Down