Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2nd order gradients for activations #1099

Closed
veqtor opened this issue Feb 17, 2020 · 14 comments
Closed

2nd order gradients for activations #1099

veqtor opened this issue Feb 17, 2020 · 14 comments
Labels
bug Something isn't working custom-ops help wanted Needs help as a contribution

Comments

@veqtor
Copy link

veqtor commented Feb 17, 2020

Describe the feature and the current behavior/state.
Currently the activation functions in tf-addons are missing 2nd order gradients, this makes it impossible to use them for training GAN's that need various forms of gradient penalties (WGAN-GP, StyleGAN 1/2, etc).
I suggest adding 2nd order gradients for these functions

Relevant information

  • Are you willing to contribute it (yes/no):
    No
  • Are you willing to maintain it going forward? (yes/no):
    No
  • Is there a relevant academic paper? (if so, where):
    different for every activation function
  • Is there already an implementation in another framework? (if so, where):
    Unknown
  • Was it part of tf.contrib? (if so, where):
    No

Which API type would this fall under (layer, metric, optimizer, etc.)
activations
Who will benefit with this feature?
Anyone doing research and/or training GAN's using activation functions in tf-addons
Any other info.

@failure-to-thrive
Copy link
Contributor

As of me, it's interesting, it's understandable as the math, it's feasible as the coding. However, I need to be guided on how to integrate it into TFA seamless. So unless maintainers or anyone else much more experienced want to take this on, I would be happy to try.

@seanpmorgan
Copy link
Member

Thanks @veqtor for bringing this up! From my understanding higher order gradients should be automatically differentiated if we have our setup correct:
https://www.tensorflow.org/tutorials/customization/autodiff#higher-order_gradients

If I run:

import tensorflow as tf
import tensorflow_addons as tfa

x = tf.Variable(1.0) 

with tf.GradientTape() as t:
  with tf.GradientTape() as t2:
    y = tfa.activations.gelu(x)
  # Compute the gradient inside the 't' context manager
  # which means the gradient computation is differentiable as well.
  dy_dx = t2.gradient(y, x)
  print(dy_dx)

d2y_dx2 = t.gradient(dy_dx, x)
print(d2y_dx2)

I get the correct first derrivative, but the second order fails for:
LookupError: gradient registry has no entry for: Addons>GeluGrad

@failure-to-thrive It would be great if you want to look into this! I haven't fully looked into this, but it seems to be related to properly registering in the gradient registry. Hand calculating 2nd order grads shouldn't be required except for some test cases (IIUC)

@failure-to-thrive
Copy link
Contributor

From my understanding higher order gradients should be automatically differentiated if we have our setup correct:

It's true if activation function is expressed with tensorflow ops. However, TFA activations (most? all?) deal with C++ code. Every TFA C++ activation has its *Grad successor.

@seanpmorgan seanpmorgan added bug Something isn't working custom-ops help wanted Needs help as a contribution labels Feb 18, 2020
@seanpmorgan
Copy link
Member

You're right. This may be helpful while working on this:
https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/python/ops/custom_gradient.py#L146-L168

@failure-to-thrive
Copy link
Contributor

@veqtor Do you want to participate as a beta-tester?

@veqtor
Copy link
Author

veqtor commented Feb 19, 2020

@failure-to-thrive sure would, but I don't know if I can build tfa incl cuda deps etc

@failure-to-thrive
Copy link
Contributor

This complicates the things. I have to find out how to build a .whl for your OS. Perhaps the same way as packages for PyPI the maintainers do.
@seanpmorgan what do you think about it?

@veqtor
Copy link
Author

veqtor commented Feb 19, 2020

Perhaps writing tests for the 2nd order grads is better?
Should be quite easy to verify that the autograph version is the same as the cuda implementation

@failure-to-thrive
Copy link
Contributor

Of course, unittests is a first-line defense against bugs. 🪲 🪲 🪲 But, what if some of them sneak anyway? 🪲 Pushing changes through the main TFA repo is not a good idea. Although, it is too early to think about it.
OK. Could you please suggest what to implement first, math papers and some test values to test against?

@veqtor
Copy link
Author

veqtor commented Feb 20, 2020

@veqtor
Copy link
Author

veqtor commented Feb 22, 2020

I looked a bit at the code for mish grads and the 2nd order derivatives, maybe it can help:
https://gist.github.com/veqtor/794434261abcbb51d67678d5a73caa1d

@failure-to-thrive
Copy link
Contributor

So here it is! https://github.com/failure-to-thrive/addons/tree/2nd-order-gradients-for-activations
Clone and checkout that branch. The rest is the same.
I was unable to find a unittests infrastructure for testing 2nd order derivatives, so here is a small test program:

import tensorflow as tf

x = tf.Variable([-2.0, -1.0, 0.0, 1.0, 2.0])


def _mish_py(x):
    return x * tf.math.tanh(tf.math.softplus(x))

with tf.GradientTape() as gg:
  with tf.GradientTape() as g:
    y = _mish_py(x)
  dy_dx = g.gradient(y, x)
d2y_dx2 = gg.gradient(dy_dx, x)
print("_mish_py", d2y_dx2.numpy())


from tensorflow_addons.activations import mish

with tf.GradientTape() as gg:
  with tf.GradientTape() as g:
    y = mish(x)
  dy_dx = g.gradient(y, x)
d2y_dx2 = gg.gradient(dy_dx, x)
print("mish    ", d2y_dx2.numpy())

The output is almost identical:

_mish_py [ 0.03502709  0.3497057   0.64        0.18468581 -0.05772461]
mish     [ 0.03502715  0.34970567  0.64        0.18468583 -0.05772461]

@veqtor
Copy link
Author

veqtor commented Mar 5, 2020

lgtm only I haven't tried it except small experiments

@WindQAQ
Copy link
Member

WindQAQ commented Dec 17, 2020

Closes the issue as we changed to pure python ops.

@WindQAQ WindQAQ closed this as completed Dec 17, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working custom-ops help wanted Needs help as a contribution
Projects
None yet
Development

No branches or pull requests

4 participants