Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable custom AD primitives #1868

Open
FluxusMagna opened this issue Feb 5, 2023 · 9 comments
Open

Enable custom AD primitives #1868

FluxusMagna opened this issue Feb 5, 2023 · 9 comments
Assignees
Labels
AD Related to automatic differentiation compiler enhancement

Comments

@FluxusMagna
Copy link

Some functions have a known expression for calculating the gradient that is better than the 'naive' AD. Perhaps the most obvious one is the FFT, which is in essence just a matrix multiplication, but surely other cases exist too. Assuming the expression that is to be differentiated contains such a function we would currently need to manually separate the differentiation of this component to apply our own differentiation scheme to it.

I think it would be neat if we could instead provide the compiler with information of what expression should replace for example the vector-jacobian product. One idea of how this could be done is through an attribute, like

def fft a = #[vjp(\_ a -> fft a)]  ???

where the expression in the attribute replaces vjp fft. This way efficient gradient definitions of relevant expressions can be defined in libraries and the user won't have to think about it.

Similar to the #[unsafe] attribute, it could be disabled with some compiler option.

I have no idea how well this would fit in with the current AD-machinery, but I think it makes sense syntactically.

@athas
Copy link
Member

athas commented Feb 5, 2023

I have no idea how well this would fit in with the current AD-machinery, but I think it makes sense syntactically.

This would likely be easy: just don't inline functions with custom derivatives until after AD is done.

@FluxusMagna
Copy link
Author

I suppose it might make a bit more sense to use the structure

def fft = #[vjp(\_ a -> fft a)] (\a -> ???)

instead, since the attribute is of a function, unless the previously stated structure can easily be handled.

@zfnmxt zfnmxt added enhancement compiler AD Related to automatic differentiation labels Feb 7, 2023
@zfnmxt zfnmxt self-assigned this Feb 7, 2023
@zfnmxt
Copy link
Collaborator

zfnmxt commented Apr 16, 2023

@nhey is adding something similar to JAX's stop_gradient function (It just zeroes out (i.e., doesn't compute) the gradient of its argument) via an attribute. For example,

jvp2 (\x -> #[stop_gradient] x*x) x 1

returns 0 for all x.

This is in theory very simple to implement; for any statement with a stop_gradient attribute AD only inserts the primal statement and doesn't do anything else. Unfortunately, this doesn't quite work for things like

jvp2 (\x ->#[stop_gradient] x) x 1

because the body of \x ->#[stop_gradient] x will have no statements in the IR and the attribute is forgotten. Even

jvp2 (\x ->#[stop_gradient] (id x)) x 1

doesn't fix it due to inlining. You can hack it to work by creating a stop_gradient function which doesn't inline the application

def stop_gradient 't (x: t): t =
  #[noinline] (#[stop_gradient] (id x))
  
jvp2 (\x -> stop_gradient x) x 1

but then you've polluted things with non-inlined with identity functions, post-AD pass. So it seems like the right way to do this is make the in-lining passes aware of #[stop_gradient] (treating it like #[noinline]) and then, during AD, you handle #[stop_gradient] as described and also remove it from the attributes set. Post-AD, you run the in-lining passes again to clean things up.

@athas Is this the right way to do it?

I'm posting this on this issue because it seems like the basic implementation machinery is the same between the two features. You could even define stop_gradient x = #[vjp(\_ _ -> 0),jvp(\_ _ -> 0)] x (although the attribute system doesn't support expressions so I guess this would need to be done with identifiers).

@athas
Copy link
Member

athas commented Apr 16, 2023

I can't think of any other way to do it. You might lose out on some pre-AD optimisation opportunities, but hopefully nothing significant - and it's difficult to see how passes such as fusion should propagate #[stop_gradient] anyway.

nhey added a commit to nhey/futhark that referenced this issue Apr 16, 2023
…ment.

Effectively stops the compiler from generating adjoint code for the
argument expression (as explained by @zfnmxt in diku-dk#1868). This is useful
for (amongst other things) implementing variational Bayesian methods
and hacking gradients together for non-differentiable functions
since it let's us treat any variable as a constant.

Co-authored-by: zfnmxt <[email protected]>
nhey added a commit to nhey/futhark that referenced this issue Apr 16, 2023
…ment.

Effectively stops the compiler from generating adjoint code for the
argument expression (as explained by @zfnmxt in diku-dk#1868). This is useful
for (amongst other things) implementing variational Bayesian methods
and hacking gradients together for non-differentiable functions
since it lets us treat any variable as a constant.

Co-authored-by: zfnmxt <[email protected]>
@samestep
Copy link

I'm still learning about Futhark and haven't been able to find much about its AD from the docs, so excuse my ignorance here; are you folks using an approach that splits AD into forward-mode and transposition, as described in this POPL 2023 paper? In my (admittedly, limited) experience implementing AD systems, this is nice because it means that the user can simply specify a custom JVP for a function, and then the compiler can transpose that to produce a custom VJP that agrees with that JVP by construction. I think this is what JAX does, although I'm not entirely sure. Would this apply for the FFT example @FluxusMagna described? (I can imagine there existing an example where the JVP is not actually easier to specify than the VJP, although I have not encountered one yet.)

@zfnmxt
Copy link
Collaborator

zfnmxt commented Dec 29, 2023

Futhark doesn't do reverse-mode AD via explicit transposition, no. There's very little on Futhark's forward-mode AD because it's so simple and just corresponds to the classic dual number formulation (see section 3.1.1 here). Reverse-mode AD in Futhark is discussed in our SC paper here as well as here.

At any rate, the whole point of custom derivatives is to spit out better code than the AD transformation can. The transposition approach doesn't inherently ameliorate this problem---the resulting code is only as good as the transposition transformation is!

@samestep
Copy link

samestep commented Dec 29, 2023

Makes sense! Yeah, I'm familiar with the contrast between the very simple dual number formulation for forward-mode and the much more complicated reverse-mode. Thanks for the links to those papers about your reverse-mode AD approach, I'll check those out!

I'm not sure I quite agree with your point about that being "the whole point of custom derivatives"; for instance, let's say you have a function which takes a polynomial and computes the roots of that polynomial. If you use some sort of iterative method to compute the polynomial roots, then it'd be inefficient to transform that code, and would be much better to use implicit differentiation instead. In this case, if I'm not mistaken, transposition of the custom implicitly differentiated JVP would still be much better than direct reverse-mode AD of the iterative method, no?

@zfnmxt
Copy link
Collaborator

zfnmxt commented Dec 29, 2023

That's a good point; there are definitely benefits to having a transposition transformation and, conceptually, the decomposition of reverse-mode into forward-mode + transposition is really nice (as well as having first-class linear map support in a language). And there surely are cases where a custom JVP is simple, but the VJP is still complex (because the JVP just happens to be something that's difficult to transpose) so you wouldn't want to write out the custom VJP.

@FluxusMagna
Copy link
Author

FluxusMagna commented Jan 15, 2024

I'm still learning about Futhark and haven't been able to find much about its AD from the docs, so excuse my ignorance here; are you folks using an approach that splits AD into forward-mode and transposition, as described in this POPL 2023 paper? In my (admittedly, limited) experience implementing AD systems, this is nice because it means that the user can simply specify a custom JVP for a function, and then the compiler can transpose that to produce a custom VJP that agrees with that JVP by construction. I think this is what JAX does, although I'm not entirely sure. Would this apply for the FFT example @FluxusMagna described? (I can imagine there existing an example where the JVP is not actually easier to specify than the VJP, although I have not encountered one yet.)

I think for the fft-case the JVP and VJP are actually the same, because the operation is equivalent to multiplication with a symmetric matrix. It's a very special case in that sense though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AD Related to automatic differentiation compiler enhancement
Projects
None yet
Development

No branches or pull requests

4 participants