-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LLaMA 3 Python support #725
Conversation
# ----------------------------------------------------------------------------- | ||
# LLaMA building blocks | ||
|
||
class RMSNorm(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment about why we're not using nn.RMSNorm maybe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a tiny bit different numerics compared to llama's reference (tested), will leave a comment, we can later swap in nn.RMSNorm tbh
train_llama3.py
Outdated
self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False) | ||
self.c_fc2 = nn.Linear(config.n_embd, hidden_dim, bias=False) | ||
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) | ||
self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused. two places
return logits, loss | ||
|
||
@staticmethod | ||
def adapt_llama_state_dict_keys(checkpoint, config: LlamaConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a small comment/docs on these defs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wasn't sure what to add and make it useful, should be fairly obvious (?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so. Why adapt them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because their LLaMA class has different variable names compared to ours (we derive naming from GPT-2) (?)
kk will add it but tbh feels redundant as on a first skim people can see we're renaming keys in the checkpoint dict.
self.eom_id: int = self.special_tokens["<|eom_id|>"] | ||
self.python_tag_id = self.special_tokens["<|python_tag|>"] | ||
self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"] | ||
self.stop_tokens = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i fixed this in llama31. these stop tokens are incorrect for model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i fixed it just in a different place, by adding the right eos token; see from_pretrained_llama3_hf
and from_pretrained_llama3_meta
we can refactor later once we support chat model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh you override it there. Hmm I think leaving this here is a bit dangerous and possibly confusing, just as setting class attributes is. Maybe we take it as an arg in our code here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed it using nano llama 3's solution
# ----------------------------------------------------------------------------- | ||
# Our own simple Distributed Data Loader | ||
|
||
def _peek_data_shard(filename): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this code is all messed up and outdated atm here.
we need llama3 tokenizer encoded data. this is actively introducing bugs if someone tries to run it, reading GPT-2 tokenized data in uint16
# 2) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py | ||
# 3) https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/generation.py | ||
|
||
Example launches to only benchmark the speed of bfloat16 compiled GPU training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete launch commands that are incorrect atm
xk: torch.Tensor, | ||
freqs_cis: torch.Tensor, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can come later in a PR, but let's delete the use of complex. The use of complex here by Meta was a clear mistake, it created a lot of complexity for no good reason and iirc it broke torch.compile for me once earlier. In their latest code (for llamachat) Meta fixed this and they're now using a fully real-valued impl.
train_llama3.py
Outdated
self.cache_k = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) | ||
self.cache_v = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) | ||
|
||
def forward(self, x, freqs_cis=None, start_pos=None, mask: Optional[torch.Tensor] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Random thought I think type hints are dumb, if you delete them places I will basically always accept that change.
I think they are ok for cases where the type is not obvious.
And I prefer comments always strictly, because e.g. in Tensors the important thing is not that it's a tensor, but what its shape is, or what the dtype is, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kind of agree, esp. because they don't help catch an error, types are not enforced like in C
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted the mask annotation everywhere
train_llama3.py
Outdated
|
||
next_token = next_token.reshape(-1) | ||
# only replace token if prompt has already been generated | ||
next_token = torch.where( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
black
pollution
Add LLaMA 3 support in our Python code acting as a reference.
The code supports only inference right now and is equivalent with nano llama 3.