Skip to content

Commit

Permalink
Support aligned batch dims on the left
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed May 7, 2024
1 parent 180ef9d commit 2b6bfc3
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 12 deletions.
78 changes: 72 additions & 6 deletions pytensor/tensor/einsum.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import collections
from collections.abc import Sequence
from functools import reduce
from functools import partial, reduce
from itertools import pairwise
from typing import cast

from numpy.core.numeric import normalize_axis_index # type: ignore
import numpy as np
from numpy.core.numeric import ( # type: ignore
normalize_axis_index,
normalize_axis_tuple,
)

from pytensor.compile.builders import OpFromGraph
from pytensor.tensor import vectorize
from pytensor.tensor.basic import (
arange,
get_vector_length,
Expand Down Expand Up @@ -54,6 +60,62 @@ def _removechars(s, chars):
return s.translate(str.maketrans(dict.fromkeys(chars)))


def _batched_tensordot(
vars: tuple[TensorVariable, TensorVariable],
axes: Sequence[Sequence[int]], # Should be length 2,
batch_axes: Sequence[Sequence[int]], # Should be length 2,
) -> TensorVariable:
# Shortcut for non batched case
if not batch_axes[0] and not batch_axes[1]:
return tensordot(*vars, axes=axes)

# Normalize axes, thankfully numpy helper does not sort axis!
axes = [
normalize_axis_tuple(var_axes, var.ndim) for var, var_axes in zip(vars, axes)
]
batch_axes = [
normalize_axis_tuple(var_axes, var.ndim)
for var, var_axes in zip(vars, batch_axes)
]
n_batch_axes = [len(var_batch_axes) for var_batch_axes in batch_axes]
if any(
var_batch_axes != tuple(range(var_n_batch_axes))
for var_batch_axes, var_n_batch_axes in zip(batch_axes, n_batch_axes)
):
# Will need to transpose /expand_dims to align batch dims on the left and then transpose back
raise NotImplementedError(
f"Arbitrary batch dims location not yet supported, got: {batch_axes}"
)

lhs, rhs = vars
lhs_axes, rhs_axes = axes
lhs_n_batch_axes, rhs_n_batch_axes = n_batch_axes

# Create signature of tensordot
lhs_signature = [f"l{i}" for i in range(lhs.type.ndim)]
rhs_signature = [f"r{i}" for i in range(rhs.type.ndim)]
# Aligned axes get the same dimension name
for i, (lhs_axis, rhs_axis) in enumerate(zip(lhs_axes, rhs_axes)):
lhs_signature[lhs_axis] = rhs_signature[rhs_axis] = f"a{i}"
# Trim away the batch ndims
lhs_signature = lhs_signature[lhs_n_batch_axes:]
rhs_signature = rhs_signature[rhs_n_batch_axes:]
out_signature = [
lhs_dim for lhs_dim in lhs_signature if not lhs_dim.startswith("a")
] + [rhs_dim for rhs_dim in rhs_signature if not rhs_dim.startswith("a")]
signature = f"({','.join(lhs_signature)}),({','.join(rhs_signature)})->({','.join(out_signature)})"
# Adjust axes for core case
core_lhs_axes = tuple(np.array(lhs_axes) - lhs_n_batch_axes)
core_rhs_axes = tuple(np.array(rhs_axes) - rhs_n_batch_axes)

# TODO: Make sure this looks reasonable after optimizations
# Right now we have some Blockwise(Reshape) that will slow down things!
out = vectorize(
partial(tensordot, axes=[core_lhs_axes, core_rhs_axes]), signature=signature
)(lhs, rhs)
return cast(TensorVariable, out)


def einsum(subscripts: str, *operands):
"""
Multiplication and summation of tensors using the Einstein summation convention.
Expand Down Expand Up @@ -199,8 +261,6 @@ def sum_repeats(
lhs_batch, rhs_batch = tuple(
zip(*[(lhs_names.find(n), rhs_names.find(n)) for n in batch_names])
)
if lhs_batch or rhs_batch:
raise NotImplementedError("Batch dimensions are not yet supported")
else:
lhs_batch = rhs_batch = ()

Expand All @@ -226,10 +286,16 @@ def sum_repeats(
# needing a transpose.
names = batch_names_str + remaining_rhs_names + remaining_lhs_names
if names == result_names:
operand = tensordot(rhs, lhs, (rhs_cont, lhs_cont))
operand = _batched_tensordot(
(rhs, lhs), (rhs_cont, lhs_cont), (rhs_batch, lhs_batch)
)
else:
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
operand = tensordot(lhs, rhs, axes=(lhs_cont, rhs_cont))
operand = _batched_tensordot(
(lhs, rhs),
axes=(lhs_cont, rhs_cont),
batch_axes=(lhs_batch, rhs_batch),
)

# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
assert len(names) == len(result_names) == len(set(names))
Expand Down
6 changes: 3 additions & 3 deletions pytensor/tensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,13 +821,13 @@ def c_code(self, node, name, inputs, outputs, sub):

@_vectorize_node.register(Reshape)
def _vectorize_reshape(op, node, x, shape):
from pytensor.tensor.blockwise import vectorize_node_fallback

old_x, old_shape = node.inputs
batched_ndims = x.type.ndim - old_x.type.ndim

if as_tensor_variable(shape).type.ndim != 1:
raise NotImplementedError(
"It is not possible to vectorize the shape argument of Reshape"
)
return vectorize_node_fallback(op, node, x, shape)

if len(tuple(old_shape)) == len(tuple(shape)):
new_shape = [*x.shape[:batched_ndims], *shape]
Expand Down
38 changes: 35 additions & 3 deletions tests/tensor/test_einsum.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from functools import partial
from string import ascii_lowercase

import numpy as np
import pytest

import pytensor.tensor as pt
from pytensor import Mode
from pytensor.tensor.einsum import _delta, _iota
from pytensor.tensor.einsum import _batched_tensordot, _delta, _iota


def test_iota():
Expand Down Expand Up @@ -39,6 +40,38 @@ def test_delta():
)


def test_batched_tensordot():
mode = Mode(linker="py", optimizer=None)
rng = np.random.default_rng(45)

signature = "(l0,a0,a1,l1),(a1,r0,r1,a0)->(l0,l1,r0,r1)"
tensordot_axes = [(-3, -2), (-1, -4)]

# X has two batch dims
# Y has one batch dim
x = pt.tensor("x", shape=(5, 4, 2, 11, 13, 3))
y = pt.tensor("y", shape=(4, 13, 5, 7, 11))
out = _batched_tensordot((x, y), tensordot_axes, [(0, 1), (0,)])

# FIXME: Not a satisfactory graph!
# import pytensor
# fn = pytensor.function([x, y], out)
# print()
# pytensor.dprint(fn, print_type=True)

x_test = rng.normal(size=x.type.shape)
y_test = rng.normal(size=y.type.shape)

np_batched_tensordot = np.vectorize(
partial(np.tensordot, axes=tensordot_axes), signature=signature
)

np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}, mode=mode),
np_batched_tensordot(x_test, y_test),
)


@pytest.mark.parametrize(
"signature",
[
Expand Down Expand Up @@ -67,14 +100,13 @@ def test_parse_einsum_input(signature):
operands = [
pt.tensor(name, shape=shape) for name, shape in zip(ascii_lowercase, shapes)
]
print(len(operands))
out = pt.einsum(signature, *operands)

rng = np.random.default_rng(37)
test_values = [rng.normal(size=shape) for shape in shapes]
np_out = np.einsum(signature, *test_values)

assert out.type.shape == np_out.shape
# assert out.type.shape == np_out.shape # Reshape operations lose static shape
np.testing.assert_allclose(out.eval(dict(zip(operands, test_values))), np_out)


Expand Down

0 comments on commit 2b6bfc3

Please sign in to comment.