diff --git a/funsor/__init__.py b/funsor/__init__.py index 64b0f3520..db3efb647 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -5,8 +5,8 @@ from funsor.terms import Funsor, Number, Variable, of_shape, to_funsor from funsor.torch import Function, Tensor, arange, function, torch_einsum -from . import (adjoint, contract, delta, distributions, domains, einsum, gaussian, handlers, interpreter, minipyro, ops, - terms, torch) +from . import (adjoint, contract, delta, distributions, domains, einsum, gaussian, handlers, interpreter, joint, + minipyro, ops, terms, torch) __all__ = [ 'Domain', @@ -29,6 +29,7 @@ 'gaussian', 'handlers', 'interpreter', + 'joint', 'minipyro', 'of_shape', 'ops', diff --git a/funsor/distributions.py b/funsor/distributions.py index 26ad08c86..3cc051ac0 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -198,11 +198,12 @@ def eager_normal(loc, scale, value): inputs, (loc, scale) = align_tensors(loc, scale) inputs.update(value.inputs) + int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') - log_density = -0.5 * math.log(2 * math.pi) - scale.log() + log_prob = -0.5 * math.log(2 * math.pi) - scale.log() loc = loc.unsqueeze(-1) precision = scale.pow(-2).unsqueeze(-1).unsqueeze(-1) - return Gaussian(log_density, loc, precision, inputs) + return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs) # Create a Gaussian from a noisy identity transform. @@ -215,12 +216,13 @@ def eager_normal(loc, scale, value): inputs = loc.inputs.copy() inputs.update(scale.inputs) inputs.update(value.inputs) + int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') - log_density = -0.5 * math.log(2 * math.pi) - scale.data.log() + log_prob = -0.5 * math.log(2 * math.pi) - scale.data.log() loc = scale.data.new_zeros(scale.data.shape + (2,)) p = scale.data.pow(-2) precision = torch.stack([p, -p, -p, p], -1).reshape(p.shape + (2, 2)) - return Gaussian(log_density, loc, precision, inputs) + return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs) class MultivariateNormal(Distribution): @@ -261,11 +263,12 @@ def eager_mvn(loc, scale_tril, value): dim, = loc.output.shape inputs, (loc, scale_tril) = align_tensors(loc, scale_tril) inputs.update(value.inputs) + int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') - log_density = -0.5 * dim * math.log(2 * math.pi) - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1) + log_prob = -0.5 * dim * math.log(2 * math.pi) - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1) inv_scale_tril = torch.inverse(scale_tril) precision = torch.matmul(inv_scale_tril.transpose(-1, -2), inv_scale_tril) - return Gaussian(log_density, loc, precision, inputs) + return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs) __all__ = [ diff --git a/funsor/gaussian.py b/funsor/gaussian.py index 9a354f685..26e5f713d 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -9,7 +9,7 @@ import funsor.ops as ops from funsor.domains import reals -from funsor.ops import Op +from funsor.ops import AddOp from funsor.terms import Binary, Funsor, FunsorMeta, Number, eager from funsor.torch import Tensor, align_tensor, align_tensors, materialize @@ -24,7 +24,7 @@ def _issubshape(subshape, supershape): def _log_det_tril(x): - return x.diagonal(dim1=-1, dim2=-2).log().sum() + return x.diagonal(dim1=-1, dim2=-2).log().sum(-1) def _mv(mat, vec): @@ -66,7 +66,6 @@ def align_gaussian(new_inputs, old): """ assert isinstance(new_inputs, OrderedDict) assert isinstance(old, Gaussian) - log_density = old.log_density loc = old.loc precision = old.precision @@ -75,7 +74,6 @@ def align_gaussian(new_inputs, old): new_ints = OrderedDict((k, d) for k, d in new_inputs.items() if d.dtype != 'real') old_ints = OrderedDict((k, d) for k, d in old.inputs.items() if d.dtype != 'real') if new_ints != old_ints: - log_density = align_tensor(new_ints, Tensor(log_density, old_ints)) loc = align_tensor(new_ints, Tensor(loc, old_ints)) precision = align_tensor(new_ints, Tensor(precision, old_ints)) @@ -106,17 +104,18 @@ def align_gaussian(new_inputs, old): new_slice2 = slice(new_offset2, new_offset2 + num_elements2) precision[..., new_slice1, new_slice2] = old_precision[..., old_slice1, old_slice2] - return log_density, loc, precision + return loc, precision class GaussianMeta(FunsorMeta): """ Wrapper to convert between OrderedDict and tuple. """ - def __call__(cls, log_density, loc, precision, inputs): + def __call__(cls, loc, precision, inputs): if isinstance(inputs, OrderedDict): inputs = tuple(inputs.items()) - return super(GaussianMeta, cls).__call__(log_density, loc, precision, inputs) + assert isinstance(inputs, tuple) + return super(GaussianMeta, cls).__call__(loc, precision, inputs) @add_metaclass(GaussianMeta) @@ -125,8 +124,7 @@ class Gaussian(Funsor): Funsor representing a batched joint Gaussian distribution as a log-density function. """ - def __init__(self, log_density, loc, precision, inputs): - assert isinstance(log_density, torch.Tensor) + def __init__(self, loc, precision, inputs): assert isinstance(loc, torch.Tensor) assert isinstance(precision, torch.Tensor) assert isinstance(inputs, tuple) @@ -141,13 +139,11 @@ def __init__(self, log_density, loc, precision, inputs): # Compute total shape of all bint inputs. batch_shape = tuple(d.dtype for d in inputs.values() if isinstance(d.dtype, integer_types)) - assert _issubshape(log_density.shape, batch_shape) assert _issubshape(loc.shape, batch_shape + (dim,)) assert _issubshape(precision.shape, batch_shape + (dim, dim)) output = reals() super(Gaussian, self).__init__(inputs, output) - self.log_density = log_density self.loc = loc self.precision = precision self.batch_shape = batch_shape @@ -173,11 +169,11 @@ def eager_subs(self, subs): if int_subs: int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype == 'real') - tensors = [self.log_density, self.loc, self.precision] + tensors = [self.loc, self.precision] funsors = [Tensor(x, int_inputs).eager_subs(int_subs) for x in tensors] inputs = funsors[0].inputs.copy() inputs.update(real_inputs) - int_result = Gaussian(funsors[0].data, funsors[1].data, funsors[2].data, inputs) + int_result = Gaussian(funsors[0].data, funsors[1].data, inputs) return int_result.eager_subs(real_subs) # Try to perform a complete substitution of all real variables, resulting in a Tensor. @@ -185,14 +181,13 @@ def eager_subs(self, subs): if all(k in subs for k, d in self.inputs.items() if d.dtype == 'real'): # Broadcast all component tensors. int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real') - tensors = [Tensor(self.log_density, int_inputs), - Tensor(self.loc, int_inputs), + tensors = [Tensor(self.loc, int_inputs), Tensor(self.precision, int_inputs)] tensors.extend(subs.values()) inputs, tensors = align_tensors(*tensors) - batch_dim = tensors[0].dim() + batch_dim = self.loc.dim() - 1 batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors)) - (log_density, loc, precision), values = tensors[:3], tensors[3:] + (loc, precision), values = tensors[:2], tensors[2:] # Form the concatenated value. offsets, event_size = _compute_offsets(self.inputs) @@ -204,8 +199,10 @@ def eager_subs(self, subs): value[..., offset: offset + self.inputs[k].num_elements] = value_k # Evaluate the non-normalized log density. - result = log_density - 0.5 * _vmv(precision, value - loc) - return Tensor(result, inputs) + result = -0.5 * _vmv(precision, value - loc) + result = Tensor(result, inputs) + assert result.output == reals() + return result raise NotImplementedError('TODO implement partial substitution of real variables') @@ -220,12 +217,12 @@ def eager_reduce(self, op, reduced_vars): return None # defer to default implementation inputs = OrderedDict((k, d) for k, d in self.inputs.items() if k not in reduced_reals) - log_density = self.log_density + int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') if reduced_reals == real_vars: dim = self.loc.size(-1) log_det_term = _log_det_tril(torch.cholesky(self.precision)) - data = log_density + log_det_term - 0.5 * math.log(2 * math.pi) * dim - result = Tensor(data, inputs) + data = log_det_term - 0.5 * math.log(2 * math.pi) * dim + result = Tensor(data, int_inputs) else: offsets, _ = _compute_offsets(self.inputs) index = [] @@ -243,8 +240,8 @@ def eager_reduce(self, op, reduced_vars): precision = torch.matmul(inv_scale_tril, inv_scale_tril.transpose(-1, -2)) reduced_dim = sum(self.inputs[k].num_elements for k in reduced_reals) log_det_term = _log_det_tril(scale_tril) - _log_det_tril(self_scale_tril) - log_density = log_density + log_det_term - 0.5 * math.log(2 * math.pi) * reduced_dim - result = Gaussian(log_density, loc, precision, inputs) + log_prob = Tensor(log_det_term - 0.5 * math.log(2 * math.pi) * reduced_dim, int_inputs) + result = log_prob + Gaussian(loc, precision, inputs) return result.reduce(ops.logaddexp, reduced_ints) @@ -254,101 +251,27 @@ def eager_reduce(self, op, reduced_vars): return None # defer to default implementation -@eager.register(Binary, Op, Gaussian, Number) -def eager_binary_gaussian_number(op, lhs, rhs): - if op is ops.add or op is ops.sub: - # Add a constant log_density term to a Gaussian. - log_density = op(lhs.log_density, rhs.data) - return Gaussian(log_density, lhs.loc, lhs.precision, lhs.inputs) - - if op is ops.mul or op is ops.truediv: - # Scale a Gaussian, as under pyro.poutine.scale. - raise NotImplementedError('TODO') - - return None # defer to default implementation - - -@eager.register(Binary, Op, Number, Gaussian) -def eager_binary_number_gaussian(op, lhs, rhs): - if op is ops.add: - # Add a constant log_density term to a Gaussian. - log_density = op(lhs.data, rhs.log_density) - return Gaussian(log_density, rhs.loc, rhs.precision, rhs.inputs) - - if op is ops.mul: - # Scale a Gaussian, as under pyro.poutine.scale. - raise NotImplementedError('TODO') - - return None # defer to default implementation - - -@eager.register(Binary, Op, Gaussian, Tensor) -def eager_binary_gaussian_tensor(op, lhs, rhs): - if op is ops.add or op is ops.sub: - # Add a batch-dependent log_density term to a Gaussian. - nonreal_inputs = OrderedDict((k, d) for k, d in lhs.inputs.items() - if d.dtype != 'real') - inputs, (rhs_data, log_density, loc, precision) = align_tensors( - rhs, - Tensor(lhs.log_density, nonreal_inputs), - Tensor(lhs.loc, nonreal_inputs), - Tensor(lhs.precision, nonreal_inputs)) - log_density = op(log_density, rhs_data) - inputs.update(lhs.inputs) - return Gaussian(log_density, loc, precision, inputs) - - if op is ops.mul or op is ops.truediv: - # Scale a Gaussian, as under pyro.poutine.scale. - raise NotImplementedError('TODO') - - return None # defer to default implementation - - -@eager.register(Binary, Op, Tensor, Gaussian) -def eager_binary_tensor_gaussian(op, lhs, rhs): - if op is ops.add: - # Add a batch-dependent log_density term to a Gaussian. - nonreal_inputs = OrderedDict((k, d) for k, d in rhs.inputs.items() - if d.dtype != 'real') - inputs, (lhs_data, log_density, loc, precision) = align_tensors( - lhs, - Tensor(rhs.log_density, nonreal_inputs), - Tensor(rhs.loc, nonreal_inputs), - Tensor(rhs.precision, nonreal_inputs)) - log_density = op(lhs_data, log_density) - inputs.update(rhs.inputs) - return Gaussian(log_density, loc, precision, inputs) - - if op is ops.mul: - # Scale a Gaussian, as under pyro.poutine.scale. - raise NotImplementedError('TODO') - - return None # defer to default implementation - - -@eager.register(Binary, Op, Gaussian, Gaussian) -def eager_binary_gaussian_gaussian(op, lhs, rhs): - if op is ops.add: - # Fuse two Gaussians by adding their log-densities pointwise. - # This is similar to a Kalman filter update, but also keeps track of - # the marginal likelihood which accumulates into log_density. - - # Align data. - inputs = lhs.inputs.copy() - inputs.update(rhs.inputs) - lhs_log_density, lhs_loc, lhs_precision = align_gaussian(inputs, lhs) - rhs_log_density, rhs_loc, rhs_precision = align_gaussian(inputs, rhs) - - # Fuse aligned Gaussians. - precision_loc = _mv(lhs_precision, lhs_loc) + _mv(rhs_precision, rhs_loc) - precision = lhs_precision + rhs_precision - scale_tril = torch.inverse(torch.cholesky(precision)) - loc = _mv(scale_tril.transpose(-1, -2), _mv(scale_tril, precision_loc)) - quadratic_term = _vmv(lhs_precision, loc - lhs_loc) + _vmv(rhs_precision, loc - rhs_loc) - log_density = lhs_log_density + rhs_log_density - 0.5 * quadratic_term - return Gaussian(log_density, loc, precision, inputs) - - return None # defer to default implementation +@eager.register(Binary, AddOp, Gaussian, Gaussian) +def eager_add_gaussian_gaussian(op, lhs, rhs): + # Fuse two Gaussians by adding their log-densities pointwise. + # This is similar to a Kalman filter update, but also keeps track of + # the marginal likelihood which accumulates into a Tensor. + + # Align data. + inputs = lhs.inputs.copy() + inputs.update(rhs.inputs) + int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') + lhs_loc, lhs_precision = align_gaussian(inputs, lhs) + rhs_loc, rhs_precision = align_gaussian(inputs, rhs) + + # Fuse aligned Gaussians. + precision_loc = _mv(lhs_precision, lhs_loc) + _mv(rhs_precision, rhs_loc) + precision = lhs_precision + rhs_precision + scale_tril = torch.inverse(torch.cholesky(precision)) + loc = _mv(scale_tril.transpose(-1, -2), _mv(scale_tril, precision_loc)) + quadratic_term = _vmv(lhs_precision, loc - lhs_loc) + _vmv(rhs_precision, loc - rhs_loc) + likelihood = Tensor(-0.5 * quadratic_term, int_inputs) + return likelihood + Gaussian(loc, precision, inputs) __all__ = [ diff --git a/funsor/joint.py b/funsor/joint.py new file mode 100644 index 000000000..c3c6691cf --- /dev/null +++ b/funsor/joint.py @@ -0,0 +1,228 @@ +from __future__ import absolute_import, division, print_function + +from collections import OrderedDict + +from six import add_metaclass + +import funsor.ops as ops +from funsor.delta import Delta +from funsor.domains import reals +from funsor.gaussian import Gaussian +from funsor.ops import AddOp, Op +from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, eager, to_funsor +from funsor.torch import Tensor + + +class JointMeta(FunsorMeta): + """ + Wrapper to fill in defaults and convert to funsor. + """ + def __call__(cls, deltas=(), discrete=0, gaussian=0): + discrete = to_funsor(discrete) + gaussian = to_funsor(gaussian) + return super(JointMeta, cls).__call__(deltas, discrete, gaussian) + + +@add_metaclass(JointMeta) +class Joint(Funsor): + """ + Normal form for a joint log probability density funsor. + + :param tuple deltas: A possibly-empty tuple of degenerate distributions + represented as :class:`~funsor.delta.Delta` funsors. + :param Funsor discrete: A joint discrete log mass function represented as + a :class:`~funsor.terms.Number` or `~funsor.terms.Tensor`. + :param Funsor gaussian: An optional joint multivariate normal distribution + a represented as :class:`~funsor.gaussian.Gaussian` or ``Number(0)`` if + absent. + """ + def __init__(self, deltas, discrete, gaussian): + assert isinstance(deltas, tuple) + assert isinstance(discrete, (Number, Tensor)) + assert discrete.output == reals() + assert gaussian is Number(0) or isinstance(gaussian, Gaussian) + inputs = OrderedDict() + for x in deltas: + assert isinstance(x, Delta) + assert x.name not in inputs + assert x.name not in discrete.inputs + assert x.name not in gaussian.inputs + inputs.update(x.inputs) + inputs.update(discrete.inputs) + inputs.update(gaussian.inputs) + output = reals() + super(Joint, self).__init__(inputs, output) + self.deltas = deltas + self.discrete = discrete + self.gaussian = gaussian + + def eager_subs(self, subs): + gaussian = self.gaussian.eager_subs(subs) + assert isinstance(gaussian, (Number, Tensor, Gaussian)) + discrete = self.discrete.eager_subs(subs) + gaussian = self.gaussian.eager_subs(subs) + deltas = [] + for x in self.deltas: + x = x.eager_subs(subs) + if isinstance(x, Delta): + deltas.append(x) + elif isinstance(x, (Number, Tensor)): + discrete += x + else: + raise ValueError('Cannot substitute {}'.format(x)) + deltas = tuple(deltas) + return Joint(deltas, discrete) + gaussian + + def eager_reduce(self, op, reduced_vars): + if op is ops.logaddexp: + # Integrate out delayed discrete variables. + discrete_vars = reduced_vars.intersection(self.discrete.inputs) + mixture_params = frozenset(self.gaussian.inputs).union(*(x.point.inputs for x in self.deltas)) + lazy_vars = discrete_vars & mixture_params # Mixtures must remain lazy. + discrete_vars -= mixture_params + discrete = self.discrete.reduce(op, discrete_vars) + + # Integrate out delayed gaussian variables. + gaussian_vars = reduced_vars.intersection(self.gaussian.inputs) + gaussian = self.gaussian.reduce(ops.logaddexp, gaussian_vars) + assert (reduced_vars - gaussian_vars).issubset(d.name for d in self.deltas) + + # Integrate out delayed degenerate variables, i.e. drop them. + deltas = tuple(d for d in self.deltas if d.name not in reduced_vars) + + assert not lazy_vars + return (Joint(deltas, discrete) + gaussian).reduce(ops.logaddexp, lazy_vars) + + if op is ops.add: + raise NotImplementedError('TODO product-reduce along a plate dimension') + + return None # defer to default implementation + + +@eager.register(Joint, tuple, Funsor, Funsor) +def eager_joint(deltas, discrete, gaussian): + # Demote a Joint to a simpler elementart funsor. + if not deltas: + if gaussian is Number(0): + return discrete + elif discrete is Number(0): + return gaussian + elif len(deltas) == 1: + if discrete is Number(0) and gaussian is Number(0): + return deltas[0] + + return None # defer to default implementation + + +################################################################################ +# Patterns to update a Joint with other funsors +################################################################################ + +@eager.register(Binary, AddOp, Joint, Joint) +def eager_add(op, joint, other): + # Fuse two joint distributions. + for d in other.deltas: + joint += d + joint += other.discrete + joint += other.gaussian + return joint + + +@eager.register(Binary, AddOp, Joint, Delta) +def eager_add(op, joint, delta): + # Update with a degenerate distribution, typically a monte carlo sample. + if delta.name in joint.inputs: + joint = joint.eager_subs(((delta.name, delta.point),)) + if not isinstance(joint, Joint): + return joint + delta + for d in joint.deltas: + if d.name in delta.inputs: + delta = delta.eager_subs(((d.name, d.point),)) + deltas = joint.deltas + (delta,) + return Joint(deltas, joint.discrete, joint.gaussian) + + +@eager.register(Binary, AddOp, Joint, (Number, Tensor)) +def eager_add(op, joint, other): + # Update with a delayed discrete random variable. + subs = tuple((d.name, d.point) for d in joint.deltas if d in other.inputs) + if subs: + return joint + other.eager_subs(subs) + return Joint(joint.deltas, joint.discrete + other, joint.gaussian) + + +@eager.register(Binary, AddOp, Joint, Gaussian) +def eager_add(op, joint, other): + # Update with a delayed gaussian random variable. + subs = tuple((d.name, d.point) for d in joint.deltas if d in other.inputs) + if subs: + other = other.eager_subs(subs) + if joint.gaussian is not Number(0): + other = joint.gaussian + other + if not isinstance(other, Gaussian): + return Joint(joint.deltas, joint.discrete) + other + return Joint(joint.deltas, joint.discrete, other) + + +@eager.register(Binary, AddOp, (Funsor, Align, Delta), Joint) +def eager_add(op, other, joint): + return joint + other + + +################################################################################ +# Patterns to create a Joint from elementary funsors +################################################################################ + +@eager.register(Binary, AddOp, Delta, Delta) +def eager_add(op, lhs, rhs): + if lhs.name == rhs.name: + raise NotImplementedError + if rhs.name in lhs.inputs: + assert lhs.name not in rhs.inputs + lhs = lhs(**{rhs.name: rhs.point}) + elif lhs.name in rhs.inputs: + rhs = rhs(**{lhs.name: lhs.point}) + return Joint(deltas=(lhs, rhs)) + + +@eager.register(Binary, AddOp, Delta, (Number, Tensor, Gaussian)) +def eager_add(op, delta, other): + if delta.name in other.inputs: + other = other.eager_subs(((delta.name, delta.point),)) + assert isinstance(other, (Number, Tensor, Gaussian)) + if isinstance(other, (Number, Tensor)): + return Joint((delta,), discrete=other) + else: + return Joint((delta,), gaussian=other) + + +@eager.register(Binary, Op, Delta, (Number, Tensor)) +def eager_binary(op, delta, other): + if op is ops.sub: + return delta + -other + + +@eager.register(Binary, AddOp, (Number, Tensor, Gaussian), Delta) +def eager_add(op, other, delta): + return delta + other + + +@eager.register(Binary, AddOp, Gaussian, (Number, Tensor)) +def eager_add(op, gaussian, discrete): + return Joint(discrete=discrete, gaussian=gaussian) + + +@eager.register(Binary, AddOp, (Number, Tensor), Gaussian) +def eager_add(op, discrete, gaussian): + return Joint(discrete=discrete, gaussian=gaussian) + + +@eager.register(Binary, Op, Gaussian, (Number, Tensor)) +def eager_binary(op, gaussian, discrete): + if op is ops.sub: + return Joint(discrete=-discrete, gaussian=gaussian) + + +__all__ = [ + 'Joint', +] diff --git a/funsor/ops.py b/funsor/ops.py index 8dff446df..421333649 100644 --- a/funsor/ops.py +++ b/funsor/ops.py @@ -29,6 +29,10 @@ class AssociativeOp(Op): pass +class AddOp(AssociativeOp): + pass + + eq = Op(operator.eq) ge = Op(operator.ge) getitem = Op(operator.getitem) @@ -41,7 +45,7 @@ class AssociativeOp(Op): sub = Op(operator.sub) truediv = Op(operator.truediv) -add = AssociativeOp(operator.add) +add = AddOp(operator.add) and_ = AssociativeOp(operator.and_) mul = AssociativeOp(operator.mul) or_ = AssociativeOp(operator.or_) diff --git a/funsor/terms.py b/funsor/terms.py index cc8493a8c..edab81213 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -138,7 +138,20 @@ def __hash__(self): return id(self) def __repr__(self): - return '{}({})'.format(type(self).__name__, ', '.join(map(repr, self._ast_args))) + return '{}({})'.format(type(self).__name__, ', '.join(map(repr, self._ast_values))) + + def _pretty(self, lines, indent=0): + lines.append((indent, type(self).__name__)) + for arg in self._ast_values: + if isinstance(arg, Funsor): + arg._pretty(lines, indent + 1) + else: + lines.append((indent + 1, str(arg))) + + def pretty(self): + lines = [] + self._pretty(lines) + return '\n'.join('| ' * indent + text for indent, text in lines) def __call__(self, *args, **kwargs): """ diff --git a/funsor/testing.py b/funsor/testing.py index 99aecb01a..2c53d4519 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -6,13 +6,14 @@ from collections import OrderedDict, namedtuple import numpy as np +import opt_einsum import pytest import torch -import opt_einsum from six.moves import reduce from funsor.domains import Domain, bint, reals from funsor.gaussian import Gaussian +from funsor.joint import Joint from funsor.numpy import Array from funsor.terms import Funsor from funsor.torch import Tensor @@ -46,9 +47,16 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6): if isinstance(actual, Tensor): assert_close(actual.data, expected.data, atol=atol, rtol=rtol) elif isinstance(actual, Gaussian): - assert_close(actual.log_density, expected.log_density, atol=atol, rtol=rtol) assert_close(actual.loc, expected.loc, atol=atol, rtol=rtol) assert_close(actual.precision, expected.precision, atol=atol, rtol=rtol) + elif isinstance(actual, Joint): + actual_deltas = {d.name: d.point for d in actual.deltas} + expected_deltas = {d.name: d.point for d in expected.deltas} + assert set(actual_deltas) == set(expected_deltas) + for name, actual_point in actual_deltas.items(): + assert_close(actual_point, expected_deltas[name]) + assert_close(actual.discrete, expected.discrete, atol=atol, rtol=rtol) + assert_close(actual.gaussian, expected.gaussian, atol=atol, rtol=rtol) elif isinstance(actual, torch.Tensor): assert actual.dtype == expected.dtype, msg if actual.dtype in (torch.long, torch.uint8): @@ -160,11 +168,10 @@ def random_gaussian(inputs): assert isinstance(inputs, OrderedDict) batch_shape = tuple(d.dtype for d in inputs.values() if d.dtype != 'real') event_shape = (sum(d.num_elements for d in inputs.values() if d.dtype == 'real'),) - log_density = torch.randn(batch_shape) loc = torch.randn(batch_shape + event_shape) prec_sqrt = torch.randn(batch_shape + event_shape + event_shape) precision = torch.matmul(prec_sqrt, prec_sqrt.transpose(-1, -2)) - return Gaussian(log_density, loc, precision, inputs) + return Gaussian(loc, precision, inputs) def make_plated_hmm_einsum(num_steps, num_obs_plates=1, num_hidden_plates=0): diff --git a/funsor/torch.py b/funsor/torch.py index aba56b770..c6dc86dfb 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -92,6 +92,7 @@ def __init__(self, data, inputs=None, dtype="real"): assert isinstance(data, torch.Tensor) assert isinstance(inputs, tuple) assert all(isinstance(d.dtype, integer_types) for k, d in inputs) + assert len(inputs) <= data.dim() inputs = OrderedDict(inputs) output = Domain(data.shape[len(inputs):], dtype) super(Tensor, self).__init__(inputs, output) diff --git a/test/test_delta.py b/test/test_delta.py index abe4f9c8f..d5908b571 100644 --- a/test/test_delta.py +++ b/test/test_delta.py @@ -44,7 +44,7 @@ def test_reduce(): @pytest.mark.parametrize('log_density', [0, 1.234]) -def test_reduce(log_density): +def test_reduce_density(log_density): point = Tensor(torch.randn(3)) d = Delta('foo', point, log_density) # Note that log_density affects ground substitution but does not affect reduction. diff --git a/test/test_distributions.py b/test/test_distributions.py index 542940b42..8f9de6a72 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -10,7 +10,7 @@ import funsor.distributions as dist from funsor.delta import Delta from funsor.domains import bint, reals -from funsor.gaussian import Gaussian +from funsor.joint import Joint from funsor.terms import Variable from funsor.testing import assert_close, check_funsor, random_tensor from funsor.torch import Tensor @@ -89,7 +89,7 @@ def test_delta_delta(): assert d is Delta('v', point, log_density) -def test_mvn_defaults(): +def test_normal_defaults(): loc = Variable('loc', reals()) scale = Variable('scale', reals()) value = Variable('value', reals()) @@ -97,20 +97,20 @@ def test_mvn_defaults(): @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) -def test_mvn_density(batch_shape): +def test_normal_density(batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) @funsor.of_shape(reals(), reals(), reals()) - def mvn(loc, scale, value): + def normal(loc, scale, value): return -((value - loc) ** 2) / (2 * scale ** 2) - scale.log() - math.log(math.sqrt(2 * math.pi)) - check_funsor(mvn, {'loc': reals(), 'scale': reals(), 'value': reals()}, reals()) + check_funsor(normal, {'loc': reals(), 'scale': reals(), 'value': reals()}, reals()) loc = Tensor(torch.randn(batch_shape), inputs) scale = Tensor(torch.randn(batch_shape).exp(), inputs) value = Tensor(torch.randn(batch_shape), inputs) - expected = mvn(loc, scale, value) + expected = normal(loc, scale, value) check_funsor(expected, inputs, reals()) actual = dist.Normal(loc, scale, value) @@ -119,7 +119,7 @@ def mvn(loc, scale, value): @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) -def test_mvn_gaussian_1(batch_shape): +def test_normal_gaussian_1(batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) @@ -132,7 +132,7 @@ def test_mvn_gaussian_1(batch_shape): check_funsor(expected, inputs, reals()) g = dist.Normal(loc, scale) - assert isinstance(g, Gaussian) + assert isinstance(g, Joint) actual = g(value=value) check_funsor(actual, inputs, reals()) @@ -140,7 +140,7 @@ def test_mvn_gaussian_1(batch_shape): @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) -def test_mvn_gaussian_2(batch_shape): +def test_normal_gaussian_2(batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) @@ -153,7 +153,7 @@ def test_mvn_gaussian_2(batch_shape): check_funsor(expected, inputs, reals()) g = dist.Normal(Variable('value', reals()), scale, loc) - assert isinstance(g, Gaussian) + assert isinstance(g, Joint) actual = g(value=value) check_funsor(actual, inputs, reals()) @@ -161,7 +161,7 @@ def test_mvn_gaussian_2(batch_shape): @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) -def test_mvn_gaussian_3(batch_shape): +def test_normal_gaussian_3(batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) @@ -174,7 +174,7 @@ def test_mvn_gaussian_3(batch_shape): check_funsor(expected, inputs, reals()) g = dist.Normal(Variable('loc', reals()), scale) - assert isinstance(g, Gaussian) + assert isinstance(g, Joint) actual = g(loc=loc, value=value) check_funsor(actual, inputs, reals()) @@ -229,7 +229,7 @@ def test_mvn_gaussian(batch_shape): check_funsor(expected, inputs, reals()) g = dist.MultivariateNormal(loc, scale_tril) - assert isinstance(g, Gaussian) + assert isinstance(g, Joint) actual = g(value=value) check_funsor(actual, inputs, reals()) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 14e91ef95..113c6f414 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -9,6 +9,7 @@ import funsor.ops as ops from funsor.domains import bint, reals from funsor.gaussian import Gaussian +from funsor.joint import Joint from funsor.terms import Number from funsor.testing import assert_close, random_gaussian, random_tensor, xfail_if_not_implemented from funsor.torch import Tensor @@ -21,30 +22,29 @@ def id_from_inputs(inputs): @pytest.mark.parametrize('expr,expected_type', [ - ('g1 + 1', Gaussian), - ('g1 - 1', Gaussian), - ('1 + g1', Gaussian), - ('g1 + shift', Gaussian), - ('g1 - shift', Gaussian), - ('shift + g1', Gaussian), - ('g1 + g1', Gaussian), + ('g1 + 1', Joint), + ('g1 - 1', Joint), + ('1 + g1', Joint), + ('g1 + shift', Joint), + ('g1 - shift', Joint), + ('shift + g1', Joint), + ('g1 + g1', Joint), ('g1(i=i0)', Gaussian), ('g2(i=i0)', Gaussian), - ('g1(i=i0) + g2(i=i0)', Gaussian), - ('g1(i=i0) + g2', Gaussian), + ('g1(i=i0) + g2(i=i0)', Joint), + ('g1(i=i0) + g2', Joint), ('g1(x=x0)', Tensor), ('g2(y=y0)', Tensor), - ('(g1 + g2)(i=i0)', Gaussian), + ('(g1 + g2)(i=i0)', Joint), ('(g1 + g2)(x=x0, y=y0)', Tensor), ('(g2 + g1)(x=x0, y=y0)', Tensor), ('g1.reduce(ops.logaddexp, "x")', Tensor), - ('(g1 + g2).reduce(ops.logaddexp, "x")', Gaussian), - ('(g1 + g2).reduce(ops.logaddexp, "y")', Gaussian), + ('(g1 + g2).reduce(ops.logaddexp, "x")', Joint), + ('(g1 + g2).reduce(ops.logaddexp, "y")', Joint), ('(g1 + g2).reduce(ops.logaddexp, frozenset(["x", "y"]))', Tensor), ]) def test_smoke(expr, expected_type): g1 = Gaussian( - log_density=torch.tensor([0.0, 1.0]), loc=torch.tensor([[0.0, 0.1, 0.2], [2.0, 3.0, 4.0]]), precision=torch.tensor([[[1.0, 0.1, 0.2], @@ -57,7 +57,6 @@ def test_smoke(expr, expected_type): assert isinstance(g1, Gaussian) g2 = Gaussian( - log_density=torch.tensor([0.0, 1.0]), loc=torch.tensor([[0.0, 0.1], [2.0, 3.0]]), precision=torch.tensor([[[1.0, 0.2], @@ -230,5 +229,5 @@ def test_logsumexp(int_inputs, real_inputs): g = random_gaussian(inputs) g_xy = g.reduce(ops.logaddexp, frozenset(['x', 'y'])) - assert_close(g_xy, g.reduce(ops.logaddexp, 'x').reduce(ops.logaddexp, 'y'), atol=1e-4, rtol=None) - assert_close(g_xy, g.reduce(ops.logaddexp, 'y').reduce(ops.logaddexp, 'x'), atol=1e-4, rtol=None) + assert_close(g_xy, g.reduce(ops.logaddexp, 'x').reduce(ops.logaddexp, 'y'), atol=1e-3, rtol=None) + assert_close(g_xy, g.reduce(ops.logaddexp, 'y').reduce(ops.logaddexp, 'x'), atol=1e-3, rtol=None) diff --git a/test/test_joint.py b/test/test_joint.py new file mode 100644 index 000000000..4837ffad6 --- /dev/null +++ b/test/test_joint.py @@ -0,0 +1,152 @@ +from __future__ import absolute_import, division, print_function + +from collections import OrderedDict + +import pytest +import torch + +import funsor.ops as ops +from funsor.delta import Delta +from funsor.domains import bint, reals +from funsor.gaussian import Gaussian +from funsor.joint import Joint +from funsor.terms import Number +from funsor.testing import assert_close, random_gaussian, random_tensor, xfail_if_not_implemented +from funsor.torch import Tensor + + +def id_from_inputs(inputs): + if not inputs: + return '()' + return ','.join(k + ''.join(map(str, d.shape)) for k, d in inputs.items()) + + +SMOKE_TESTS = [ + ('dx + dy', Joint), + ('dx + g', Joint), + ('dy + g', Joint), + ('g + dx', Joint), + ('g + dy', Joint), + ('dx + t', Joint), + ('dy + t', Joint), + ('dx - t', Joint), + ('dy - t', Joint), + ('t + dx', Joint), + ('t + dy', Joint), + ('g + 1', Joint), + ('g - 1', Joint), + ('1 + g', Joint), + ('g + t', Joint), + ('g - t', Joint), + ('t + g', Joint), + ('g + g', Joint), + ('(dx + dy)(i=i0)', Joint), + ('(dx + g)(i=i0)', Joint), + ('(dy + g)(i=i0)', Joint), + ('(g + dx)(i=i0)', Joint), + ('(g + dy)(i=i0)', Joint), + ('(dx + t)(i=i0)', Joint), + ('(dy + t)(i=i0)', Joint), + ('(dx - t)(i=i0)', Joint), + ('(dy - t)(i=i0)', Joint), + ('(t + dx)(i=i0)', Joint), + ('(t + dy)(i=i0)', Joint), + ('(g + 1)(i=i0)', Joint), + ('(g - 1)(i=i0)', Joint), + ('(1 + g)(i=i0)', Joint), + ('(g + t)(i=i0)', Joint), + ('(g - t)(i=i0)', Joint), + ('(t + g)(i=i0)', Joint), + ('(g + g)(i=i0)', Joint), + ('(dx + dy)(x=x0)', Joint), + ('(dx + g)(x=x0)', Tensor), + ('(dy + g)(x=x0)', Joint), + ('(g + dx)(x=x0)', Tensor), + ('(g + dy)(x=x0)', Joint), + ('(dx + t)(x=x0)', Tensor), + ('(dy + t)(x=x0)', Joint), + ('(dx - t)(x=x0)', Tensor), + ('(dy - t)(x=x0)', Joint), + ('(t + dx)(x=x0)', Tensor), + ('(t + dy)(x=x0)', Joint), + ('(g + 1)(x=x0)', Tensor), + ('(g - 1)(x=x0)', Tensor), + ('(1 + g)(x=x0)', Tensor), + ('(g + t)(x=x0)', Tensor), + ('(g - t)(x=x0)', Tensor), + ('(t + g)(x=x0)', Tensor), + ('(g + g)(x=x0)', Tensor), + ('(g + dy).reduce(ops.logaddexp, "x")', Joint), + ('(g + dy).reduce(ops.logaddexp, "y")', Gaussian), + ('(t + g + dy).reduce(ops.logaddexp, "x")', Joint), + ('(t + g + dy).reduce(ops.logaddexp, "y")', Joint), + ('(t + g).reduce(ops.logaddexp, "x")', Tensor), +] + + +@pytest.mark.parametrize('expr,expected_type', SMOKE_TESTS) +def test_smoke(expr, expected_type): + dx = Delta('x', Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2))]))) + assert isinstance(dx, Delta) + + dy = Delta('y', Tensor(torch.randn(3, 4), OrderedDict([('j', bint(3))]))) + assert isinstance(dy, Delta) + + t = Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2)), ('j', bint(3))])) + assert isinstance(t, Tensor) + + g = Gaussian( + loc=torch.tensor([[0.0, 0.1, 0.2], + [2.0, 3.0, 4.0]]), + precision=torch.tensor([[[1.0, 0.1, 0.2], + [0.1, 1.0, 0.3], + [0.2, 0.3, 1.0]], + [[1.0, 0.1, 0.2], + [0.1, 1.0, 0.3], + [0.2, 0.3, 1.0]]]), + inputs=OrderedDict([('i', bint(2)), ('x', reals(3))])) + assert isinstance(g, Gaussian) + + i0 = Number(1, 2) + assert isinstance(i0, Number) + + x0 = Tensor(torch.tensor([0.5, 0.6, 0.7])) + assert isinstance(x0, Tensor) + + result = eval(expr) + assert isinstance(result, expected_type) + + +@pytest.mark.parametrize('int_inputs', [ + {}, + {'i': bint(2)}, + {'i': bint(2), 'j': bint(3)}, +], ids=id_from_inputs) +@pytest.mark.parametrize('real_inputs', [ + {'x': reals()}, + {'x': reals(4)}, + {'x': reals(2, 3)}, + {'x': reals(), 'y': reals()}, + {'x': reals(2), 'y': reals(3)}, + {'x': reals(4), 'y': reals(2, 3), 'z': reals()}, +], ids=id_from_inputs) +def test_reduce(int_inputs, real_inputs): + int_inputs = OrderedDict(sorted(int_inputs.items())) + real_inputs = OrderedDict(sorted(real_inputs.items())) + inputs = int_inputs.copy() + inputs.update(real_inputs) + + t = random_tensor(int_inputs) + g = random_gaussian(inputs) + truth = {name: random_tensor(int_inputs, domain) for name, domain in real_inputs.items()} + + state = 0 + state += g + state += t + for name, point in truth.items(): + with xfail_if_not_implemented(): + state += Delta(name, point) + actual = state.reduce(ops.logaddexp, frozenset(truth)) + + expected = t + g(**truth) + assert_close(actual, expected)