Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Contract to optimizer #105

Merged
merged 14 commits into from
Mar 28, 2019
15 changes: 14 additions & 1 deletion funsor/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

import funsor.ops as ops
from funsor.contract import Contract
from funsor.interpreter import interpretation, reinterpret
from funsor.ops import AssociativeOp
from funsor.registry import KeyedRegistry
Expand All @@ -19,7 +20,7 @@ def __init__(self):

def __call__(self, cls, *args):
result = eager(cls, *args)
if cls in (Reduce, Binary, Tensor):
if cls in (Reduce, Contract, Binary, Tensor):
self.tape.append((result, cls, args))
return result

Expand Down Expand Up @@ -87,3 +88,15 @@ def adjoint_reduce(out_adj, out, op, arg, reduced_vars):
return {arg: out_adj + (arg * 0.)} # XXX hack to simulate "expand"
elif op is ops.add: # plate!
return {arg: out_adj + Binary(ops.safesub, out, arg)}


@adjoint_ops.register(Contract, Funsor, Funsor, AssociativeOp, AssociativeOp, Funsor, Funsor, frozenset)
def adjoint_contract(out_adj, out, sum_op, prod_op, lhs, rhs, reduced_vars):

lhs_reduced_vars = frozenset(rhs.inputs) - frozenset(lhs.inputs)
lhs_adj = Contract(sum_op, prod_op, out_adj, rhs, lhs_reduced_vars)

rhs_reduced_vars = frozenset(lhs.inputs) - frozenset(rhs.inputs)
rhs_adj = Contract(sum_op, prod_op, out_adj, lhs, rhs_reduced_vars)

return {lhs: lhs_adj, rhs: rhs_adj}
76 changes: 15 additions & 61 deletions funsor/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,32 @@

import funsor.interpreter as interpreter
import funsor.ops as ops
from funsor.optimizer import Finitary, optimize
from funsor.sum_product import _partition
from funsor.terms import Funsor, Subs, eager


def _order_lhss(lhs, reduced_vars):
assert isinstance(lhs, Finitary)

components = _partition(lhs.operands, reduced_vars)
root_lhs = Finitary(ops.mul, tuple(components[0][0]))
if len(components) > 1:
remaining_lhs = Finitary(ops.mul, tuple(t for c in components[1:] for t in c[0]))
else:
remaining_lhs = None

return root_lhs, remaining_lhs


def _simplify_contract(fn, lhs, rhs, reduced_vars):
def _simplify_contract(fn, sum_op, prod_op, lhs, rhs, reduced_vars):
"""
Reduce free variables that do not appear explicitly in the lhs
"""
if not reduced_vars:
return lhs * rhs
return prod_op(lhs, rhs)

lhs_vars = frozenset(lhs.inputs)
rhs_vars = frozenset(rhs.inputs)
assert reduced_vars <= lhs_vars | rhs_vars
progress = False
if not reduced_vars <= lhs_vars:
rhs = rhs.reduce(ops.add, reduced_vars - lhs_vars)
rhs = rhs.reduce(sum_op, reduced_vars - lhs_vars)
reduced_vars = reduced_vars & lhs_vars
progress = True
if not reduced_vars <= rhs_vars:
lhs = lhs.reduce(ops.add, reduced_vars - rhs_vars)
lhs = lhs.reduce(sum_op, reduced_vars - rhs_vars)
reduced_vars = reduced_vars & rhs_vars
progress = True
if progress:
return Contract(lhs, rhs, reduced_vars)
return Contract(sum_op, prod_op, lhs, rhs, reduced_vars)

return fn(lhs, rhs, reduced_vars)
return fn(sum_op, prod_op, lhs, rhs, reduced_vars)


def contractor(fn):
Expand All @@ -58,14 +43,18 @@ def contractor(fn):

class Contract(Funsor):

def __init__(self, lhs, rhs, reduced_vars):
def __init__(self, sum_op, prod_op, lhs, rhs, reduced_vars):
assert isinstance(sum_op, ops.AssociativeOp)
assert isinstance(prod_op, ops.AssociativeOp)
assert isinstance(lhs, Funsor)
assert isinstance(rhs, Funsor)
assert isinstance(reduced_vars, frozenset)
inputs = OrderedDict([(k, d) for t in (lhs, rhs)
for k, d in t.inputs.items() if k not in reduced_vars])
output = rhs.output
super(Contract, self).__init__(inputs, output)
self.sum_op = sum_op
self.prod_op = prod_op
self.lhs = lhs
self.rhs = rhs
self.reduced_vars = reduced_vars
Expand All @@ -79,48 +68,13 @@ def eager_subs(self, subs):
raise NotImplementedError('TODO alpha-convert to avoid conflict')
lhs = Subs(self.lhs, subs)
rhs = Subs(self.rhs, subs)
return Contract(lhs, rhs, self.reduced_vars)


@eager.register(Contract, Funsor, Funsor, frozenset)
@contractor
def eager_contract(lhs, rhs, reduced_vars):
return (lhs * rhs).reduce(ops.add, reduced_vars)


@optimize.register(Contract, Funsor, Funsor, frozenset)
@contractor
def optimize_contract(lhs, rhs, reduced_vars):
return None


@optimize.register(Contract, Funsor, Finitary, frozenset)
@contractor
def optimize_contract_funsor_finitary(lhs, rhs, reduced_vars):
return Contract(rhs, lhs, reduced_vars)
return Contract(self.sum_op, self.prod_op, lhs, rhs, self.reduced_vars)


@optimize.register(Contract, Finitary, (Finitary, Funsor), frozenset)
@eager.register(Contract, ops.AssociativeOp, ops.AssociativeOp, Funsor, Funsor, frozenset)
@contractor
def optimize_contract_finitary_funsor(lhs, rhs, reduced_vars):
# exploit linearity of contraction
if lhs.op is ops.add:
return Finitary(
ops.add,
tuple(Contract(operand, rhs, reduced_vars) for operand in lhs.operands)
)

# recursively apply law of iterated expectation
assert len(lhs.operands) > 1, "Finitary with one operand should have been passed through"
if lhs.op is ops.mul:
root_lhs, remaining_lhs = _order_lhss(lhs, reduced_vars)
if remaining_lhs is not None:
inner = Contract(remaining_lhs, rhs,
reduced_vars & frozenset(remaining_lhs.inputs))
return Contract(root_lhs, inner,
reduced_vars & frozenset(root_lhs.inputs))

return None
def eager_contract(sum_op, prod_op, lhs, rhs, reduced_vars):
return prod_op(lhs, rhs).reduce(sum_op, reduced_vars)


__all__ = [
Expand Down
29 changes: 15 additions & 14 deletions funsor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,23 @@
from funsor.torch import Tensor


def _make_base_lhs(arg, reduced_vars, normalized=False):
def _make_base_lhs(prod_op, arg, reduced_vars, normalized=False):
if not all(isinstance(d.dtype, integer_types) for d in arg.inputs.values()):
raise NotImplementedError("TODO implement continuous base lhss")

if prod_op not in (ops.add, ops.mul):
raise NotImplementedError("{} not supported product op".format(prod_op))

make_unit = torch.ones if prod_op is ops.mul else torch.zeros

sizes = OrderedDict(set((var, dtype) for var, dtype in arg.inputs.items()))
terms = tuple(
Tensor(torch.ones((d.dtype,)) / float(d.dtype), OrderedDict([(var, d)]))
Tensor(make_unit((d.dtype,)) / float(d.dtype), OrderedDict([(var, d)]))
if normalized else
Tensor(torch.ones((d.dtype,)), OrderedDict([(var, d)]))
Tensor(make_unit((d.dtype,)), OrderedDict([(var, d)]))
for var, d in sizes.items() if var in reduced_vars
)
return Finitary(ops.mul, terms) if len(terms) > 1 else terms[0]
return Finitary(prod_op, terms) if len(terms) > 1 else terms[0]


def naive_contract_einsum(eqn, *terms, **kwargs):
Expand All @@ -36,8 +41,10 @@ def naive_contract_einsum(eqn, *terms, **kwargs):
assert "plates" not in kwargs

backend = kwargs.pop('backend', 'torch')
if backend in ('torch', 'pyro.ops.einsum.torch_log'):
prod_op = ops.mul
if backend == 'torch':
sum_op, prod_op = ops.add, ops.mul
elif backend in ('pyro.ops.einsum.torch_log', 'pyro.ops.einsum.torch_marginal'):
sum_op, prod_op = ops.logaddexp, ops.add
else:
raise ValueError("{} backend not implemented".format(backend))

Expand All @@ -51,17 +58,11 @@ def naive_contract_einsum(eqn, *terms, **kwargs):
output_dims = frozenset(d for d in output)
reduced_vars = input_dims - output_dims

if backend == 'pyro.ops.einsum.torch_log':
terms = tuple(term.exp() for term in terms)

with interpretation(optimize):
rhs = Finitary(prod_op, tuple(terms))
lhs = _make_base_lhs(rhs, reduced_vars, normalized=False)
lhs = _make_base_lhs(prod_op, rhs, reduced_vars, normalized=False)
assert frozenset(lhs.inputs) == reduced_vars
result = Contract(lhs, rhs, reduced_vars)

if backend == 'pyro.ops.einsum.torch_log':
result = result.log()
result = Contract(sum_op, prod_op, lhs, rhs, reduced_vars)

return reinterpret(result)

Expand Down
4 changes: 2 additions & 2 deletions funsor/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def integrator(fn):
@eager.register(Integrate, Funsor, Funsor, frozenset)
@integrator
def eager_integrate(log_measure, integrand, reduced_vars):
return Contract(log_measure.exp(), integrand, reduced_vars)
return Contract(ops.add, ops.mul, log_measure.exp(), integrand, reduced_vars)


@eager.register(Integrate, Reduce, Funsor, frozenset)
Expand All @@ -78,7 +78,7 @@ def eager_integrate(log_measure, integrand, reduced_vars):
arg = Integrate(log_measure.arg, integrand, reduced_vars)
return arg.reduce(ops.add, log_measure.reduced_vars)

return Contract(log_measure.exp(), integrand, reduced_vars)
return Contract(ops.add, ops.mul, log_measure.exp(), integrand, reduced_vars)


__all__ = [
Expand Down
7 changes: 7 additions & 0 deletions funsor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@ def reciprocal(x):
])


UNITS = {
mul: 1.,
add: 0.,
}


PRODUCT_INVERSES = {
mul: safediv,
add: safesub,
Expand All @@ -232,6 +238,7 @@ def reciprocal(x):
'GetitemOp',
'Op',
'PRODUCT_INVERSES',
'UNITS',
'abs',
'add',
'and_',
Expand Down
59 changes: 44 additions & 15 deletions funsor/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
from opt_einsum.paths import greedy
from six.moves import reduce

import funsor.ops as ops
from funsor.contract import Contract, contractor
from funsor.domains import find_domain
from funsor.integrate import Integrate
from funsor.interpreter import dispatched_interpretation, interpretation, reinterpret
from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp
from funsor.terms import Binary, Funsor, Reduce, Subs, eager, lazy
from funsor.ops import DISTRIBUTIVE_OPS, UNITS, AssociativeOp
from funsor.terms import Binary, Funsor, Reduce, Subs, Unary, eager, lazy, to_funsor


class Finitary(Funsor):
Expand Down Expand Up @@ -157,13 +160,19 @@ def optimize_reduction(op, arg, reduced_vars):
if not (op, arg.op) in DISTRIBUTIVE_OPS:
return None

return Contract(op, arg.op, arg, to_funsor(UNITS[arg.op]), reduced_vars)


@optimize.register(Contract, AssociativeOp, AssociativeOp, Finitary, (Finitary, Funsor), frozenset)
def optimize_contract_finitary_funsor(sum_op, prod_op, lhs, rhs, reduced_vars):

if prod_op is not lhs.op:
return None

# build opt_einsum optimizer IR
inputs = []
size_dict = {}
for operand in arg.operands:
inputs.append(frozenset(d for d in operand.inputs.keys()))
size_dict.update({k: ((REAL_SIZE * v.num_elements) if v.dtype == 'real' else v.dtype)
for k, v in operand.inputs.items()})
inputs = [frozenset(t.inputs) for t in lhs.operands] + [frozenset(rhs.inputs)]
size_dict = {k: ((REAL_SIZE * v.num_elements) if v.dtype == 'real' else v.dtype)
for arg in (lhs, rhs) for k, v in arg.inputs.items()}
outputs = frozenset().union(*inputs) - reduced_vars

# optimize path with greedy opt_einsum optimizer
Expand All @@ -177,8 +186,7 @@ def optimize_reduction(op, arg, reduced_vars):
for input in inputs:
reduce_dim_counter.update({d: 1 for d in input})

reduce_op, finitary_op = op, arg.op
operands = list(arg.operands)
operands = list(lhs.operands) + [rhs]
for (a, b) in path:
b, a = tuple(sorted((a, b), reverse=True))
tb = operands.pop(b)
Expand All @@ -197,17 +205,14 @@ def optimize_reduction(op, arg, reduced_vars):
# count new appearance of variables that aren't reduced
reduce_dim_counter.update({d: 1 for d in reduced_vars & (both_vars - path_end_reduced_vars)})

path_end = Binary(finitary_op, ta, tb)
if path_end_reduced_vars:
path_end = Reduce(reduce_op, path_end, path_end_reduced_vars)

path_end = Contract(sum_op, prod_op, ta, tb, path_end_reduced_vars)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the path optimizer can be called recursively here if ta or tb are Finitary

operands.append(path_end)

# reduce any remaining dims, if necessary
final_reduced_vars = frozenset(d for (d, count) in reduce_dim_counter.items()
if count > 0) & reduced_vars
if final_reduced_vars:
path_end = Reduce(reduce_op, path_end, final_reduced_vars)
path_end = Reduce(sum_op, path_end, final_reduced_vars)
return path_end


Expand All @@ -218,6 +223,30 @@ def remove_single_finitary(op, operands):
return None


@optimize.register(Unary, ops.Op, Finitary)
def optimize_exp_finitary(op, arg):
# useful for handling Integrate...
if op is not ops.exp or arg.op is not ops.add:
return None
return Finitary(ops.mul, tuple(operand.exp() for operand in arg.operands))


@optimize.register(Contract, AssociativeOp, AssociativeOp, Funsor, Funsor, frozenset)
@contractor
def optimize_contract(sum_op, prod_op, lhs, rhs, reduced_vars):
return None


@optimize.register(Integrate, Funsor, Funsor, frozenset)
def optimize_integrate(log_measure, integrand, reduced_vars):
return Contract(ops.add, ops.mul, log_measure.exp(), integrand, reduced_vars)


@optimize.register(Contract, AssociativeOp, AssociativeOp, Funsor, Finitary, frozenset)
def optimize_contract_funsor_finitary(sum_op, prod_op, lhs, rhs, reduced_vars):
return Contract(sum_op, prod_op, rhs, lhs, reduced_vars)


@dispatched_interpretation
def desugar(cls, *args):
result = desugar.dispatch(cls, *args)
Expand Down
18 changes: 13 additions & 5 deletions funsor/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from funsor.domains import Domain, bint, find_domain, reals
from funsor.integrate import Integrate, integrator
from funsor.montecarlo import monte_carlo
from funsor.ops import GetitemOp, Op
from funsor.ops import AssociativeOp, GetitemOp, Op
from funsor.six import getargspec
from funsor.terms import Binary, Funsor, FunsorMeta, Number, Subs, Variable, eager, to_data, to_funsor

Expand Down Expand Up @@ -419,15 +419,23 @@ def eager_getitem_tensor_tensor(op, lhs, rhs):
return Tensor(data, inputs, lhs.dtype)


@eager.register(Contract, Tensor, Tensor, frozenset)
@eager.register(Contract, AssociativeOp, AssociativeOp, Tensor, Tensor, frozenset)
@contractor
def eager_contract(lhs, rhs, reduced_vars):
def eager_contract(sum_op, prod_op, lhs, rhs, reduced_vars):
if (sum_op, prod_op) == (ops.add, ops.mul):
backend = "torch"
elif (sum_op, prod_op) == (ops.logaddexp, ops.add):
backend = "pyro.ops.einsum.torch_log"
else:
return prod_op(lhs, rhs).reduce(sum_op, reduced_vars)

inputs = OrderedDict((k, d) for t in (lhs, rhs)
for k, d in t.inputs.items() if k not in reduced_vars)

data = opt_einsum.contract(lhs.data, list(lhs.inputs),
rhs.data, list(rhs.inputs),
list(inputs), backend="torch")
dtype = find_domain(ops.mul, lhs.output, rhs.output).dtype
list(inputs), backend=backend)
dtype = find_domain(prod_op, lhs.output, rhs.output).dtype
return Tensor(data, inputs, dtype)


Expand Down
Loading