-
Notifications
You must be signed in to change notification settings - Fork 613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
2nd order gradients for activations #1099
Comments
As of me, it's interesting, it's understandable as the math, it's feasible as the coding. However, I need to be guided on how to integrate it into TFA seamless. So unless maintainers or anyone else much more experienced want to take this on, I would be happy to try. |
Thanks @veqtor for bringing this up! From my understanding higher order gradients should be automatically differentiated if we have our setup correct: If I run:
I get the correct first derrivative, but the second order fails for: @failure-to-thrive It would be great if you want to look into this! I haven't fully looked into this, but it seems to be related to properly registering in the gradient registry. Hand calculating 2nd order grads shouldn't be required except for some test cases (IIUC) |
It's true if activation function is expressed with tensorflow ops. However, TFA activations (most? all?) deal with C++ code. Every TFA C++ activation has its |
You're right. This may be helpful while working on this: |
@veqtor Do you want to participate as a beta-tester? |
@failure-to-thrive sure would, but I don't know if I can build tfa incl cuda deps etc |
This complicates the things. I have to find out how to build a .whl for your OS. Perhaps the same way as packages for PyPI the maintainers do. |
Perhaps writing tests for the 2nd order grads is better? |
Of course, unittests is a first-line defense against bugs. 🪲 🪲 🪲 But, what if some of them sneak anyway? 🪲 Pushing changes through the main TFA repo is not a good idea. Although, it is too early to think about it. |
I can try to build TFA for my platform. Maybe start with Mish: definition: https://www.wolframalpha.com/input/?i=x+*+tanh%28log%281+%2B+exp%28x%29%29%29 First order derivative: Second order: |
I looked a bit at the code for mish grads and the 2nd order derivatives, maybe it can help: |
So here it is! https://github.com/failure-to-thrive/addons/tree/2nd-order-gradients-for-activations import tensorflow as tf
x = tf.Variable([-2.0, -1.0, 0.0, 1.0, 2.0])
def _mish_py(x):
return x * tf.math.tanh(tf.math.softplus(x))
with tf.GradientTape() as gg:
with tf.GradientTape() as g:
y = _mish_py(x)
dy_dx = g.gradient(y, x)
d2y_dx2 = gg.gradient(dy_dx, x)
print("_mish_py", d2y_dx2.numpy())
from tensorflow_addons.activations import mish
with tf.GradientTape() as gg:
with tf.GradientTape() as g:
y = mish(x)
dy_dx = g.gradient(y, x)
d2y_dx2 = gg.gradient(dy_dx, x)
print("mish ", d2y_dx2.numpy()) The output is almost identical:
|
lgtm only I haven't tried it except small experiments |
Closes the issue as we changed to pure python ops. |
Describe the feature and the current behavior/state.
Currently the activation functions in tf-addons are missing 2nd order gradients, this makes it impossible to use them for training GAN's that need various forms of gradient penalties (WGAN-GP, StyleGAN 1/2, etc).
I suggest adding 2nd order gradients for these functions
Relevant information
No
No
different for every activation function
Unknown
No
Which API type would this fall under (layer, metric, optimizer, etc.)
activations
Who will benefit with this feature?
Anyone doing research and/or training GAN's using activation functions in tf-addons
Any other info.
The text was updated successfully, but these errors were encountered: