Skip to content

Commit

Permalink
Fix _delta on non rightmost axes
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Aug 4, 2024
1 parent a8993d6 commit 2413d99
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
4 changes: 3 additions & 1 deletion pytensor/tensor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pytensor.tensor.basic import (
arange,
as_tensor,
expand_dims,
get_vector_length,
moveaxis,
stack,
Expand Down Expand Up @@ -176,7 +177,8 @@ def _delta(shape: TensorVariable, axes: Sequence[int]) -> TensorVariable:
iotas = [_iota(base_shape, i) for i in range(len(axes))]
eyes = [eq(i1, i2) for i1, i2 in pairwise(iotas)]
result = reduce(and_, eyes)
return broadcast_to(result, shape)
non_axes = [i for i in range(len(tuple(shape))) if i not in axes]
return broadcast_to(expand_dims(result, non_axes), shape)


def _general_dot(
Expand Down
5 changes: 5 additions & 0 deletions tests/tensor/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def test_delta():
[[1.0, 0.0], [0.0, 1.0]],
)

np.testing.assert_allclose(
_delta((2, 2, 2), (0, 1)).eval(mode=mode),
[[[1, 1], [0, 0]], [[0, 0], [1, 1]]],
)


def test_general_dot():
rng = np.random.default_rng(45)
Expand Down

0 comments on commit 2413d99

Please sign in to comment.