diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index 0d21cc4b08..ab0c399f8c 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -17,6 +17,7 @@ from pytensor.tensor.basic import ( arange, as_tensor, + expand_dims, get_vector_length, moveaxis, stack, @@ -176,7 +177,8 @@ def _delta(shape: TensorVariable, axes: Sequence[int]) -> TensorVariable: iotas = [_iota(base_shape, i) for i in range(len(axes))] eyes = [eq(i1, i2) for i1, i2 in pairwise(iotas)] result = reduce(and_, eyes) - return broadcast_to(result, shape) + non_axes = [i for i in range(len(tuple(shape))) if i not in axes] + return broadcast_to(expand_dims(result, non_axes), shape) def _general_dot( diff --git a/tests/tensor/test_einsum.py b/tests/tensor/test_einsum.py index ecd6169a1e..9131cda056 100644 --- a/tests/tensor/test_einsum.py +++ b/tests/tensor/test_einsum.py @@ -67,6 +67,11 @@ def test_delta(): [[1.0, 0.0], [0.0, 1.0]], ) + np.testing.assert_allclose( + _delta((2, 2, 2), (0, 1)).eval(mode=mode), + [[[1, 1], [0, 0]], [[0, 0], [1, 1]]], + ) + def test_general_dot(): rng = np.random.default_rng(45)