Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix einsum bug #1185

Merged
merged 3 commits into from
Feb 3, 2025
Merged

Fix einsum bug #1185

merged 3 commits into from
Feb 3, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Feb 3, 2025

A shortcut in the numpy implementation of einsum_path when there's nothing to optimize, creates a default path that can combine more than 2 operands. Our implementation only works with 2 or 1 operand operations at each step.

https://github.com/numpy/numpy/blob/cc5851e654bfd82a23f2758be4bd224be84fc1c3/numpy/_core/einsumfunc.py#L945-L951

Closes #1184


📚 Documentation preview 📚: https://pytensor--1185.org.readthedocs.build/en/1185/

@@ -546,8 +552,6 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
"If you need this functionality open an issue in https://github.com/pymc-devs/pytensor/issues to let us know. "
)

# TODO: Is this doing something clever about unknown shapes?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked, they are not doing anything clever, simply set unknown dims to 8 and use whatever comes out of it

@ricardoV94 ricardoV94 force-pushed the fix_bug_einsum branch 2 times, most recently from 902487b to 43a0cc5 Compare February 3, 2025 11:13
Nothing clever was going on, unknown dims were simply faked as having length 8 for JAX polymorphism export
A shortcut in the numpy implementation of einsum_path when there's nothing to optimize, creates a default path that can combine more than 2 operands. Our implementation only works with 2 or 1 operand operations at each step.

https://github.com/numpy/numpy/blob/cc5851e654bfd82a23f2758be4bd224be84fc1c3/numpy/_core/einsumfunc.py#L945-L951
Copy link

codecov bot commented Feb 3, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.29%. Comparing base (884dee9) to head (5de2c22).
Report is 3 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1185   +/-   ##
=======================================
  Coverage   82.29%   82.29%           
=======================================
  Files         186      186           
  Lines       48039    48045    +6     
  Branches     8632     8633    +1     
=======================================
+ Hits        39533    39539    +6     
  Misses       6348     6348           
  Partials     2158     2158           
Files with missing lines Coverage Δ
pytensor/tensor/einsum.py 97.01% <100.00%> (+0.09%) ⬆️

@ricardoV94 ricardoV94 merged commit c22e79e into pymc-devs:main Feb 3, 2025
64 checks passed
@ricardoV94 ricardoV94 changed the title Fix bug einsum Fix einsum bug Feb 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug in einsum
2 participants