From b6aed01702e6e05ba4586b53b89b74d8fdf754c2 Mon Sep 17 00:00:00 2001 From: Jeffrey Quesnelle Date: Tue, 11 Jul 2023 00:39:47 -0400 Subject: [PATCH] add dynamic part ntk --- model_loader.py | 10 ++- perplexity.py | 3 +- ...lamaDynamicPartNTKScaledRotaryEmbedding.py | 89 +++++++++++++++++++ scaled_rope/patch.py | 6 ++ 4 files changed, 105 insertions(+), 3 deletions(-) create mode 100644 scaled_rope/LlamaDynamicPartNTKScaledRotaryEmbedding.py diff --git a/model_loader.py b/model_loader.py index c551b0b..007a220 100644 --- a/model_loader.py +++ b/model_loader.py @@ -32,7 +32,7 @@ def load_model(model, load_in_8bit, load_in_4bit, length): return loaded -def apply_patches(loaded, length, dynamic_ntk, dynamic_linear, ntk, linear, part_ntk): +def apply_patches(loaded, length, dynamic_ntk, dynamic_linear, dynamic_part_ntk, ntk, linear, part_ntk): if "GPTNeoXForCausalLM" in loaded.config.architectures: patch_gptneox_for_longer_sequences(loaded, length) if dynamic_linear: @@ -49,6 +49,12 @@ def apply_patches(loaded, length, dynamic_ntk, dynamic_linear, ntk, linear, part else: raise RuntimeError( f"Unsupported architecture {loaded.config.architectures} for dyanmic ntk") + elif dynamic_part_ntk: + if "LlamaForCausalLM" in loaded.config.architectures: + patch_llama_for_dynamic_part_ntk_rotary_embeddings(loaded) + else: + raise RuntimeError( + f"Unsupported architecture {loaded.config.architectures} for dyanmic part ntk") elif ntk: if "GPTNeoXForCausalLM" in loaded.config.architectures: patch_gptneox_for_ntk_scaled_rotary_embeddings( @@ -69,4 +75,4 @@ def apply_patches(loaded, length, dynamic_ntk, dynamic_linear, ntk, linear, part patch_llama_for_part_ntk_scaled_rotary_embeddings(loaded, scale=part_ntk) else: raise RuntimeError( - f"Unsupported architecture {loaded.config.architectures} for linear") \ No newline at end of file + f"Unsupported architecture {loaded.config.architectures} for part ntk") \ No newline at end of file diff --git a/perplexity.py b/perplexity.py index 1d7c48f..823ec32 100644 --- a/perplexity.py +++ b/perplexity.py @@ -138,7 +138,7 @@ def main(args): loaded = load_model(model, args.load_in_8bit, args.load_in_4bit, args.max_tokens) apply_patches(loaded, args.max_tokens, args.dynamic_ntk, - args.dynamic_linear, args.ntk, args.linear, args.part_ntk) + args.dynamic_linear, args.dynamic_part_ntk, args.ntk, args.linear, args.part_ntk) result = [] for max_length in tokens: @@ -175,6 +175,7 @@ def main(args): parser.add_argument("--ntk", type=float) parser.add_argument("--part-ntk", type=float) parser.add_argument("--linear", type=float) + parser.add_argument("--dynamic-part-ntk", action="store_true") parser.add_argument("--output-file", type=str) parser.add_argument("--load-in-8bit", action="store_true") parser.add_argument("--load-in-4bit", action="store_true") diff --git a/scaled_rope/LlamaDynamicPartNTKScaledRotaryEmbedding.py b/scaled_rope/LlamaDynamicPartNTKScaledRotaryEmbedding.py new file mode 100644 index 0000000..02279ce --- /dev/null +++ b/scaled_rope/LlamaDynamicPartNTKScaledRotaryEmbedding.py @@ -0,0 +1,89 @@ +import torch +import math + +def find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) #Inverse dim formula to find number of rotations + +def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(find_correction_factor(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_factor(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim-1) #Clamp values just in case + +def linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 #Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + +def find_newbase_ntk(dim, base=10000, scale=1): + return base * scale ** (dim / (dim-2)) + +class LlamaDynamicPartNTKScaledRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk_factor=1, extrapolation_factor=1, device=None): + super().__init__() + self.dim = dim + self.base = base + self.ntk_factor = ntk_factor + self.extrapolation_factor = extrapolation_factor + self.max_position_embeddings = max_position_embeddings + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + dtype = torch.get_default_dtype() + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + + #Interpolation constants found experimentally for LLaMA (might not be totally optimal though) + #Do not change unless there is a good reason for doing so! + beta_0 = 1.25 + beta_1 = 0.75 + gamma_0 = 16 + gamma_1 = 2 + + # the "dynamic" part + scale = seq_len / self.max_position_embeddings + + #Three RoPE extrapolation/interpolation methods + inv_freq_base = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim)) + inv_freq_linear = 1.0 / (scale * (self.base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))) + inv_freq_ntk = 1.0 / (find_newbase_ntk(self.dim, self.base, scale) ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim)) + + current_dtype = inv_freq_ntk.dtype + current_device = inv_freq_ntk.device + + #Combine NTK and Linear + low, high = find_correction_range(beta_0, beta_1, self.dim, self.base, self.max_position_embeddings) + inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).type(current_dtype).to(current_device)) * self.ntk_factor + inv_freq = inv_freq_linear * (1 - inv_freq_mask) + inv_freq_ntk * inv_freq_mask + + #Combine Extrapolation and NTK and Linear + low, high = find_correction_range(gamma_0, gamma_1, self.dim, self.base, self.max_position_embeddings) + inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).type(current_dtype).to(current_device)) * self.extrapolation_factor + inv_freq = inv_freq * (1 - inv_freq_mask) + inv_freq_base * inv_freq_mask + + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) diff --git a/scaled_rope/patch.py b/scaled_rope/patch.py index 11e0a6c..0a617e6 100644 --- a/scaled_rope/patch.py +++ b/scaled_rope/patch.py @@ -6,6 +6,12 @@ def patch_llama_for_dynamic_scaled_rotary_embeddings(model, ntk): for each in model.model.layers: each.self_attn.rotary_emb = LlamaDynamicScaledRotaryEmbedding( each.self_attn.head_dim, device=each.self_attn.rotary_emb.inv_freq.device, ntk=ntk) + +def patch_llama_for_dynamic_part_ntk_rotary_embeddings(model): + from .LlamaDynamicPartNTKScaledRotaryEmbedding import LlamaDynamicPartNTKScaledRotaryEmbedding + for each in model.model.layers: + each.self_attn.rotary_emb = LlamaDynamicPartNTKScaledRotaryEmbedding( + each.self_attn.head_dim, device=each.self_attn.rotary_emb.inv_freq.device) def patch_llama_for_ntk_scaled_rotary_embeddings(model, alpha):