Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor LLaMA 3 refactor #735

Merged
merged 4 commits into from
Aug 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 25 additions & 57 deletions train_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@
TODO: add the actual commands
"""

import argparse
import os
import math
import glob
import inspect
from contextlib import nullcontext
from dataclasses import dataclass
import json
from pathlib import Path
import time
from typing import (
AbstractSet,
Callable,
Collection,
Dict,
Iterator,
Expand Down Expand Up @@ -55,9 +55,6 @@
# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the LLaMA 3.x model

# using a global to toggle flash-attention
FLASH = 0

# Used in Grouped Query Attention (GQA), broadcasts the key and value tensors
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
Expand Down Expand Up @@ -157,6 +154,7 @@ def __init__(self, config):
self.n_rep = self.n_head // self.n_kv_head
self.hd = config.n_embd // config.n_head
self.use_kv = config.use_kv
self.flash = config.flash

self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.hd, bias=False) # key, query, value projections
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # output projection
Expand Down Expand Up @@ -186,9 +184,12 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None):

q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) # (B, NH, T, HD)

if FLASH:
if self.flash:
# flashattention
y = F.scaled_dot_product_attention(q, k, v, mask)
# if T == 1 no need to mask, otherwise the function complains
# scaled_dot_product_attention expects a mask where value of True indicates that the element should take part in attention
# our mask is the opposite, so we need to invert it
y = F.scaled_dot_product_attention(q, k, v, mask == 0 if T > 1 else None)
else:
# manual implementation of attention
# this materializes the large (T,T) matrix for all the queries and keys
Expand Down Expand Up @@ -257,6 +258,7 @@ class LlamaConfig:
use_scaled_rope: bool = True
max_gen_batch_size: int = 4
use_kv: bool = True
flash: bool = False # use flashattention?

def __init__(self, **kwargs):
for k, v in kwargs.items():
Expand Down Expand Up @@ -402,7 +404,7 @@ def unpermute(w, n_heads, dim1, dim2):
def from_pretrained_llama3_hf(cls, model_id):
"""Loads pretrained LLaMA model weights from HuggingFace"""
from transformers import AutoModelForCausalLM, AutoTokenizer
assert model_id == "meta-llama/Meta-Llama-3.1-8B", "Only the 8B-bae model is supported for now"
assert model_id == "meta-llama/Meta-Llama-3.1-8B", "Only the 8B-base model is supported for now"
model_args = LlamaConfig()

model = AutoModelForCausalLM.from_pretrained(model_id)
Expand Down Expand Up @@ -477,7 +479,6 @@ def generate(
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
"""
Expand All @@ -488,45 +489,35 @@ def generate(
max_gen_len (int): Maximum length of the generated text sequence.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.

Returns:
Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.
Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences.

Note:
This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
If logprobs is True, token log probabilities are computed for each generated token.

"""
bsz = len(prompt_tokens)
assert bsz <= self.config.max_gen_batch_size, (bsz, self.config.max_gen_batch_size)
assert bsz <= self.config.max_gen_batch_size, f"Batch size {bsz} exceeds the maximum generation batch size {self.config.max_gen_batch_size}"
device = next(self.parameters()).device

min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
assert max_prompt_len <= self.config.block_size
assert max_prompt_len <= self.config.block_size, f"Prompt length {max_prompt_len} exceeds the maximum block size {self.config.block_size}"
total_len = min(self.config.block_size, max_gen_len + max_prompt_len)

pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device)
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
for idx, t in enumerate(prompt_tokens):
tokens[idx, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)

prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device=device)
input_text_mask = tokens != pad_id

if min_prompt_len == total_len:
logits, _ = self.forward(tokens, start_pos=prev_pos)
token_logprobs = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens,
reduction="none",
ignore_index=pad_id,
)

stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)).to(device)

Expand All @@ -542,41 +533,25 @@ def generate(
# only replace token if prompt has already been generated
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens[:, prev_pos + 1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (
torch.isin(next_token, stop_tokens)
)
eos_reached |= ~input_text_mask[:, cur_pos] & torch.isin(next_token, stop_tokens)
prev_pos = cur_pos
if all(eos_reached):
break

if logprobs:
token_logprobs = token_logprobs.tolist()
out_tokens, out_logprobs = [], []
out_tokens = []
for i, toks in enumerate(tokens.tolist()):
# cut to max gen len
start = 0 if echo else len(prompt_tokens[i])
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
probs = None
if logprobs:
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
# cut to after eos tok if any
for stop_token in self.tokenizer.stop_tokens:
try:
eos_idx = toks.index(stop_token)
toks = toks[:eos_idx]
probs = probs[:eos_idx] if logprobs else None
except ValueError:
pass
out_tokens.append(toks)
out_logprobs.append(probs)
return (out_tokens, out_logprobs if logprobs else None)
return out_tokens

# -----------------------------------------------------------------------------
# sampling utils
Expand Down Expand Up @@ -956,16 +931,14 @@ def print0(*args, **kwargs):
print(*args, **kwargs)

if __name__ == "__main__":
import time
import argparse
print0(f"Running pytorch {torch.version.__version__}")

# default settings will overfit a tiny batch of data
# and save model weights and debug state to disk on the first iteration
parser = argparse.ArgumentParser()
parser.add_argument("--use_hf", type=int, default=1, help="use HuggingFace (default) or use Meta's model")
parser.add_argument("--ckpt_dir", type=str, default=None, help="path to llama3 model checkpoint")
parser.add_argument("--tokenizer_path", type=str, default=None, help="path to llama3 tokenizer")
parser.add_argument("--ckpt_dir", type=str, default=None, help="path to llama3 model checkpoint (needed if use_hf=0)")
parser.add_argument("--tokenizer_path", type=str, default=None, help="path to llama3 tokenizer (needed if use_hf=0)")
# file system input / output
parser.add_argument("--input_bin", type=str, default="dev/data/tinyshakespeare/tiny_shakespeare_val.bin", help="input .bin to train on")
parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on")
Expand Down Expand Up @@ -995,7 +968,6 @@ def print0(*args, **kwargs):
# memory management
parser.add_argument("--device", type=str, default="", help="by default we autodetect, or set it here")
parser.add_argument("--compile", type=int, default=0, help="torch.compile the model")
parser.add_argument("--flash", type=int, default=0, help="use flash attention")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|float16|bfloat16")
parser.add_argument("--zero_stage", type=int, default=0, help="zero redundancy optimizer stage (0/1/2/3)")
# python -> C bridge
Expand Down Expand Up @@ -1049,9 +1021,9 @@ def print0(*args, **kwargs):
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
print(f"using device: {device}")
device_type = 'cuda' if 'cuda' in device else 'cpu'
assert device_type in {'cuda'} # we need to load LLaMA as bf16 on CUDA
assert device_type in {'cuda'}, "GPU required to run LLaMA 3" # we need to load LLaMA as bf16 on CUDA
print(f"using device: {device}")

# calculate gradient accumulation from the desired total batch size and the current run configuration
tokens_per_fwdbwd = B * T * ddp_world_size
Expand All @@ -1074,16 +1046,12 @@ def print0(*args, **kwargs):
if args.tensorcores:
torch.set_float32_matmul_precision('high')

# turn on/off flash attention
assert args.flash in {0, 1}
FLASH = args.flash

# init the model
assert args.ckpt_dir is not None and os.path.exists(args.ckpt_dir), f"llama3 ckpt dir {args.ckpt_dir} does not exist"
assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist"
if args.use_hf:
model = LLaMA.from_pretrained_llama3_hf(args.model)
else: # use Meta's checkpoint
assert args.ckpt_dir is not None and os.path.exists(args.ckpt_dir), f"llama3 ckpt dir {args.ckpt_dir} does not exist"
assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist"
model = LLaMA.from_pretrained_llama3_meta(args.ckpt_dir, args.tokenizer_path)

model.train()
Expand Down Expand Up @@ -1198,7 +1166,7 @@ def get_lr(it):
else: # Meta
prompt_tokens = [model.tokenizer.encode(x, bos=True, eos=False) for x in prompts]

generation_tokens, _ = model.generate(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, logprobs=False, echo=False)
generation_tokens = model.generate(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, echo=False)
results = [{"generation": model.tokenizer.decode(t)} for t in generation_tokens]
for prompt, result in zip(prompts, results):
print(prompt, end="")
Expand Down