You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
we need to generate random number and map to input at some point, which we need to split the input key to map the pytree
Per communication with @mattjj offline, there are some general guidelines how we should improve this:
raveling pytrees may not be the most performant way to do things.
one thing you can ask yourself is: how would you do the right thing with plain old flat lists? if you know how to do something with flat lists, like splitting a key the right number of times, then just tree-flattening, working with the flat lists, and tree-unflattening is often the best approach
that said, we had this helper in Autograd: https://github.com/HIPS/autograd/blob/c6d81ce7eede6db801d4e9a92b27ec5d409d0eab/autograd/misc/flatten.py#L30
the venerable issue #190 has me saying it might be a good idea to add: jax-ml/jax#190
flattening seems like it could be a performance footgun because it really ties the compiler's hands to pack everything into one flat vector
so yeah unless you're sure you want to work with flat stuff, maybe avoid it and try to work with pytree mapping/flatteing/unflattening
The text was updated successfully, but these errors were encountered:
junpenglao
changed the title
Refactor internal flatten array and usage of ravel_pytree
Improve usage of ravel_pytree for handling flatten view of PyTree internally
Oct 10, 2022
Did some light benchmarking and using ravel_pytree doesnt seems too terrible. Given that we have quite a few place that flatten view is unavoidable (mostly when we multiple some dense matrix), until we have a good solution for matrix operation of a PyTree (e.g., better than tree-math provides), we will need to work with flatten PyTree that output from ravel_pytree.
I will instead send in a PR to just refactor out some common pattern when using ravel_pytree.
junpenglao
added a commit
to junpenglao/blackjax
that referenced
this issue
Oct 10, 2022
Currently, we have multiple places where we use flatten array. One common pattern is:
In general this is needed because:
*args, **kwargs
already contain flatten arrayPer communication with @mattjj offline, there are some general guidelines how we should improve this:
The text was updated successfully, but these errors were encountered: