Skip to content

Commit

Permalink
Fix _general_dot doctest
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Jul 31, 2024
1 parent 5ba3cc7 commit a8993d6
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion pytensor/tensor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,16 @@ def _general_dot(
import pytensor.tensor as pt
from pytensor.tensor.einsum import _general_dot
import numpy as np
A = pt.tensor(shape = (3, 4, 5))
B = pt.tensor(shape = (3, 5, 2))
result = _general_dot((A, B), axes=[[2], [1]], batch_axes=[[0], [0]])
print(result.type.shape)
A_val = np.empty((3, 4, 5))
B_val = np.empty((3, 5, 2))
print(result.shape.eval({A:A_val, B:B_val}))
.. testoutput::
Expand Down

0 comments on commit a8993d6

Please sign in to comment.