You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Adaptive Pooling is a pain in JAX/XLA, at least if you want to get the same results as torch in the general case. (Equinox doesn't aim for torch equivalence here, and other major JAX frameworks don't even implement it, except sometimes for the boring case when input % output == 0)
https://stackoverflow.com/a/63603993/1736826 seems to be a correct description of how it works, at least in the 1-d case. You end up with overlapping windows of differing sizes, which can't be turned into a reduce_window call (I think?) I think you can maybe get it done with a second quasi-mask argument to reduce window, but i haven't figured it out yet.
Maybe going for torch-equivalence here isn't worth it? Maybe it can be done w/ Pallas?
Seems like the easiest thing would be to do ~symmetric padding but no one seems to do it that way?
The text was updated successfully, but these errors were encountered:
Adaptive Pooling is a pain in JAX/XLA, at least if you want to get the same results as torch in the general case. (Equinox doesn't aim for torch equivalence here, and other major JAX frameworks don't even implement it, except sometimes for the boring case when input % output == 0)
https://stackoverflow.com/a/63603993/1736826 seems to be a correct description of how it works, at least in the 1-d case. You end up with overlapping windows of differing sizes, which can't be turned into a reduce_window call (I think?) I think you can maybe get it done with a second quasi-mask argument to reduce window, but i haven't figured it out yet.
Maybe going for torch-equivalence here isn't worth it? Maybe it can be done w/ Pallas?
Seems like the easiest thing would be to do ~symmetric padding but no one seems to do it that way?
The text was updated successfully, but these errors were encountered: