-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: implement backward computation for more operators #921
base: master
Are you sure you want to change the base?
Conversation
We use the following formulas to compute the gradients: Let g be `tensor->grad`, let x be `src0`, and let y be `tensor`. For tanh, `g * (1 - tanh^2(x)) = g * (1 - y^2) = g - gy^2`. For sigmoid, `g * (sigmoid(x) * (1 - sigmoid(x))) = g * (y * (1 - y)) = gy - gy^2`.
This comes with a breaking change: `ggml_clamp` is no longer an in-place operation. If you still want/need that behavior, use `ggml_clamp_inplace`. I hope no one depended on that. Also introduces `GGML_OP_CLAMP_BACK`, whose implementations for other backends will be added in a subsequent commit. The definition of `clamp_back` is as follows: { 0 if x < min d/dx(clamp(x, min, max)) = { 1 if min <= x <= max { 0 if x > max
Slice the gradient using a view operation, reshape, and then add to the inputs' gradients.
Introduces `GGML_UNARY_OP_ELU_BACK`, defined as the following: ELU'(x) = { e^x if x <= 0 { x if x > 0
d/dx(LeakyRELU(x, negative_slope)) = { 1 if x > 0 { negative_slope if x <= 0 The equivalent formula `negative_slope * step(-x) + step(x)` is used for backward computation.
…GELU_BACK` Introduces corresponding `*_BACK` operators for both. Backend-specific accelerated implementations forthcoming.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should add tests to tests/test-grad0.cpp
I'm currently working on adding training support for the MNIST example in #908 . I have a working backward pass for |
d/dx(ELU(x)) is 1 if x >= 0, not x
It might be better to wait for @JohannesGaessler to merge #908 and then continue this PR? |
That's probably best, considering the changes needed for the tests. |
I extended the code in |
Perfect. I plan to finish this PR this weekend. |
This PR will add backward computations for most operators once completed.
Leaving
pad
,im2col
, andnorm
for a future PR now.Currently unsure if I should fuse the multiply + gradient computation for
gelu_back
/gelu_quick_back
like withsilu_back
.