Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some MHA and RoPE refactoring, llama-3.1 rope_scaling #91

Merged
merged 9 commits into from
Aug 30, 2024

Conversation

francoishernandez
Copy link
Member

Notes:

  • the RoPE refactoring was probably not strictly necessary, but it makes things clearer IMO;
  • models using rotary embeddings will need to be reconverted because all rope related settings were moved to a sub config for clarity;
  • this is not extensively tested, but seems to work fine (tested on the example prompt here for instance);
  • our RoPE implementation is not strictly equivalent to the HF one, because we rely on some of the "original" rope tricks based on complex space computation (.polar/.real/.imag), where HF just applies .sin()/.cos() methods, which is not numerically equivalent;
  • additional scaling types can be implemented (e.g. taking inspiration around here... https://github.com/huggingface/transformers/blob/f1a385b1de7e83e2be9b087d1c0646c0c426e2fc/src/transformers/modeling_rope_utils.py)

@vince62s
Copy link
Contributor

our RoPE implementation is not strictly equivalent to the HF one, because we rely on some of the "original" rope tricks based on complex space computation (.polar/.real/.imag), where HF just applies .sin()/.cos() methods, which is not numerically equivalent

not sure about that.
When the model comes from a Hugging Face format, rotary_interleave is False, and in this case we also use cos/sin, (not the Polar formula)
But maybe I am missing something.

@vince62s
Copy link
Contributor

Did you check the speed performance of this refactoring ? recomputing / applying rotary is quite impactful. My only concern is the reassignment of cos/sin which was performed only when shifting each 32 positions before.

@francoishernandez
Copy link
Member Author

francoishernandez commented Aug 28, 2024

our RoPE implementation is not strictly equivalent to the HF one, because we rely on some of the "original" rope tricks based on complex space computation (.polar/.real/.imag), where HF just applies .sin()/.cos() methods, which is not numerically equivalent

not sure about that. When the model comes from a Hugging Face format, rotary_interleave is False, and in this case we also use cos/sin, (not the Polar formula) But maybe I am missing something.

The .polar trick I'm mentioning is in the initial rope computation:
[main]

rope = torch.polar(torch.ones_like(rope), rope)

[refactor]
rope = torch.polar(torch.ones_like(rope), rope)

Afterwards, both the "interleave" and "not" use the real/imag parts of the rope tensor to access cos/sin.

Did you check the speed performance of this refactoring ? recomputing / applying rotary is quite impactful. My only concern is the reassignment of cos/sin which was performed only when shifting each 32 positions before.

Not in depth. It might be worth a look indeed. We can probably have some similar sort of cache to prevent unnecessary recomputations.

@francoishernandez francoishernandez force-pushed the mha_refactor_rope_scaling branch from 2604c7c to 09a7b49 Compare August 28, 2024 13:16
@francoishernandez francoishernandez force-pushed the mha_refactor_rope_scaling branch from 09a7b49 to 3df5ef9 Compare August 28, 2024 13:22
@francoishernandez
Copy link
Member Author

francoishernandez commented Aug 28, 2024

d79b6c3 -> similarly to the previous implementation, we pre-compute the rope further than needed to prevent computing at each step.

Not sure about the true impact as my setup might not be the most stable, but it seems that we lost ~5-10% inference speed, which we re-take here.

@francoishernandez francoishernandez force-pushed the mha_refactor_rope_scaling branch from 3091a4e to d79b6c3 Compare August 28, 2024 14:52
@vince62s
Copy link
Contributor

what is the benefit of setting position_embedding in transformer_decoder.py, transformer_lm_decoder.py, transformer_encoder.py vs directly in mha.py ?

@francoishernandez
Copy link
Member Author

what is the benefit of setting position_embedding in transformer_decoder.py, transformer_lm_decoder.py, transformer_encoder.py vs directly in mha.py ?

That's a good question. It seemed cleaner to have a single "base" RotaryPosition object to compute rope at a higher level, and use it in all underlying layers/MHA, but that's debatable.

@francoishernandez francoishernandez marked this pull request as ready for review August 29, 2024 13:28
@vince62s
Copy link
Contributor

good to merge.

@francoishernandez francoishernandez merged commit b81cce1 into main Aug 30, 2024
4 checks passed
@francoishernandez francoishernandez deleted the mha_refactor_rope_scaling branch February 7, 2025 08:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants