Skip to content

Commit

Permalink
Implement ops.matmul and align_tensors(..., expand=True) (#250)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored Sep 20, 2019
1 parent 5b2f238 commit 9a5cf22
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 16 deletions.
6 changes: 2 additions & 4 deletions funsor/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,7 @@ def eager_normal(loc, scale, value):
if isinstance(loc, Variable):
loc, value = value, loc

inputs, (loc, scale) = align_tensors(loc, scale)
loc, scale = torch.broadcast_tensors(loc, scale)
inputs, (loc, scale) = align_tensors(loc, scale, expand=True)
inputs.update(value.inputs)
int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real')

Expand Down Expand Up @@ -407,8 +406,7 @@ def eager_normal(loc, scale, value):
coeffs[c] = affine(**{k: 1. if c == k else 0. for k in real_inputs.keys()}) - const

tensors = [const] + list(coeffs.values())
inputs, tensors = align_tensors(*tensors)
tensors = torch.broadcast_tensors(*tensors)
inputs, tensors = align_tensors(*tensors, expand=True)
const, coeffs = tensors[0], tensors[1:]

dim = sum(d.num_elements for d in real_inputs.values())
Expand Down
12 changes: 12 additions & 0 deletions funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ def find_domain(op, *domains):
dtype = lhs.dtype
shape = lhs.shape[:op.offset] + lhs.shape[1 + op.offset:]
return Domain(shape, dtype)
elif op == ops.matmul:
assert lhs.shape and rhs.shape
if len(rhs.shape) == 1:
assert lhs.shape[-1] == rhs.shape[-1]
shape = lhs.shape[:-1]
elif len(lhs.shape) == 1:
assert lhs.shape[-1] == rhs.shape[-2]
shape = rhs.shape[:-2] + rhs.shape[-1:]
else:
assert lhs.shape[-1] == rhs.shape[-2]
shape = broadcast_shape(lhs.shape[:-1], rhs.shape[:-2] + (1,)) + rhs.shape[-1:]
return Domain(shape, 'real')

if lhs.dtype == 'real' or rhs.dtype == 'real':
dtype = 'real'
Expand Down
6 changes: 6 additions & 0 deletions funsor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class MulOp(AssociativeOp):
pass


class MatmulOp(AssociativeOp):
pass


class LogAddExpOp(AssociativeOp):
pass

Expand Down Expand Up @@ -157,6 +161,7 @@ def _default(self, x, y):
add = AddOp(operator.add)
and_ = AssociativeOp(operator.and_)
mul = MulOp(operator.mul)
matmul = MatmulOp(operator.matmul)
or_ = AssociativeOp(operator.or_)
xor = AssociativeOp(operator.xor)

Expand Down Expand Up @@ -325,6 +330,7 @@ def reciprocal(x):
'log',
'log1p',
'lt',
'matmul',
'max',
'min',
'mul',
Expand Down
7 changes: 2 additions & 5 deletions funsor/pyro/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,16 +254,13 @@ def eager_affine_normal(matrix, loc, scale, value_x, value_y):
assert len(matrix.output.shape) == 2
assert value_x.output == reals(matrix.output.shape[0])
assert value_y.output == reals(matrix.output.shape[1])
tensors = (matrix, loc, scale, value_x)
int_inputs, tensors = align_tensors(*tensors)
matrix, loc, scale, value_x = tensors
loc += value_x @ matrix
int_inputs, (loc, scale) = align_tensors(loc, scale, expand=True)

loc = loc + value_x.unsqueeze(-2).matmul(matrix).squeeze(-2)
i_name = gensym("i")
y_name = gensym("y")
y_i_name = gensym("y_i")
int_inputs[i_name] = bint(value_y.output.shape[0])
loc, scale = torch.broadcast_tensors(loc, scale)
loc = Tensor(loc, int_inputs)
scale = Tensor(scale, int_inputs)
y_dist = Independent(Normal(loc, scale, y_i_name), y_name, i_name, y_i_name)
Expand Down
6 changes: 6 additions & 0 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,12 @@ def __truediv__(self, other):
def __rtruediv__(self, other):
return Binary(ops.truediv, to_funsor(other), self)

def __matmul__(self, other):
return Binary(ops.matmul, self, to_funsor(other))

def __rmatmul__(self, other):
return Binary(ops.matmul, to_funsor(other), self)

def __pow__(self, other):
return Binary(ops.pow, self, to_funsor(other))

Expand Down
64 changes: 57 additions & 7 deletions funsor/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import funsor.ops as ops
from funsor.delta import Delta
from funsor.domains import Domain, bint, find_domain, reals
from funsor.ops import GetitemOp, Op, ReshapeOp
from funsor.ops import GetitemOp, MatmulOp, Op, ReshapeOp
from funsor.terms import (
Binary,
Funsor,
Expand Down Expand Up @@ -50,13 +50,15 @@ def _(x, indent, out):
out.append((indent, f"torch.tensor({repr(x.tolist())}, dtype={x.dtype})"))


def align_tensor(new_inputs, x):
def align_tensor(new_inputs, x, expand=False):
r"""
Permute and expand a tensor to match desired ``new_inputs``.
Permute and add dims to a tensor to match desired ``new_inputs``.
:param OrderedDict new_inputs: A target set of inputs.
:param funsor.terms.Funsor x: A :class:`Tensor` or
:class:`~funsor.terms.Number` .
:param bool expand: If False (default), set result size to 1 for any input
of ``x`` not in ``new_inputs``; if True expand to ``new_inputs` size.
:return: a number or :class:`torch.Tensor` that can be broadcast to other
tensors with inputs ``new_inputs``.
:rtype: tuple
Expand All @@ -81,26 +83,33 @@ def align_tensor(new_inputs, x):
# Unsquash multivariate input dims by filling in ones.
data = data.reshape(tuple(old_inputs[k].dtype if k in old_inputs else 1 for k in new_inputs) +
x.output.shape)

# Optionally expand new dims.
if expand:
data = data.expand(tuple(d.dtype for d in new_inputs.values()) + x.output.shape)
return data


def align_tensors(*args):
def align_tensors(*args, **kwargs):
r"""
Permute multiple tensors before applying a broadcasted op.
This is mainly useful for implementing eager funsor operations.
:param funsor.terms.Funsor \*args: Multiple :class:`Tensor` s and
:class:`~funsor.terms.Number` s.
:param bool expand: Whether to expand input tensors. Defaults to False.
:return: a pair ``(inputs, tensors)`` where tensors are all
:class:`torch.Tensor` s that can be broadcast together to a single data
with given ``inputs``.
:rtype: tuple
"""
expand = kwargs.pop('expand', False)
assert not kwargs
inputs = OrderedDict()
for x in args:
inputs.update(x.inputs)
tensors = [align_tensor(inputs, x) for x in args]
tensors = [align_tensor(inputs, x, expand=expand) for x in args]
return inputs, tensors


Expand Down Expand Up @@ -415,8 +424,41 @@ def eager_binary_tensor_tensor(op, lhs, rhs):

# Reshape to support broadcasting of output shape.
if inputs:
lhs_dim = len(lhs.output.shape)
rhs_dim = len(rhs.output.shape)
lhs_dim = len(lhs.shape)
rhs_dim = len(rhs.shape)
if lhs_dim < rhs_dim:
cut = lhs_data.dim() - lhs_dim
shape = lhs_data.shape
shape = shape[:cut] + (1,) * (rhs_dim - lhs_dim) + shape[cut:]
lhs_data = lhs_data.reshape(shape)
elif rhs_dim < lhs_dim:
cut = rhs_data.dim() - rhs_dim
shape = rhs_data.shape
shape = shape[:cut] + (1,) * (lhs_dim - rhs_dim) + shape[cut:]
rhs_data = rhs_data.reshape(shape)

data = op(lhs_data, rhs_data)
return Tensor(data, inputs, dtype)


@eager.register(Binary, MatmulOp, Tensor, Tensor)
def eager_binary_tensor_tensor(op, lhs, rhs):
# Compute inputs and outputs.
dtype = find_domain(op, lhs.output, rhs.output).dtype
if lhs.inputs == rhs.inputs:
inputs = lhs.inputs
lhs_data, rhs_data = lhs.data, rhs.data
else:
inputs, (lhs_data, rhs_data) = align_tensors(lhs, rhs)
if len(lhs.shape) == 1:
lhs_data = lhs_data.unsqueeze(-2)
if len(rhs.shape) == 1:
rhs_data = rhs_data.unsqueeze(-1)

# Reshape to support broadcasting of output shape.
if inputs:
lhs_dim = max(2, len(lhs.shape))
rhs_dim = max(2, len(rhs.shape))
if lhs_dim < rhs_dim:
cut = lhs_data.dim() - lhs_dim
shape = lhs_data.shape
Expand All @@ -428,7 +470,15 @@ def eager_binary_tensor_tensor(op, lhs, rhs):
shape = shape[:cut] + (1,) * (lhs_dim - rhs_dim) + shape[cut:]
rhs_data = rhs_data.reshape(shape)

print(f"lhs.data.shape = {lhs.data.shape}")
print(f"rhs.data.shape = {rhs.data.shape}")
print(f"lhs_data.shape = {lhs_data.shape}")
print(f"rhs_data.shape = {rhs_data.shape}")
data = op(lhs_data, rhs_data)
if len(lhs.shape) == 1:
data = data.squeeze(-2)
if len(rhs.shape) == 1:
data = data.squeeze(-1)
return Tensor(data, inputs, dtype)


Expand Down
20 changes: 20 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,26 @@ def test_binary_broadcast(inputs1, inputs2, output_shape1, output_shape2):
assert_close(actual_block, expected_block)


@pytest.mark.parametrize('output_shape2', [(2,), (2, 5), (4, 2, 5)], ids=str)
@pytest.mark.parametrize('output_shape1', [(2,), (3, 2), (4, 3, 2)], ids=str)
@pytest.mark.parametrize('inputs2', [(), ('a',), ('b', 'a'), ('b', 'c', 'a')], ids=str)
@pytest.mark.parametrize('inputs1', [(), ('a',), ('a', 'b'), ('b', 'a', 'c')], ids=str)
def test_matmul(inputs1, inputs2, output_shape1, output_shape2):
sizes = {'a': 6, 'b': 7, 'c': 8}
inputs1 = OrderedDict((k, bint(sizes[k])) for k in inputs1)
inputs2 = OrderedDict((k, bint(sizes[k])) for k in inputs2)
x1 = random_tensor(inputs1, reals(*output_shape1))
x2 = random_tensor(inputs1, reals(*output_shape2))

actual = x1 @ x2
assert actual.output == find_domain(ops.matmul, x1.output, x2.output)

block = {'a': 1, 'b': 2, 'c': 3}
actual_block = actual(**block)
expected_block = Tensor(x1(**block).data @ x2(**block).data)
assert_close(actual_block, expected_block)


@pytest.mark.parametrize('scalar', [0.5])
@pytest.mark.parametrize('dims', [(), ('a',), ('a', 'b'), ('b', 'a', 'c')])
@pytest.mark.parametrize('symbol', BINARY_OPS)
Expand Down

0 comments on commit 9a5cf22

Please sign in to comment.