diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index 743bf8d8ba..0d21cc4b08 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -214,11 +214,16 @@ def _general_dot( import pytensor.tensor as pt from pytensor.tensor.einsum import _general_dot + import numpy as np + A = pt.tensor(shape = (3, 4, 5)) B = pt.tensor(shape = (3, 5, 2)) result = _general_dot((A, B), axes=[[2], [1]], batch_axes=[[0], [0]]) - print(result.type.shape) + + A_val = np.empty((3, 4, 5)) + B_val = np.empty((3, 5, 2)) + print(result.shape.eval({A:A_val, B:B_val})) .. testoutput::