-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Some MHA and RoPE refactoring, llama-3.1 rope_scaling #91
Conversation
not sure about that. |
Did you check the speed performance of this refactoring ? recomputing / applying rotary is quite impactful. My only concern is the reassignment of cos/sin which was performed only when shifting each 32 positions before. |
The eole/eole/modules/multi_headed_attn.py Line 28 in 9c8cc5b
[refactor] Line 139 in 4eb4853
Afterwards, both the "interleave" and "not" use the real/imag parts of the rope tensor to access cos/sin.
Not in depth. It might be worth a look indeed. We can probably have some similar sort of cache to prevent unnecessary recomputations. |
2604c7c
to
09a7b49
Compare
09a7b49
to
3df5ef9
Compare
d79b6c3 -> similarly to the previous implementation, we pre-compute the rope further than needed to prevent computing at each step. Not sure about the true impact as my setup might not be the most stable, but it seems that we lost ~5-10% inference speed, which we re-take here. |
3091a4e
to
d79b6c3
Compare
what is the benefit of setting position_embedding in transformer_decoder.py, transformer_lm_decoder.py, transformer_encoder.py vs directly in mha.py ? |
That's a good question. It seemed cleaner to have a single "base" RotaryPosition object to compute rope at a higher level, and use it in all underlying layers/MHA, but that's debatable. |
good to merge. |
Notes:
.polar
/.real
/.imag
), where HF just applies.sin()
/.cos()
methods, which is not numerically equivalent;