Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Demo equivalence - tmp #730

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
41bf8e0
Equivalent with nano llama 3
gordicaleksa Aug 2, 2024
838cd13
Refactor
gordicaleksa Aug 2, 2024
c414d02
Minor refactor
gordicaleksa Aug 2, 2024
465aac4
Equivalent to nano llama 3 reference code
gordicaleksa Aug 3, 2024
f50f2de
Refactor attn, change numerics but equivalent
gordicaleksa Aug 3, 2024
c0c08ba
Have prompts in a file instead of inline, prompt 4 is different
gordicaleksa Aug 3, 2024
de879d1
Refactor checkpoint state dict map func
gordicaleksa Aug 3, 2024
0199e51
Refactor MLP
gordicaleksa Aug 3, 2024
fdd5345
Refactor attn mechanism
gordicaleksa Aug 3, 2024
fa7bcc3
One more minor attn fix
gordicaleksa Aug 3, 2024
180215f
Unify generate and generate_llama
gordicaleksa Aug 3, 2024
8919b66
Fix generate for gpt-2
gordicaleksa Aug 3, 2024
ccdbdfd
Going towards pure llama 3 file - fixed attn
gordicaleksa Aug 3, 2024
8a48df7
MLP GPT2->LLaMA3
gordicaleksa Aug 3, 2024
c1d2b7f
Removed from pretrained for GPT-2
gordicaleksa Aug 3, 2024
d855c96
Refactoring - got to main
gordicaleksa Aug 3, 2024
b1acb59
Got to llama 3 inference (end)
gordicaleksa Aug 3, 2024
bad7857
Done - need to test train loop and saving model
gordicaleksa Aug 3, 2024
879cc5f
Remove init weights as it's gpt-2 specific
gordicaleksa Aug 4, 2024
7768a36
Add prompts file
gordicaleksa Aug 4, 2024
cd90273
Fix saving model / state logic
gordicaleksa Aug 4, 2024
4b386a2
Test training loop works
gordicaleksa Aug 4, 2024
0749a4a
Minor refactor - remove wpe pos array from fwd
gordicaleksa Aug 4, 2024
8e55d16
Support HF & Meta models
gordicaleksa Aug 4, 2024
72dcfeb
Remove float(-inf)
gordicaleksa Aug 4, 2024
298a49a
Demo equivalence
gordicaleksa Aug 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llmc_py/prompts.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"prompts": [
"Clearly, the meaning of life is",
"Simply put, the theory of relativity states that",
"The repo llm.c on GitHub is",
"Translate English to French:\n\nsea otter => loutre de mer\npeppermint => menthe poivrée\nplush girafe => girafe peluche\ncheese =>"
]
}
59 changes: 59 additions & 0 deletions llmc_py/rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# From: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py

import math
from typing import Tuple
import torch

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)

def apply_scaling(freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)

def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)

def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
189 changes: 189 additions & 0 deletions llmc_py/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# From: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/tokenizer.py

import os
from pathlib import Path
from typing import (
AbstractSet,
Callable,
Collection,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Union,
cast,
)

import tiktoken
from tiktoken.load import load_tiktoken_bpe

# The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000

# https://github.com/openai/tiktoken/issues/195
# Here we iterate over subsequences and split if we exceed the limit
# of max consecutive non-whitespace or whitespace characters.
MAX_NO_WHITESPACES_CHARS = 25_000


class Tokenizer:
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""

special_tokens: Dict[str, int]

num_reserved_special_tokens = 256

pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501

def __init__(self, model_path: str):
"""
Initializes the Tokenizer with a Tiktoken model.

Args:
model_path (str): The path to the Tiktoken model file.
"""
assert os.path.isfile(model_path), model_path

mergeable_ranks = load_tiktoken_bpe(model_path)
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|finetune_right_pad_id|>",
"<|step_id|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|eom_id|>", # end of message
"<|eot_id|>", # end of turn
"<|python_tag|>",
]
reserved_tokens = [
f"<|reserved_special_token_{2 + i}|>"
for i in range(self.num_reserved_special_tokens - len(special_tokens))
]
special_tokens = special_tokens + reserved_tokens

self.special_tokens = {
token: num_base_tokens + i for i, token in enumerate(special_tokens)
}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)

self.n_words: int = num_base_tokens + len(special_tokens)
# BOS / EOS token IDs
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
self.eot_id: int = self.special_tokens["<|eot_id|>"]
self.eom_id: int = self.special_tokens["<|eom_id|>"]
self.python_tag_id = self.special_tokens["<|python_tag|>"]
self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"]
self.stop_tokens = [
self.special_tokens["<|eom_id|>"],
self.special_tokens["<|eot_id|>"],
]

def encode(
self,
s: str,
*,
bos: bool,
eos: bool,
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
disallowed_special: Union[Literal["all"], Collection[str]] = (),
) -> List[int]:
"""
Encodes a string into a list of token IDs.

Args:
s (str): The input string to be encoded.
bos (bool): Whether to prepend the beginning-of-sequence token.
eos (bool): Whether to append the end-of-sequence token.
allowed_tokens ("all"|set[str]): allowed special tokens in string
disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string

Returns:
list[int]: A list of token IDs.

By default, setting disallowed_special=() encodes a string by ignoring
special tokens. Specifically:
- Setting `disallowed_special` to () will cause all text corresponding
to special tokens to be encoded as natural text (insteading of raising
an error).
- Setting `allowed_special` to "all" will treat all text corresponding
to special tokens to be encoded as special tokens.
"""
if allowed_special is None:
allowed_special = set()
assert type(s) is str

substrs = (
substr
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
for substr in self._split_whitespaces_or_nonwhitespaces(
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
)
)
t: List[int] = []
for substr in substrs:
t.extend(
self.model.encode(
substr,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
)
if bos:
t.insert(0, self.bos_id)
if eos:
t.append(self.eos_id)
return t

def decode(self, t: Sequence[int]) -> str:
"""
Decodes a list of token IDs into a string.

Args:
t (List[int]): The list of token IDs to be decoded.

Returns:
str: The decoded string.
"""
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(List[int], t))

@staticmethod
def _split_whitespaces_or_nonwhitespaces(
s: str, max_consecutive_slice_len: int
) -> Iterator[str]:
"""
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
consecutive whitespaces or consecutive non-whitespaces.
"""
current_slice_len = 0
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
slice_start = 0

for i in range(len(s)):
is_now_space = s[i].isspace()

if current_slice_is_space ^ is_now_space:
current_slice_len = 1
current_slice_is_space = is_now_space
else:
current_slice_len += 1
if current_slice_len > max_consecutive_slice_len:
yield s[slice_start:i]
slice_start = i
current_slice_len = 1
yield s[slice_start:]
57 changes: 57 additions & 0 deletions llmc_py/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Taken from:
# 1) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py
# 2) https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/generation.py

import torch
from torch import nn

# Special modules
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight

# Sampling
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.

Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.

Returns:
torch.Tensor: Sampled token indices.

Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

# GQA
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
Loading
Loading