-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add einsum
#722
Add einsum
#722
Conversation
Are the current tests failing suppose to fail? |
Looks like it's related to the changes to the |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #722 +/- ##
==========================================
+ Coverage 81.60% 81.69% +0.09%
==========================================
Files 179 182 +3
Lines 47271 47585 +314
Branches 11481 11584 +103
==========================================
+ Hits 38574 38875 +301
- Misses 6511 6520 +9
- Partials 2186 2190 +4
|
dd08faa
to
180ef9d
Compare
All cases except those requiring tensordot with batch dims not on the left are passing We may need more tests soon enough |
Can we reuse how numpy implements tensordot?
…On Tue, 7 May 2024, 11:52 Ricardo Vieira, ***@***.***> wrote:
All cases except those requiring tensordot with batch dims not on the left
are passing
We may need more tests soon enough
—
Reply to this email directly, view it on GitHub
<#722 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUL3W5UGQ73SNQMCO2LZBEOQDAVCNFSM6AAAAABGPQHTOGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAOJZGA4DMOBVHE>
.
You are receiving this because your review was requested.Message ID:
***@***.***>
|
We already do that, but numpy doesn't have batched tensordot (except of course through einsum), but we already have batched tensordot working in this PR just not with arbitrary batch axis. Should just need some extra transposes to get the job done |
Huhu convolutions via einsum work :D |
aa36518
to
b1fb4ec
Compare
Now Einsum also works with inputs with unknown static shape (unoptimized ofc). We can add a rewrite for when such Op is found with inputs that now have static shapes (this can be quite relevant in PyMC, when users use |
2e95049
to
ca8bf54
Compare
The ellipsis case is failing due to a bug in opt_einsum: dgasmith/opt_einsum#235 |
Oh I misunderstood!
…On Thu, 11 Jul 2024, 12:38 Ricardo Vieira, ***@***.***> wrote:
Not sure what you mean @zaxtax <https://github.com/zaxtax>, I'm talking
about allowing the "optimize" kwarg like there is in numpy, which defines
what kind of optimization to do: optimize{bool, list, tuple, ‘greedy’,
‘optimal’}, users can pass their custom contraction path as well.
If users pass contraction_path, we don't need to know static shapes. If
users set to greedy/optimal (optimal should be default), we need to know.
But we may find them later only. If they don't want optimize, then we don't
need to obviously
—
Reply to this email directly, view it on GitHub
<#722 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUPHF22H4SECYZCOLV3ZLZOBFAVCNFSM6AAAAABGPQHTOGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMRSGU4TIOJZHE>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Some unrelated jax test failing, probably something that changed in a recent release? https://github.com/pymc-devs/pytensor/actions/runs/10161294793/job/28099514375?pr=722#step:6:778 |
Yes, I am also looking at this now. It's a jax bug that can be recreated easily:
We can ignore it. Looks like their |
459bb77
to
48c663a
Compare
I think I fixed the tests (not the JAX one) and appeased mypy. @jessegrabowski docstrings extensions are left to you |
Stopped force-pushing if you want to take over |
Opened an issue here: jax-ml/jax#22751 I'll hit the docstrings ASAP if that's all that's holding this up |
Great let's just mark it as xfail then |
First pass on docstrings. Working on the doctests revealed two things:
from jax._src.lax.lax import _delta as jax_delta
from pytensor.tensor.einsum import _delta as pt_delta
jax_delta(int, (3, 3, 3), (0,1))
Array([[[1, 1, 1],
[0, 0, 0],
[0, 0, 0]],
[[0, 0, 0],
[1, 1, 1],
[0, 0, 0]],
[[0, 0, 0],
[0, 0, 0],
[1, 1, 1]]], dtype=int32)
pt_delta((3,3,3), (0,1)).astype(int).eval()
array([[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]],
[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]],
[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]]) We seem to always output the
|
Static shape is optional, not a requirement. In our case it probably has to do with the reshape introduced by tensordot and/or Blockwise which doesn't do any special shape inference shape (static or at rewrite) for core shapes. That's something we probably want to address for Blockwise in the Numba backend |
I understand it's optional, but it also shouldn't be discarded if available no? |
We are not discarding anything on purpose but an intermediate op (or blockwise) doesn't know how to provide more precise output shape. There can also be a tradeoff where quite some effort may be needed to figure out static shape that may not be worth it at define time. Anyway the main point is that it shouldn't be a blocker. We can open an issue for whatever Op is losing the static shape and then assess if it's worth the cost or not |
@jessegrabowski I think I fixed the |
70cc6a3
to
a902974
Compare
Co-authored-by: Adrian Seyboldt <[email protected]> Co-authored-by: Jesse Grabowski <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Rob Zinkov <[email protected]>
🎆 |
Description
TODO:
contract_path
with ellipsis fails whenshapes=True
dgasmith/opt_einsum#235)optimize
kwargRelated Issue
einsum
equivalent #57Checklist
Type of change