-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable custom AD primitives #1868
Comments
This would likely be easy: just don't inline functions with custom derivatives until after AD is done. |
I suppose it might make a bit more sense to use the structure
instead, since the attribute is of a function, unless the previously stated structure can easily be handled. |
@nhey is adding something similar to JAX's jvp2 (\x -> #[stop_gradient] x*x) x 1 returns 0 for all This is in theory very simple to implement; for any statement with a jvp2 (\x ->#[stop_gradient] x) x 1 because the body of jvp2 (\x ->#[stop_gradient] (id x)) x 1 doesn't fix it due to inlining. You can hack it to work by creating a def stop_gradient 't (x: t): t =
#[noinline] (#[stop_gradient] (id x))
jvp2 (\x -> stop_gradient x) x 1 but then you've polluted things with non-inlined with identity functions, post-AD pass. So it seems like the right way to do this is make the in-lining passes aware of @athas Is this the right way to do it? I'm posting this on this issue because it seems like the basic implementation machinery is the same between the two features. You could even define |
I can't think of any other way to do it. You might lose out on some pre-AD optimisation opportunities, but hopefully nothing significant - and it's difficult to see how passes such as fusion should propagate |
…ment. Effectively stops the compiler from generating adjoint code for the argument expression (as explained by @zfnmxt in diku-dk#1868). This is useful for (amongst other things) implementing variational Bayesian methods and hacking gradients together for non-differentiable functions since it let's us treat any variable as a constant. Co-authored-by: zfnmxt <[email protected]>
…ment. Effectively stops the compiler from generating adjoint code for the argument expression (as explained by @zfnmxt in diku-dk#1868). This is useful for (amongst other things) implementing variational Bayesian methods and hacking gradients together for non-differentiable functions since it lets us treat any variable as a constant. Co-authored-by: zfnmxt <[email protected]>
I'm still learning about Futhark and haven't been able to find much about its AD from the docs, so excuse my ignorance here; are you folks using an approach that splits AD into forward-mode and transposition, as described in this POPL 2023 paper? In my (admittedly, limited) experience implementing AD systems, this is nice because it means that the user can simply specify a custom JVP for a function, and then the compiler can transpose that to produce a custom VJP that agrees with that JVP by construction. I think this is what JAX does, although I'm not entirely sure. Would this apply for the FFT example @FluxusMagna described? (I can imagine there existing an example where the JVP is not actually easier to specify than the VJP, although I have not encountered one yet.) |
Futhark doesn't do reverse-mode AD via explicit transposition, no. There's very little on Futhark's forward-mode AD because it's so simple and just corresponds to the classic dual number formulation (see section 3.1.1 here). Reverse-mode AD in Futhark is discussed in our SC paper here as well as here. At any rate, the whole point of custom derivatives is to spit out better code than the AD transformation can. The transposition approach doesn't inherently ameliorate this problem---the resulting code is only as good as the transposition transformation is! |
Makes sense! Yeah, I'm familiar with the contrast between the very simple dual number formulation for forward-mode and the much more complicated reverse-mode. Thanks for the links to those papers about your reverse-mode AD approach, I'll check those out! I'm not sure I quite agree with your point about that being "the whole point of custom derivatives"; for instance, let's say you have a function which takes a polynomial and computes the roots of that polynomial. If you use some sort of iterative method to compute the polynomial roots, then it'd be inefficient to transform that code, and would be much better to use implicit differentiation instead. In this case, if I'm not mistaken, transposition of the custom implicitly differentiated JVP would still be much better than direct reverse-mode AD of the iterative method, no? |
That's a good point; there are definitely benefits to having a transposition transformation and, conceptually, the decomposition of reverse-mode into forward-mode + transposition is really nice (as well as having first-class linear map support in a language). And there surely are cases where a custom JVP is simple, but the VJP is still complex (because the JVP just happens to be something that's difficult to transpose) so you wouldn't want to write out the custom VJP. |
I think for the fft-case the JVP and VJP are actually the same, because the operation is equivalent to multiplication with a symmetric matrix. It's a very special case in that sense though. |
Some functions have a known expression for calculating the gradient that is better than the 'naive' AD. Perhaps the most obvious one is the FFT, which is in essence just a matrix multiplication, but surely other cases exist too. Assuming the expression that is to be differentiated contains such a function we would currently need to manually separate the differentiation of this component to apply our own differentiation scheme to it.
I think it would be neat if we could instead provide the compiler with information of what expression should replace for example the vector-jacobian product. One idea of how this could be done is through an attribute, like
where the expression in the attribute replaces
vjp fft
. This way efficient gradient definitions of relevant expressions can be defined in libraries and the user won't have to think about it.Similar to the
#[unsafe]
attribute, it could be disabled with some compiler option.I have no idea how well this would fit in with the current AD-machinery, but I think it makes sense syntactically.
The text was updated successfully, but these errors were encountered: