Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor towards Gaussian._eager_subs_affine() #284

Merged
merged 3 commits into from
Oct 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/affine.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Affine Pattern Matching
-----------------------
.. automodule:: funsor.affine
:members:
:show-inheritance:
:member-order: bysource
8 changes: 0 additions & 8 deletions docs/source/funsors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,6 @@ Joint
:show-inheritance:
:member-order: bysource

Affine
--------
.. automodule:: funsor.affine
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

Contraction
-----------
.. automodule:: funsor.cnf
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Funsor is a tensor-like library for functions and distributions
optimizer
adjoint
sum_product
affine
testing

.. toctree::
Expand Down
87 changes: 53 additions & 34 deletions funsor/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import opt_einsum
import torch

from funsor.cnf import Contraction
from funsor.interpreter import gensym, interpretation
from funsor.terms import Binary, Funsor, Lambda, Reduce, Unary, Variable, bint, reflect
from funsor.interpreter import gensym
from funsor.terms import Binary, Funsor, Lambda, Reduce, Unary, Variable, bint
from funsor.torch import Einsum, Tensor

from . import ops
Expand All @@ -20,46 +19,61 @@ def is_affine(fn):
:param Funsor fn: A funsor.
:rtype: bool
"""
assert isinstance(fn, Funsor)
return _affine_inputs(fn) == _real_inputs(fn)
return affine_inputs(fn) == _real_inputs(fn)


def _real_inputs(fn):
return frozenset(k for k, d in fn.inputs.items() if d.dtype == "real")


@singledispatch
def _affine_inputs(fn):
def affine_inputs(fn):
"""
Returns a [sound sub]set of real inputs of ``fn``
wrt which ``fn`` is known to be affine.

:param Funsor fn: A funsor.
:return: A set of input names wrt which ``fn`` is affine.
:rtype: frozenset
"""
result = getattr(fn, '_affine_inputs', None)
if result is None:
result = fn._affine_inputs = _affine_inputs(fn)
return result


@singledispatch
def _affine_inputs(fn):
assert isinstance(fn, Funsor)
return frozenset()


@_affine_inputs.register(Variable)
# Make registration public.
affine_inputs.register = _affine_inputs.register


@affine_inputs.register(Variable)
def _(fn):
return _real_inputs(fn)


@_affine_inputs.register(Unary)
@affine_inputs.register(Unary)
def _(fn):
if fn.op in (ops.neg, ops.add) or isinstance(fn.op, ops.ReshapeOp):
return _affine_inputs(fn.arg)
return affine_inputs(fn.arg)
return frozenset()


@_affine_inputs.register(Binary)
@affine_inputs.register(Binary)
def _(fn):
if fn.op in (ops.add, ops.sub):
return _affine_inputs(fn.lhs) | _affine_inputs(fn.rhs)
return affine_inputs(fn.lhs) | affine_inputs(fn.rhs)
if fn.op is ops.truediv:
return _affine_inputs(fn.lhs) - _real_inputs(fn.rhs)
return affine_inputs(fn.lhs) - _real_inputs(fn.rhs)
if isinstance(fn.op, ops.GetitemOp):
return _affine_inputs(fn.lhs)
return affine_inputs(fn.lhs)
if fn.op in (ops.mul, ops.matmul):
lhs_affine = _affine_inputs(fn.lhs) - _real_inputs(fn.rhs)
rhs_affine = _affine_inputs(fn.rhs) - _real_inputs(fn.lhs)
lhs_affine = affine_inputs(fn.lhs) - _real_inputs(fn.rhs)
rhs_affine = affine_inputs(fn.rhs) - _real_inputs(fn.lhs)
if not lhs_affine:
return rhs_affine
if not rhs_affine:
Expand All @@ -70,26 +84,19 @@ def _(fn):
return frozenset()


@_affine_inputs.register(Reduce)
@affine_inputs.register(Reduce)
def _(fn):
return _affine_inputs(fn.arg) - fn.reduced_vars
return affine_inputs(fn.arg) - fn.reduced_vars


@_affine_inputs.register(Contraction)
def _(fn):
with interpretation(reflect):
flat = reduce(fn.bin_op, fn.terms).reduce(fn.red_op, fn.reduced_vars)
return _affine_inputs(flat)


@_affine_inputs.register(Einsum)
@affine_inputs.register(Einsum)
def _(fn):
# This is simply a multiary version of the above Binary(ops.mul, ...) case.
results = []
for i, x in enumerate(fn.operands):
others = fn.operands[:i] + fn.operands[i+1:]
other_inputs = reduce(ops.or_, map(_real_inputs, others), frozenset())
results.append(_affine_inputs(x) - other_inputs)
results.append(affine_inputs(x) - other_inputs)
# This multilinear case introduces incompleteness, since some vars
# could later be reduced, making remaining vars affine.
if sum(map(bool, results)) == 1:
Expand All @@ -101,35 +108,40 @@ def _(fn):

def extract_affine(fn):
"""
Extracts an affine representation of a funsor, which is exact for affine
funsors and approximate otherwise. For affine funsors this satisfies::
Extracts an affine representation of a funsor, satisfying::

x = ...
const, coeffs = extract_affine(x)
y = sum(Einsum(eqn, (coeff, Variable(var, coeff.output)))
for var, (coeff, eqn) in coeffs.items())
assert_close(y, x)
assert frozenset(coeffs) == affine_inputs(x)

The ``coeffs`` will have one key per input wrt which ``fn`` is known to be
affine (via :func:`affine_inputs` ), and ``const`` and ``coeffs.values``
will all be constant wrt these inputs.

The affine approximation is computed by ev evaluating ``fn`` at
zero and each basis vector. To improve performance, users may want to run
under the :func:`~funsor.memoize.memoize` interpretation.

:param Funsor fn: A funsor assumed to be affine wrt the (add,mul) semiring.
The affine assumption is not checked.
:param Funsor fn: A funsor that is affine wrt the (add,mul) semiring in
some subset of its inputs.
:return: A pair ``(const, coeffs)`` where const is a funsor with no real
inputs and ``coeffs`` is an OrderedDict mapping input name to a
``(coefficient, eqn)`` pair in einsum form.
:rtype: tuple
"""
# Determine constant part by evaluating fn at zero.
real_inputs = OrderedDict((k, v) for k, v in fn.inputs.items() if v.dtype == 'real')
zeros = {k: Tensor(torch.zeros(v.shape)) for k, v in real_inputs.items()}
inputs = affine_inputs(fn)
inputs = OrderedDict((k, v) for k, v in fn.inputs.items() if k in inputs)
zeros = {k: Tensor(torch.zeros(v.shape)) for k, v in inputs.items()}
const = fn(**zeros)

# Determine linear coefficients by evaluating fn on basis vectors.
name = gensym('probe')
coeffs = OrderedDict()
for k, v in real_inputs.items():
for k, v in inputs.items():
dim = v.num_elements
var = Variable(name, bint(dim))
subs = zeros.copy()
Expand All @@ -141,3 +153,10 @@ def extract_affine(fn):
eqn = f'{inputs1},{inputs2}->{output}'
coeffs[k] = coeff, eqn
return const, coeffs


__all__ = [
"affine_inputs",
"extract_affine",
"is_affine",
]
25 changes: 23 additions & 2 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,26 @@
from pyro.distributions.util import broadcast_shape

import funsor.ops as ops
from funsor.affine import affine_inputs
from funsor.delta import Delta
from funsor.domains import find_domain
from funsor.gaussian import Gaussian
from funsor.interpreter import recursion_reinterpret
from funsor.interpreter import interpretation, recursion_reinterpret
from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp, nullop
from funsor.terms import Align, Binary, Funsor, Number, Reduce, Subs, Unary, Variable, eager, normalize, to_funsor
from funsor.terms import (
Align,
Binary,
Funsor,
Number,
Reduce,
Subs,
Unary,
Variable,
eager,
normalize,
reflect,
to_funsor
)
from funsor.torch import Tensor
from funsor.util import quote

Expand Down Expand Up @@ -269,6 +283,13 @@ def eager_contraction_gaussian(red_op, bin_op, reduced_vars, x, y):
return (x + y).reduce(red_op, reduced_vars)


@affine_inputs.register(Contraction)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved verbatim from affine.py

def _(fn):
with interpretation(reflect):
flat = reduce(fn.bin_op, fn.terms).reduce(fn.red_op, fn.reduced_vars)
return affine_inputs(flat)


##########################################
# Normalizing Contractions
##########################################
Expand Down
84 changes: 50 additions & 34 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pyro.distributions.util import broadcast_shape

import funsor.ops as ops
from funsor.affine import affine_inputs, extract_affine
from funsor.delta import Delta
from funsor.domains import reals
from funsor.ops import AddOp, NegOp, SubOp
Expand Down Expand Up @@ -334,47 +335,56 @@ def eager_subs(self, subs):
if not subs:
return self

# Constants and Variables are eagerly substituted;
# Constants and Affine funsors are eagerly substituted;
# everything else is lazily substituted.
lazy_subs = tuple((k, v) for k, v in subs
if not isinstance(v, (Number, Tensor, Variable, Slice)))
if not isinstance(v, (Number, Tensor, Variable, Slice))
and not affine_inputs(v))
var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable))
int_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor, Slice))
if v.dtype != 'real')
real_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor))
if v.dtype == 'real')
if not (var_subs or int_subs or real_subs):
return reflect(Subs, self, lazy_subs)

# First perform any variable substitutions.
affine_subs = tuple((k, v) for k, v in subs
if not isinstance(v, Variable) and affine_inputs(v))
if var_subs:
rename = {k: v.name for k, v in var_subs}
inputs = OrderedDict((rename.get(k, k), d) for k, d in self.inputs.items())
if len(inputs) != len(self.inputs):
raise ValueError("Variable substitution name conflict")
var_result = Gaussian(self.info_vec, self.precision, inputs)
new_subs = int_subs + real_subs + lazy_subs
return Subs(var_result, new_subs) if new_subs else var_result

# Next perform any integer substitution, i.e. slicing into a batch.
return self._eager_subs_var(var_subs, int_subs + real_subs + affine_subs + lazy_subs)
if int_subs:
int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real')
real_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype == 'real')
tensors = [self.info_vec, self.precision]
funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors]
inputs = funsors[0].inputs.copy()
inputs.update(real_inputs)
int_result = Gaussian(funsors[0].data, funsors[1].data, inputs)
new_subs = real_subs + lazy_subs
return Subs(int_result, new_subs) if new_subs else int_result

return self._eager_subs_int(int_subs, real_subs + affine_subs + lazy_subs)
if real_subs:
return self._eager_subs_real(real_subs, affine_subs + lazy_subs)
if affine_subs:
# TODO: return self._eager_subs_affine(affine_subs, lazy_subs)
lazy_subs = affine_subs + lazy_subs
return reflect(Subs, self, lazy_subs)

def _eager_subs_var(self, subs, remaining_subs):
# Perform variable substitution, i.e. renaming of inputs.
rename = {k: v.name for k, v in subs}
inputs = OrderedDict((rename.get(k, k), d) for k, d in self.inputs.items())
if len(inputs) != len(self.inputs):
raise ValueError("Variable substitution name conflict")
var_result = Gaussian(self.info_vec, self.precision, inputs)
return Subs(var_result, remaining_subs) if remaining_subs else var_result

def _eager_subs_int(self, subs, remaining_subs):
# Perform integer substitution, i.e. slicing into a batch.
int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real')
real_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype == 'real')
tensors = [self.info_vec, self.precision]
funsors = [Subs(Tensor(x, int_inputs), subs) for x in tensors]
inputs = funsors[0].inputs.copy()
inputs.update(real_inputs)
int_result = Gaussian(funsors[0].data, funsors[1].data, inputs)
return Subs(int_result, remaining_subs) if remaining_subs else int_result

def _eager_subs_real(self, subs, remaining_subs):
# Broadcast all component tensors.
real_subs = OrderedDict(subs)
assert real_subs and not int_subs
subs = OrderedDict(subs)
int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real')
tensors = [Tensor(self.info_vec, int_inputs),
Tensor(self.precision, int_inputs)]
tensors.extend(real_subs.values())
tensors.extend(subs.values())
int_inputs, tensors = align_tensors(*tensors)
batch_dim = tensors[0].dim() - 1
batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors))
Expand All @@ -384,15 +394,15 @@ def eager_subs(self, subs):
for k, offset in offsets.items()]

# Expand all substituted values.
values = OrderedDict(zip(real_subs, values))
values = OrderedDict(zip(subs, values))
for k, value in values.items():
value = value.reshape(value.shape[:batch_dim] + (-1,))
if not torch._C._get_tracing_state():
assert value.size(-1) == self.inputs[k].num_elements
values[k] = value.expand(batch_shape + value.shape[-1:])

# Try to perform a complete substitution of all real variables, resulting in a Tensor.
if all(k in real_subs for k, d in self.inputs.items() if d.dtype == 'real'):
if all(k in subs for k, d in self.inputs.items() if d.dtype == 'real'):
# Form the concatenated value.
value = BlockVector(batch_shape + (event_size,))
for k, i in slices:
Expand All @@ -405,11 +415,11 @@ def eager_subs(self, subs):

result = Tensor(result, int_inputs)
assert result.output == reals()
return Subs(result, lazy_subs)
return Subs(result, remaining_subs) if remaining_subs else result

# Perform a partial substution of a subset of real variables, resulting in a Joint.
# We split real inputs into two sets: a for the preserved and b for the substituted.
b = frozenset(k for k, v in real_subs.items())
b = frozenset(k for k, v in subs.items())
a = frozenset(k for k, d in self.inputs.items() if d.dtype == 'real' and k not in b)
prec_aa = torch.cat([torch.cat([
precision[..., i1, i2]
Expand All @@ -431,9 +441,15 @@ def eager_subs(self, subs):
precision = prec_aa.expand(info_vec.shape + (-1,))
inputs = int_inputs.copy()
for k, d in self.inputs.items():
if k not in real_subs:
if k not in subs:
inputs[k] = d
return Gaussian(info_vec, precision, inputs) + Tensor(log_scale, int_inputs)
result = Gaussian(info_vec, precision, inputs) + Tensor(log_scale, int_inputs)
return Subs(result, remaining_subs) if remaining_subs else result

def _eager_subs_affine(self, subs, remaining_subs):
affine = OrderedDict((k, extract_affine(v)) for k, v in subs)
assert affine
raise NotImplementedError('TODO')

def eager_reduce(self, op, reduced_vars):
if op is ops.logaddexp:
Expand Down