Skip to content

Commit

Permalink
add dynamic part ntk
Browse files Browse the repository at this point in the history
  • Loading branch information
jquesnelle committed Jul 11, 2023
1 parent 114adf9 commit b6aed01
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 3 deletions.
10 changes: 8 additions & 2 deletions model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def load_model(model, load_in_8bit, load_in_4bit, length):

return loaded

def apply_patches(loaded, length, dynamic_ntk, dynamic_linear, ntk, linear, part_ntk):
def apply_patches(loaded, length, dynamic_ntk, dynamic_linear, dynamic_part_ntk, ntk, linear, part_ntk):
if "GPTNeoXForCausalLM" in loaded.config.architectures:
patch_gptneox_for_longer_sequences(loaded, length)
if dynamic_linear:
Expand All @@ -49,6 +49,12 @@ def apply_patches(loaded, length, dynamic_ntk, dynamic_linear, ntk, linear, part
else:
raise RuntimeError(
f"Unsupported architecture {loaded.config.architectures} for dyanmic ntk")
elif dynamic_part_ntk:
if "LlamaForCausalLM" in loaded.config.architectures:
patch_llama_for_dynamic_part_ntk_rotary_embeddings(loaded)
else:
raise RuntimeError(
f"Unsupported architecture {loaded.config.architectures} for dyanmic part ntk")
elif ntk:
if "GPTNeoXForCausalLM" in loaded.config.architectures:
patch_gptneox_for_ntk_scaled_rotary_embeddings(
Expand All @@ -69,4 +75,4 @@ def apply_patches(loaded, length, dynamic_ntk, dynamic_linear, ntk, linear, part
patch_llama_for_part_ntk_scaled_rotary_embeddings(loaded, scale=part_ntk)
else:
raise RuntimeError(
f"Unsupported architecture {loaded.config.architectures} for linear")
f"Unsupported architecture {loaded.config.architectures} for part ntk")
3 changes: 2 additions & 1 deletion perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def main(args):
loaded = load_model(model, args.load_in_8bit,
args.load_in_4bit, args.max_tokens)
apply_patches(loaded, args.max_tokens, args.dynamic_ntk,
args.dynamic_linear, args.ntk, args.linear, args.part_ntk)
args.dynamic_linear, args.dynamic_part_ntk, args.ntk, args.linear, args.part_ntk)

result = []
for max_length in tokens:
Expand Down Expand Up @@ -175,6 +175,7 @@ def main(args):
parser.add_argument("--ntk", type=float)
parser.add_argument("--part-ntk", type=float)
parser.add_argument("--linear", type=float)
parser.add_argument("--dynamic-part-ntk", action="store_true")
parser.add_argument("--output-file", type=str)
parser.add_argument("--load-in-8bit", action="store_true")
parser.add_argument("--load-in-4bit", action="store_true")
Expand Down
89 changes: 89 additions & 0 deletions scaled_rope/LlamaDynamicPartNTKScaledRotaryEmbedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
import math

def find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) #Inverse dim formula to find number of rotations

def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low = math.floor(find_correction_factor(low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_factor(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim-1) #Clamp values just in case

def linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 #Prevent singularity

linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func

def find_newbase_ntk(dim, base=10000, scale=1):
return base * scale ** (dim / (dim-2))

class LlamaDynamicPartNTKScaledRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk_factor=1, extrapolation_factor=1, device=None):
super().__init__()
self.dim = dim
self.base = base
self.ntk_factor = ntk_factor
self.extrapolation_factor = extrapolation_factor
self.max_position_embeddings = max_position_embeddings
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len

#Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
#Do not change unless there is a good reason for doing so!
beta_0 = 1.25
beta_1 = 0.75
gamma_0 = 16
gamma_1 = 2

# the "dynamic" part
scale = seq_len / self.max_position_embeddings

#Three RoPE extrapolation/interpolation methods
inv_freq_base = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
inv_freq_linear = 1.0 / (scale * (self.base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim)))
inv_freq_ntk = 1.0 / (find_newbase_ntk(self.dim, self.base, scale) ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))

current_dtype = inv_freq_ntk.dtype
current_device = inv_freq_ntk.device

#Combine NTK and Linear
low, high = find_correction_range(beta_0, beta_1, self.dim, self.base, self.max_position_embeddings)
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).type(current_dtype).to(current_device)) * self.ntk_factor
inv_freq = inv_freq_linear * (1 - inv_freq_mask) + inv_freq_ntk * inv_freq_mask

#Combine Extrapolation and NTK and Linear
low, high = find_correction_range(gamma_0, gamma_1, self.dim, self.base, self.max_position_embeddings)
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).type(current_dtype).to(current_device)) * self.extrapolation_factor
inv_freq = inv_freq * (1 - inv_freq_mask) + inv_freq_base * inv_freq_mask

self.register_buffer("inv_freq", inv_freq)

t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
6 changes: 6 additions & 0 deletions scaled_rope/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ def patch_llama_for_dynamic_scaled_rotary_embeddings(model, ntk):
for each in model.model.layers:
each.self_attn.rotary_emb = LlamaDynamicScaledRotaryEmbedding(
each.self_attn.head_dim, device=each.self_attn.rotary_emb.inv_freq.device, ntk=ntk)

def patch_llama_for_dynamic_part_ntk_rotary_embeddings(model):
from .LlamaDynamicPartNTKScaledRotaryEmbedding import LlamaDynamicPartNTKScaledRotaryEmbedding
for each in model.model.layers:
each.self_attn.rotary_emb = LlamaDynamicPartNTKScaledRotaryEmbedding(
each.self_attn.head_dim, device=each.self_attn.rotary_emb.inv_freq.device)


def patch_llama_for_ntk_scaled_rotary_embeddings(model, alpha):
Expand Down

0 comments on commit b6aed01

Please sign in to comment.