From 558f0845ac6f307b77451bc91725f0aa6913b681 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 27 Jan 2025 11:35:40 +0100 Subject: [PATCH 1/3] Allow importing `rewrite_graph` from `rewriting` submodule --- pytensor/graph/rewriting/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytensor/graph/rewriting/__init__.py b/pytensor/graph/rewriting/__init__.py index e69de29bb2..52cfca4cfe 100644 --- a/pytensor/graph/rewriting/__init__.py +++ b/pytensor/graph/rewriting/__init__.py @@ -0,0 +1,4 @@ +from pytensor.graph.rewriting.utils import rewrite_graph + + +all = ("rewrite_graph",) From b61d31c90fdd38293bb245bc79cee5132a460881 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 27 Jan 2025 12:15:43 +0100 Subject: [PATCH 2/3] Don't run unrelated tests in altenarnative backends --- tests/link/jax/test_elemwise.py | 4 +- tests/link/jax/test_tensor_basic.py | 4 +- tests/link/numba/test_elemwise.py | 7 +++- tests/link/numba/test_tensor_basic.py | 4 +- tests/tensor/test_basic.py | 54 ++++++++++++------------- tests/tensor/test_elemwise.py | 58 +++++++++++++-------------- 6 files changed, 67 insertions(+), 64 deletions(-) diff --git a/tests/link/jax/test_elemwise.py b/tests/link/jax/test_elemwise.py index 88d5c21925..687049f7e1 100644 --- a/tests/link/jax/test_elemwise.py +++ b/tests/link/jax/test_elemwise.py @@ -15,11 +15,11 @@ from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax from pytensor.tensor.type import matrix, tensor, vector, vectors from tests.link.jax.test_basic import compare_jax_and_py -from tests.tensor.test_elemwise import TestElemwise +from tests.tensor.test_elemwise import check_elemwise_runtime_broadcast def test_elemwise_runtime_broadcast(): - TestElemwise.check_runtime_broadcast(get_mode("JAX")) + check_elemwise_runtime_broadcast(get_mode("JAX")) def test_jax_Dimshuffle(): diff --git a/tests/link/jax/test_tensor_basic.py b/tests/link/jax/test_tensor_basic.py index 0ee4a236d9..75ca673d78 100644 --- a/tests/link/jax/test_tensor_basic.py +++ b/tests/link/jax/test_tensor_basic.py @@ -14,7 +14,7 @@ from pytensor.graph.op import get_test_value from pytensor.tensor.type import iscalar, matrix, scalar, vector from tests.link.jax.test_basic import compare_jax_and_py -from tests.tensor.test_basic import TestAlloc +from tests.tensor.test_basic import check_alloc_runtime_broadcast def test_jax_Alloc(): @@ -54,7 +54,7 @@ def compare_shape_dtype(x, y): def test_alloc_runtime_broadcast(): - TestAlloc.check_runtime_broadcast(get_mode("JAX")) + check_alloc_runtime_broadcast(get_mode("JAX")) def test_jax_MakeVector(): diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 862ea1a2e2..1da34ff392 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -24,7 +24,10 @@ scalar_my_multi_out, set_test_value, ) -from tests.tensor.test_elemwise import TestElemwise, careduce_benchmark_tester +from tests.tensor.test_elemwise import ( + careduce_benchmark_tester, + check_elemwise_runtime_broadcast, +) rng = np.random.default_rng(42849) @@ -124,7 +127,7 @@ def test_Elemwise(inputs, input_vals, output_fn, exc): @pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults") def test_elemwise_runtime_broadcast(): - TestElemwise.check_runtime_broadcast(get_mode("NUMBA")) + check_elemwise_runtime_broadcast(get_mode("NUMBA")) def test_elemwise_speed(benchmark): diff --git a/tests/link/numba/test_tensor_basic.py b/tests/link/numba/test_tensor_basic.py index 269fc57940..95ab5799c1 100644 --- a/tests/link/numba/test_tensor_basic.py +++ b/tests/link/numba/test_tensor_basic.py @@ -16,7 +16,7 @@ compare_shape_dtype, set_test_value, ) -from tests.tensor.test_basic import TestAlloc +from tests.tensor.test_basic import check_alloc_runtime_broadcast pytest.importorskip("numba") @@ -52,7 +52,7 @@ def test_Alloc(v, shape): def test_alloc_runtime_broadcast(): - TestAlloc.check_runtime_broadcast(get_mode("NUMBA")) + check_alloc_runtime_broadcast(get_mode("NUMBA")) def test_AllocEmpty(): diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index ff8751e411..754859fa6f 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -716,6 +716,32 @@ def test_masked_array_not_implemented( ptb.as_tensor(x) +def check_alloc_runtime_broadcast(mode): + """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules.""" + floatX = config.floatX + x_v = vector("x", shape=(None,)) + + out = alloc(x_v, 5, 3) + f = pytensor.function([x_v], out, mode=mode) + TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1) + + np.testing.assert_array_equal( + f(x=np.zeros((3,), dtype=floatX)), + np.zeros((5, 3), dtype=floatX), + ) + with pytest.raises(ValueError, match="Runtime broadcasting not allowed"): + f(x=np.zeros((1,), dtype=floatX)) + + out = alloc(specify_shape(x_v, (1,)), 5, 3) + f = pytensor.function([x_v], out, mode=mode) + TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1) + + np.testing.assert_array_equal( + f(x=np.zeros((1,), dtype=floatX)), + np.zeros((5, 3), dtype=floatX), + ) + + class TestAlloc: dtype = config.floatX mode = mode_opt @@ -729,32 +755,6 @@ def check_allocs_in_fgraph(fgraph, n): == n ) - @staticmethod - def check_runtime_broadcast(mode): - """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules.""" - floatX = config.floatX - x_v = vector("x", shape=(None,)) - - out = alloc(x_v, 5, 3) - f = pytensor.function([x_v], out, mode=mode) - TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1) - - np.testing.assert_array_equal( - f(x=np.zeros((3,), dtype=floatX)), - np.zeros((5, 3), dtype=floatX), - ) - with pytest.raises(ValueError, match="Runtime broadcasting not allowed"): - f(x=np.zeros((1,), dtype=floatX)) - - out = alloc(specify_shape(x_v, (1,)), 5, 3) - f = pytensor.function([x_v], out, mode=mode) - TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1) - - np.testing.assert_array_equal( - f(x=np.zeros((1,), dtype=floatX)), - np.zeros((5, 3), dtype=floatX), - ) - def setup_method(self): self.rng = np.random.default_rng(seed=utt.fetch_seed()) @@ -912,7 +912,7 @@ def test_alloc_of_view_linker(self): @pytest.mark.parametrize("mode", (Mode("py"), Mode("c"))) def test_runtime_broadcast(self, mode): - self.check_runtime_broadcast(mode) + check_alloc_runtime_broadcast(mode) def test_infer_static_shape(): diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 081e495127..bd208c5848 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -705,6 +705,33 @@ def test_any_grad(self): assert np.all(gx_val == 0) +def check_elemwise_runtime_broadcast(mode): + """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules.""" + x_v = matrix("x") + m_v = vector("m") + + z_v = x_v - m_v + f = pytensor.function([x_v, m_v], z_v, mode=mode) + + # Test invalid broadcasting by either x or m + for x_sh, m_sh in [((2, 1), (3,)), ((2, 3), (1,))]: + x = np.ones(x_sh).astype(config.floatX) + m = np.zeros(m_sh).astype(config.floatX) + + # This error is introduced by PyTensor, so it's the same across different backends + with pytest.raises(ValueError, match="Runtime broadcasting not allowed"): + f(x, m) + + x = np.ones((2, 3)).astype(config.floatX) + m = np.zeros((1,)).astype(config.floatX) + + x = np.ones((2, 4)).astype(config.floatX) + m = np.zeros((3,)).astype(config.floatX) + # This error is backend specific, and may have different types + with pytest.raises((ValueError, TypeError)): + f(x, m) + + class TestElemwise(unittest_tools.InferShapeTester): def test_elemwise_grad_bool(self): x = scalar(dtype="bool") @@ -750,42 +777,15 @@ def test_input_dimensions_overflow(self): g = pytensor.function([a, b, c, d, e, f], s, mode=Mode(linker="py")) g(*[np.zeros(2**11, config.floatX) for i in range(6)]) - @staticmethod - def check_runtime_broadcast(mode): - """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules.""" - x_v = matrix("x") - m_v = vector("m") - - z_v = x_v - m_v - f = pytensor.function([x_v, m_v], z_v, mode=mode) - - # Test invalid broadcasting by either x or m - for x_sh, m_sh in [((2, 1), (3,)), ((2, 3), (1,))]: - x = np.ones(x_sh).astype(config.floatX) - m = np.zeros(m_sh).astype(config.floatX) - - # This error is introduced by PyTensor, so it's the same across different backends - with pytest.raises(ValueError, match="Runtime broadcasting not allowed"): - f(x, m) - - x = np.ones((2, 3)).astype(config.floatX) - m = np.zeros((1,)).astype(config.floatX) - - x = np.ones((2, 4)).astype(config.floatX) - m = np.zeros((3,)).astype(config.floatX) - # This error is backend specific, and may have different types - with pytest.raises((ValueError, TypeError)): - f(x, m) - def test_runtime_broadcast_python(self): - self.check_runtime_broadcast(Mode(linker="py")) + check_elemwise_runtime_broadcast(Mode(linker="py")) @pytest.mark.skipif( not pytensor.config.cxx, reason="G++ not available, so we need to skip this test.", ) def test_runtime_broadcast_c(self): - self.check_runtime_broadcast(Mode(linker="c")) + check_elemwise_runtime_broadcast(Mode(linker="c")) def test_str(self): op = Elemwise(ps.add, inplace_pattern={0: 0}, name=None) From 8e08d2f3940cf5c1cba80919e1dd49fff373efc3 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 27 Jan 2025 11:38:17 +0100 Subject: [PATCH 3/3] Rewrite batched dots that do not reduce as multiplication --- pytensor/tensor/math.py | 42 ++++++++++++++++---- pytensor/tensor/rewriting/math.py | 60 +++++++++++++++++++++++++++++ tests/tensor/rewriting/test_math.py | 53 ++++++++++++++++++++++++- 3 files changed, 146 insertions(+), 9 deletions(-) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index efcc2500a7..f11e33b41d 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -29,7 +29,7 @@ stack, switch, ) -from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import ( CAReduce, Elemwise, @@ -2726,6 +2726,22 @@ def logsumexp(x, axis=None, keepdims=False): return log(sum(exp(x), axis=axis, keepdims=keepdims)) +# Predefine all batched variations of Dot +_inner_prod = Blockwise( + _dot, + signature="(n),(n)->()", +) + +_matrix_vec_prod = Blockwise( + _dot, + signature="(m,k),(k)->(m)", +) + +_vec_matrix_prod = Blockwise( + _dot, + signature="(k),(k,n)->(n)", +) + _matrix_matrix_matmul = Blockwise( _dot, signature="(m,k),(k,n)->(m,n)", @@ -2795,14 +2811,24 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None @_vectorize_node.register(Dot) -def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y): +def vectorize_node_dot(op, node, batched_x, batched_y): old_x, old_y = node.inputs - if old_x.type.ndim == 2 and old_y.type.ndim == 2: - # If original input is equivalent to a matrix-matrix product, - # return specialized Matmul Op to avoid unnecessary new Ops. - return matmul(batched_x, batched_y).owner - else: - return vectorize_node_fallback(op, node, batched_x, batched_y) + old_x_ndim = old_x.type.ndim + old_y_ndim = old_y.type.ndim + match (old_x_ndim, old_y_ndim): + case (1, 1): + batch_op = _inner_prod + case (2, 1): + batch_op = _matrix_vec_prod + case (1, 2): + batch_op = _vec_matrix_prod + case (2, 2): + batch_op = _matrix_matrix_matmul + case _: + raise ValueError( + f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D." + ) + return batch_op(batched_x, batched_y).owner def nan_to_num(x, nan=0.0, posinf=None, neginf=None): diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 03fa1ae094..065ecfc0b1 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -44,6 +44,10 @@ Prod, Sum, _conj, + _inner_prod, + _matrix_matrix_matmul, + _matrix_vec_prod, + _vec_matrix_prod, add, digamma, dot, @@ -242,6 +246,62 @@ def local_batched_matmul_to_core_matmul(fgraph, node): return None +@register_canonicalize +@register_specialize +@node_rewriter([_inner_prod, _matrix_vec_prod, _vec_matrix_prod, _matrix_matrix_matmul]) +def local_blockwise_dot_to_mul(fgraph, node): + """Rewrite blockwise dots that correspond to multiplication without summation. + + We don't touch the regular dot, to not interfere with the BLAS optimizations. + """ + a, b = node.inputs + a_static_shape = a.type.shape + b_static_shape = b.type.shape + core_a_ndim = len(node.op.inputs_sig[0]) + core_b_ndim = len(node.op.inputs_sig[1]) + + if core_a_ndim > 2 or core_b_ndim > 2: + # Shouldn't happen, but here just in case + return None + + if core_b_ndim == 1: + if a_static_shape[-1] == 1 or b_static_shape[-1] == 1: + if core_a_ndim == 1: + # inner product: (..., 1) * (..., 1) -> (...) + # just squeeze the last dimensions of a and b + new_a = a.squeeze(-1) + new_b = b.squeeze(-1) + else: + # matrix vector product: (..., m, 1) * (..., 1) -> (..., m) + # the last dimension of b is already aligned for the elemwise multiplication + # after we squeeze the last dimension of a + new_a = a.squeeze(-1) + new_b = b + else: + return None + + else: + if a_static_shape[-1] == 1 or b_static_shape[-2] == 1: + if core_a_ndim == 1: + # vector_matrix product: (..., 1) * (..., 1, n) -> (..., n) + # the last dimension of a is already aligned for the elemwise multiplication + # after we squeeze the one to last dimension of b + new_a = a + new_b = b.squeeze(-2) + else: + # matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n) + # the dimensions of a and b are already aligned for the elemwise multiplication + new_a = a + new_b = b + else: + return None + + new_a = copy_stack_trace(a, new_a) + new_b = copy_stack_trace(b, new_b) + new_out = copy_stack_trace(node.out, mul(new_a, new_b)) + return [new_out] + + def is_inverse_pair(node_op, prev_op, inv_pair): """ Given two consecutive operations, check if they are the diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index d344d29dad..a1759ef81b 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -16,7 +16,8 @@ from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp, deep_copy_op from pytensor.configdefaults import config -from pytensor.graph.basic import Apply, equal_computations +from pytensor.graph import vectorize_graph +from pytensor.graph.basic import Apply, ancestors, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import ( SequentialNodeRewriter, @@ -4590,3 +4591,53 @@ def test_pow_1_rewrite(shape): x_val = np.random.random(shape).astype(config.floatX) np.testing.assert_allclose(z.eval({x: x_val}), f(x_val)) + + +@pytest.mark.parametrize( + "a_shape,b_shape", + [ + ((1,), (1,)), + ((3, 1), (1,)), + ((1,), (1, 3)), + ((3, 1), (1, 3)), + ], + ids=str, +) +@pytest.mark.parametrize("batched", (False, True)) +def test_local_dot_to_mul(batched, a_shape, b_shape): + a = tensor("a", shape=a_shape) + b = tensor("b", shape=b_shape) + + out = dot(a, b) + if batched: + batch_a = tensor("batch_a", shape=(1, 5, *a_shape)) + batch_b = tensor("batch_b", shape=(7, 1, *b_shape)) + out = vectorize_graph(out, {a: batch_a, b: batch_b}) + a = batch_a + b = batch_b + + assert ( + sum( + isinstance(var.owner.op, (Blockwise | Dot)) + for var in ancestors([out]) + if var.owner + ) + == 1 + ) + + # For now rewrite only applies to Batched Dots + rewritten_out = rewrite_graph(out) + assert rewritten_out.type.shape == out.type.shape + assert sum( + isinstance(var.owner.op, (Blockwise | Dot)) + for var in ancestors([rewritten_out]) + if var.owner + ) == (0 if batched else 1) + + a_test = np.random.normal(size=a.type.shape).astype(a.type.dtype) + b_test = np.random.normal(size=b.type.shape).astype(b.type.dtype) + test_mode = Mode(linker="py", optimizer=None) + np.testing.assert_allclose( + out.eval({a: a_test, b: b_test}, mode=test_mode), + rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode), + )