You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've been giving this some thought lately. Initially I was planning to do something like nquad in scipy which is basically just a wrapper around recursive calls to quad. However, this sort of thing in jax ended up causing lots of issues with jit and AD due to the large number of local function definitions, and it also is likely super inefficient on GPU since it's almost entirely sequential. There is probably still some way to make it work but I haven't had time to play with it more, so would welcome contributions for that.
The other main way is "proper" nd quadrature using actual nd rules, rather than iterated 1d rules. I'm still reading up on the theory of this (it's mostly the same as 1d stuff but there's an additional issue of deciding which axis to split cells along, and many of the rules/algorithms are specific to a particular number of dimensions).
Thank you for sharing your thoughts! I guess the main problem with the recursive approach is that you would vmap over conditional control flow which means that you always have to wait for the slowest computation to finish. Therefore, I think some kind of batched nd quadrature rule would be ideal. Maybe it would be sufficient to start with 2 and 3 dimensions?
Hi, big fan of this repository! Are there any updates on n-dimensional quadrature? It would be very useful for me!
The text was updated successfully, but these errors were encountered: