-
Notifications
You must be signed in to change notification settings - Fork 168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why is dividing by e**(1/4) for both keys and queries more memory efficient? #5
Comments
Sure. This is because pytorch retains the history of all computation steps (it remembers all inputs and outputs, to compute the gradients). For long sequences, the memory use of a transformer is dominated by the matrix I didn't test this very thoroughly, so there may be some optimizations that I don't know about. But some quick tests with the debugger seem to bear this out. |
Excellent, thank you for the explanation!
…Sent from my iPhone
On 29 Aug 2019, at 07:46, Peter Bloem ***@***.***> wrote:
Sure. This is because pytorch retains the history of all computation steps (it remembers all inputs and outputs, to compute the gradients).
For long sequences, the memory use of a transformer is dominated by the matrix dot (the dot product of all queries and keys). Every time you operate on this matrix pytorch needs to remember another t * t float values. If we move the scaling to the keys and queries, it needs to remember only 2 * t * e float values. If we set t (much) bigger than e, this is more efficient.
I didn't test this very thoroughly, so there may be some optimizations that I don't know about. But some quick tests with the debugger seem to bear this out.
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub, or mute the thread.
|
Does this not make the dividing square-root thing an artefact of pytorch implementation rather than a general thing? I think we gain a computational advantage also by not having to divide the bigger Could you please share your thoughts on this? |
I think this is easier to do with a static computation graph (like in tf 1.0). If you build the graph dynamically, and do something like
you don't know whether dot1 might be re-used in some other branch of the computation. With a static computation graph you can do more optimization during compilation when you know that dot1 is never used again. With
you can tell by inspecting the code, of course, that the first dot object will never be used but I don't think pytorch has that kind of runtime access to the structure of the code. Also, I don't know if you could easily compress the two modules to save memory for the backward in a dynamic computation graph. However, thinking about this again, for multiplying by a constant you can work out the gradients without storing the input values, so it shouldn't matter at any rate. I don't remember how I tested the memory use, but I did notice a clear jump in memory. I'll reopen this ticket to try again. As for moving the multiplication even further back, I don't expect it would make a big difference. Multiplying by a constant will be at most linear in the size of the dot matrix, so it will vanish compared to multiplying the dot matrix by the values. Still, it might be worth testing to see what the impact is. |
Hi,
Would you mind explaining why the follow code is more memory efficient than just dividing one of them by
sqrt(e)
?former/former/modules.py
Lines 48 to 52 in 7b12ae6
Thank you.
The text was updated successfully, but these errors were encountered: