Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LLaMA 3 Python support #725

Merged
merged 36 commits into from
Aug 8, 2024
Merged

Add LLaMA 3 Python support #725

merged 36 commits into from
Aug 8, 2024

Conversation

gordicaleksa
Copy link
Contributor

@gordicaleksa gordicaleksa commented Aug 2, 2024

Add LLaMA 3 support in our Python code acting as a reference.

The code supports only inference right now and is equivalent with nano llama 3.

# -----------------------------------------------------------------------------
# LLaMA building blocks

class RMSNorm(torch.nn.Module):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment about why we're not using nn.RMSNorm maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a tiny bit different numerics compared to llama's reference (tested), will leave a comment, we can later swap in nn.RMSNorm tbh

train_llama3.py Outdated
self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.c_fc2 = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused. two places

return logits, loss

@staticmethod
def adapt_llama_state_dict_keys(checkpoint, config: LlamaConfig):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a small comment/docs on these defs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasn't sure what to add and make it useful, should be fairly obvious (?)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. Why adapt them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because their LLaMA class has different variable names compared to ours (we derive naming from GPT-2) (?)

kk will add it but tbh feels redundant as on a first skim people can see we're renaming keys in the checkpoint dict.

self.eom_id: int = self.special_tokens["<|eom_id|>"]
self.python_tag_id = self.special_tokens["<|python_tag|>"]
self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"]
self.stop_tokens = [
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i fixed this in llama31. these stop tokens are incorrect for model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i fixed it just in a different place, by adding the right eos token; see from_pretrained_llama3_hf and from_pretrained_llama3_meta

we can refactor later once we support chat model?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh you override it there. Hmm I think leaving this here is a bit dangerous and possibly confusing, just as setting class attributes is. Maybe we take it as an arg in our code here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed it using nano llama 3's solution

# -----------------------------------------------------------------------------
# Our own simple Distributed Data Loader

def _peek_data_shard(filename):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code is all messed up and outdated atm here.
we need llama3 tokenizer encoded data. this is actively introducing bugs if someone tries to run it, reading GPT-2 tokenized data in uint16

# 2) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py
# 3) https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/generation.py

Example launches to only benchmark the speed of bfloat16 compiled GPU training:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete launch commands that are incorrect atm

xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can come later in a PR, but let's delete the use of complex. The use of complex here by Meta was a clear mistake, it created a lot of complexity for no good reason and iirc it broke torch.compile for me once earlier. In their latest code (for llamachat) Meta fixed this and they're now using a fully real-valued impl.

train_llama3.py Outdated
self.cache_k = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd))
self.cache_v = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd))

def forward(self, x, freqs_cis=None, start_pos=None, mask: Optional[torch.Tensor] = None):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Random thought I think type hints are dumb, if you delete them places I will basically always accept that change.
I think they are ok for cases where the type is not obvious.
And I prefer comments always strictly, because e.g. in Tensors the important thing is not that it's a tensor, but what its shape is, or what the dtype is, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kind of agree, esp. because they don't help catch an error, types are not enforced like in C

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted the mask annotation everywhere

train_llama3.py Outdated

next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

black pollution

@karpathy karpathy merged commit 6e6a528 into karpathy:master Aug 8, 2024
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants