Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RoPE Interpolation #3564

Merged
merged 12 commits into from
Jul 12, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add dynamic ntk
shahules786 committed Jul 11, 2023
commit 44b9522795ed770008e5dea653c26bc6420fd3f2
68 changes: 43 additions & 25 deletions model/model_training/models/patching.py
Original file line number Diff line number Diff line change
@@ -178,40 +178,54 @@ def patch_model(
add_dropout(getattr(layer, mlp_key), _patched_mlp_forward, resid_pdrop)


from .RWNTKScaledRope import RWNTKScaledRotary
ROPE_DICT = {
"RWForCausalLM":{
"ntk": RWNTKScaledRotary
}
}
from .rope import RWNTKScaledRope, LlamaLinearScaledRope, LlamaNTKScaledRope, LlamaDynamicScaledRotaryEmbedding
from transformers import AutoConfig
import numpy as np

class RopePatch:

def __init__(self, training_config):
if training_config.superhot:
self.do_patch = True
self.args = training_config.superhot_config
rope_type = self.args.pop("type")
config = AutoConfig.from_pretrained(training_config.model_name, trust_remote_code=True)
architecture = np.intersect1d(config.architectures, list(ROPE_DICT.keys()))
if architecture:
self.model_name = architecture[0]
self.patch_fun = ROPE_DICT.get(self.model_name)[rope_type]
def __init__(self, model_name, **kwargs):

self.args = kwargs
rope_type = self.args.pop("type")
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
architecture = config.architectures
if architecture:
self.model_name = architecture[0]
if "RWForCausalLM" in architecture:
self.architecture = "RWForCausalLM"
if rope_type == "ntk":
self.patch_fun = RWNTKScaledRope
else:
raise NotImplementedError()
elif "LlamaForCausalLM" in architecture:
self.architecture = "LlamaForCausalLM"
if rope_type == "linear":
self.patch_fun = LlamaLinearScaledRope
elif rope_type == "ntk":
self.patch_fun = LlamaNTKScaledRope
elif rope_type == "dynamic-ntk":
self.patch_fun = LlamaDynamicScaledRotaryEmbedding
else:
raise NotImplementedError()
else:
raise NotImplementedError()
else:
self.do_patch = False


@classmethod
def from_config(cls, config):

model_name = config.model_name
args = config.superhot_config
return cls(model_name, **args)

def patch(self, model):

if self.do_patch:
if self.model_name == "RWForCausalLM":
self.patch_rw_model(model, **self.args)
else:
raise NotImplementedError()
if self.architecture == "RWForCausalLM":
self.patch_rw_model(model, **self.args)
elif self.architecture == "LlamaForCausalLM":
self.patch_llama_model(model, **self.args)
else:
raise NotImplementedError()


def patch_rw_model(self, model, **kwargs):
@@ -220,7 +234,11 @@ def patch_rw_model(self, model, **kwargs):
each.self_attention.maybe_rotary = self.patch_fun(model.config.head_dim, **kwargs)



def patch_llama_model(self, model, **kwargs):

kwargs.update({"device":model.device})
for each in model.model.layers:
each.self_attn.rotary_emb = self.patch_fun(each.self_attn.head_dim, **kwargs)



177 changes: 177 additions & 0 deletions model/model_training/models/rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@

import torch
from typing import Optional


# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0

class RWNTKScaledRope(torch.nn.Module):

"""Implementation of RotaryEmbedding from GPT-NeoX.
This implementation is design to operate on queries and keys that are compatible with
[batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
"""

def __init__(
self,
head_dim: int,
base=10000,
alpha:int=2,
):
super().__init__()
self.alpha = alpha
base = base * self.alpha ** (head_dim / (head_dim-2))
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.head_dim = head_dim
self.seq_len_cached = None
self.batch_size_cached = None
self.cos_cached: torch.Tensor | None = None
self.sin_cached: torch.Tensor | None = None

def cos_sin(
self,
seq_len: int,
device="cuda",
dtype=torch.bfloat16,
) -> torch.Tensor:
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)

if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()

self.cos_cached = emb.cos()[None, :, :]
self.sin_cached = emb.sin()[None, :, :]

self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)

return self.cos_cached, self.sin_cached

def forward(self, q, k):
batch, seq_len, head_dim = q.shape
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)


class LlamaLinearScaledRope(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):
super().__init__()
self.scale = 1 / scale
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
t *= self.scale
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
t *= self.scale
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)



class LlamaNTKScaledRope(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):
super().__init__()
base = base * alpha ** (dim / (dim-2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)


import math

class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):
super().__init__()
self.ntk = ntk
self.base = base
self.dim = dim
self.max_position_embeddings = max_position_embeddings
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
if self.ntk:
base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
if not self.ntk:
t *= self.max_position_embeddings / seq_len
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)