From 1cdb388c6a1245bd74ae58ca5f5256ba6bf47a5c Mon Sep 17 00:00:00 2001
From: Shahules786 <shahules786@gmail.com>
Date: Tue, 11 Jul 2023 12:32:18 +0000
Subject: [PATCH 01/12] NTK scaled rope

---
 .../model_training/models/RWNTKScaledRope.py  | 61 +++++++++++++++++++
 1 file changed, 61 insertions(+)
 create mode 100644 model/model_training/models/RWNTKScaledRope.py

diff --git a/model/model_training/models/RWNTKScaledRope.py b/model/model_training/models/RWNTKScaledRope.py
new file mode 100644
index 0000000000..cbc18a7185
--- /dev/null
+++ b/model/model_training/models/RWNTKScaledRope.py
@@ -0,0 +1,61 @@
+
+import torch
+from typing import Optional
+
+
+# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
+def rotate_half(x):
+    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+    return torch.cat((-x2, x1), dim=x1.ndim - 1)  # dim=-1 triggers a bug in torch < 1.8.0
+
+class RWNTKScaledRotary(torch.nn.Module):
+
+    """Implementation of RotaryEmbedding from GPT-NeoX.
+    This implementation is design to operate on queries and keys that are compatible with
+    [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
+    """
+
+    def __init__(
+        self,
+        head_dim: int,
+        base=10000,
+        alpha:int=2,
+    ):
+        super().__init__()
+        self.alpha = alpha
+        base = base * self.alpha ** (head_dim / (head_dim-2))
+        inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
+        self.register_buffer("inv_freq", inv_freq, persistent=False)
+        self.head_dim = head_dim
+        self.seq_len_cached = None
+        self.batch_size_cached = None
+        self.cos_cached: torch.Tensor | None = None
+        self.sin_cached: torch.Tensor | None = None
+
+    def cos_sin(
+        self,
+        seq_len: int,
+        device="cuda",
+        dtype=torch.bfloat16,
+    ) -> torch.Tensor:
+        if seq_len != self.seq_len_cached:
+            self.seq_len_cached = seq_len
+            t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
+            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+            emb = torch.cat((freqs, freqs), dim=-1).to(device)
+
+            if dtype in [torch.float16, torch.bfloat16]:
+                emb = emb.float()
+
+            self.cos_cached = emb.cos()[None, :, :]
+            self.sin_cached = emb.sin()[None, :, :]
+
+            self.cos_cached = self.cos_cached.type(dtype)
+            self.sin_cached = self.sin_cached.type(dtype)
+
+        return self.cos_cached, self.sin_cached
+
+    def forward(self, q, k):
+        batch, seq_len, head_dim = q.shape
+        cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
+        return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

From 53189e4b20021ccf497858359481403a13a0d8d1 Mon Sep 17 00:00:00 2001
From: Shahules786 <shahules786@gmail.com>
Date: Tue, 11 Jul 2023 12:32:33 +0000
Subject: [PATCH 02/12] added patching for falcon

---
 model/model_training/models/patching.py | 51 +++++++++++++++++++++++++
 1 file changed, 51 insertions(+)

diff --git a/model/model_training/models/patching.py b/model/model_training/models/patching.py
index c8757beb8f..196d30830f 100644
--- a/model/model_training/models/patching.py
+++ b/model/model_training/models/patching.py
@@ -176,3 +176,54 @@ def patch_model(
         if resid_pdrop is not None and resid_pdrop > 0:
             add_dropout(getattr(layer, attention_key), _patched_attn_forward, resid_pdrop)
             add_dropout(getattr(layer, mlp_key), _patched_mlp_forward, resid_pdrop)
+            
+        
+from .RWNTKScaledRope import RWNTKScaledRotary
+ROPE_DICT = {
+    "RWForCausalLM":{
+        "ntk": RWNTKScaledRotary
+    }
+}
+from transformers import AutoConfig
+import numpy as np
+            
+class RopePatch:
+    
+    def __init__(self, training_config):
+        if training_config.superhot:
+            self.do_patch = True
+            self.args = training_config.superhot_config
+            rope_type = self.args.pop("type")
+            config = AutoConfig.from_pretrained(training_config.model_name, trust_remote_code=True)
+            architecture = np.intersect1d(config.architectures, list(ROPE_DICT.keys()))
+            if architecture:
+                self.model_name = architecture[0]
+                self.patch_fun = ROPE_DICT.get(self.model_name)[rope_type]
+            else:
+                raise NotImplementedError()
+        else:
+            self.do_patch = False
+
+        
+    def patch(self, model):
+        
+        if self.do_patch:
+            if self.model_name == "RWForCausalLM":
+                self.patch_rw_model(model, **self.args)
+            else:
+                raise NotImplementedError()
+                
+    
+    def patch_rw_model(self, model, **kwargs):
+        
+        for each in model.transformer.h:
+            each.self_attention.maybe_rotary = self.patch_fun(model.config.head_dim, **kwargs)
+    
+            
+            
+            
+            
+        
+        
+    
+    

From 252e96b9c5a896c1a869dedb80804bed3a4c7e8d Mon Sep 17 00:00:00 2001
From: Shahules786 <shahules786@gmail.com>
Date: Tue, 11 Jul 2023 12:33:08 +0000
Subject: [PATCH 03/12] fix for falcon layers

---
 model/model_training/models/peft_modeling.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/model/model_training/models/peft_modeling.py b/model/model_training/models/peft_modeling.py
index 54d51caa06..7130a3c903 100644
--- a/model/model_training/models/peft_modeling.py
+++ b/model/model_training/models/peft_modeling.py
@@ -47,10 +47,12 @@ def peft_model(model, peft_type="lora", int8_training=False, gradient_checkpoint
         config = LoraConfig(
             r=16,
             lora_alpha=32,
-            target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
+            target_modules=["query_key_value"],
             lora_dropout=0.05,
             bias="none",
             task_type="CAUSAL_LM",
+            inference_mode= False,
+
         )
     elif peft_type == "prefix-tuning":
         config = PrefixTuningConfig(

From ef45c26b29dc1fcbf6357d604db5d9788b256a35 Mon Sep 17 00:00:00 2001
From: Shahules786 <shahules786@gmail.com>
Date: Tue, 11 Jul 2023 12:33:23 +0000
Subject: [PATCH 04/12] add rope scaling

---
 model/model_training/trainer_sft.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/model/model_training/trainer_sft.py b/model/model_training/trainer_sft.py
index 547f1e5aed..1560f4c069 100755
--- a/model/model_training/trainer_sft.py
+++ b/model/model_training/trainer_sft.py
@@ -331,6 +331,7 @@ def main():
     if not training_conf.deepspeed or training_conf.local_rank == 0:
         tokenizer_sanity_check(tokenizer)
 
+    print("POINT 1")
     train_collate_fn = DialogueDataCollator(
         tokenizer,
         max_length=training_conf.max_length,
@@ -362,7 +363,6 @@ def main():
     )
 
     train, evals = get_dataset(training_conf)
-
     show_dataset_stats = (training_conf.verbose or training_conf.show_dataset_stats) and (
         not training_conf.deepspeed or training_conf.local_rank == 0
     )
@@ -416,8 +416,11 @@ def main():
         sampler = None
 
     metrics, preprocess_fns = get_metrics(training_conf, tokenizer)
-
     model = get_model(training_conf, tokenizer)
+    
+    from model_training.models.patching import RopePatch
+    superhot = RopePatch(training_conf)
+    superhot.patch(model)
 
     if training_conf.peft_model:
         print("Using PEFT model")

From 44b9522795ed770008e5dea653c26bc6420fd3f2 Mon Sep 17 00:00:00 2001
From: Shahules786 <shahules786@gmail.com>
Date: Tue, 11 Jul 2023 18:28:02 +0000
Subject: [PATCH 05/12] add dynamic ntk

---
 model/model_training/models/patching.py |  68 +++++----
 model/model_training/models/rope.py     | 177 ++++++++++++++++++++++++
 2 files changed, 220 insertions(+), 25 deletions(-)
 create mode 100644 model/model_training/models/rope.py

diff --git a/model/model_training/models/patching.py b/model/model_training/models/patching.py
index 196d30830f..574c07a7ca 100644
--- a/model/model_training/models/patching.py
+++ b/model/model_training/models/patching.py
@@ -178,40 +178,54 @@ def patch_model(
             add_dropout(getattr(layer, mlp_key), _patched_mlp_forward, resid_pdrop)
             
         
-from .RWNTKScaledRope import RWNTKScaledRotary
-ROPE_DICT = {
-    "RWForCausalLM":{
-        "ntk": RWNTKScaledRotary
-    }
-}
+from .rope import RWNTKScaledRope, LlamaLinearScaledRope, LlamaNTKScaledRope, LlamaDynamicScaledRotaryEmbedding
 from transformers import AutoConfig
-import numpy as np
             
 class RopePatch:
     
-    def __init__(self, training_config):
-        if training_config.superhot:
-            self.do_patch = True
-            self.args = training_config.superhot_config
-            rope_type = self.args.pop("type")
-            config = AutoConfig.from_pretrained(training_config.model_name, trust_remote_code=True)
-            architecture = np.intersect1d(config.architectures, list(ROPE_DICT.keys()))
-            if architecture:
-                self.model_name = architecture[0]
-                self.patch_fun = ROPE_DICT.get(self.model_name)[rope_type]
+    def __init__(self, model_name, **kwargs):
+
+        self.args = kwargs
+        rope_type = self.args.pop("type")
+        config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
+        architecture = config.architectures
+        if architecture:
+            self.model_name = architecture[0]
+            if "RWForCausalLM" in architecture:
+                self.architecture = "RWForCausalLM"
+                if rope_type == "ntk":
+                    self.patch_fun = RWNTKScaledRope
+                else:
+                    raise NotImplementedError()
+            elif "LlamaForCausalLM" in architecture:
+                self.architecture = "LlamaForCausalLM"
+                if rope_type == "linear":
+                    self.patch_fun = LlamaLinearScaledRope
+                elif rope_type == "ntk":
+                    self.patch_fun = LlamaNTKScaledRope
+                elif rope_type == "dynamic-ntk":
+                    self.patch_fun = LlamaDynamicScaledRotaryEmbedding
+                else:
+                    raise NotImplementedError()
             else:
                 raise NotImplementedError()
-        else:
-            self.do_patch = False
 
+
+    @classmethod
+    def from_config(cls, config):
+        
+        model_name  = config.model_name
+        args = config.superhot_config
+        return cls(model_name, **args)
         
     def patch(self, model):
         
-        if self.do_patch:
-            if self.model_name == "RWForCausalLM":
-                self.patch_rw_model(model, **self.args)
-            else:
-                raise NotImplementedError()
+        if self.architecture == "RWForCausalLM":
+            self.patch_rw_model(model, **self.args)
+        elif self.architecture == "LlamaForCausalLM":
+            self.patch_llama_model(model, **self.args)
+        else:
+            raise NotImplementedError()
                 
     
     def patch_rw_model(self, model, **kwargs):
@@ -220,7 +234,11 @@ def patch_rw_model(self, model, **kwargs):
             each.self_attention.maybe_rotary = self.patch_fun(model.config.head_dim, **kwargs)
     
             
-            
+    def patch_llama_model(self, model, **kwargs):
+        
+        kwargs.update({"device":model.device})
+        for each in model.model.layers:
+            each.self_attn.rotary_emb = self.patch_fun(each.self_attn.head_dim, **kwargs)
             
             
         
diff --git a/model/model_training/models/rope.py b/model/model_training/models/rope.py
new file mode 100644
index 0000000000..1a13d0aebd
--- /dev/null
+++ b/model/model_training/models/rope.py
@@ -0,0 +1,177 @@
+
+import torch
+from typing import Optional
+
+
+# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
+def rotate_half(x):
+    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+    return torch.cat((-x2, x1), dim=x1.ndim - 1)  # dim=-1 triggers a bug in torch < 1.8.0
+
+class RWNTKScaledRope(torch.nn.Module):
+
+    """Implementation of RotaryEmbedding from GPT-NeoX.
+    This implementation is design to operate on queries and keys that are compatible with
+    [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
+    """
+
+    def __init__(
+        self,
+        head_dim: int,
+        base=10000,
+        alpha:int=2,
+    ):
+        super().__init__()
+        self.alpha = alpha
+        base = base * self.alpha ** (head_dim / (head_dim-2))
+        inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
+        self.register_buffer("inv_freq", inv_freq, persistent=False)
+        self.head_dim = head_dim
+        self.seq_len_cached = None
+        self.batch_size_cached = None
+        self.cos_cached: torch.Tensor | None = None
+        self.sin_cached: torch.Tensor | None = None
+
+    def cos_sin(
+        self,
+        seq_len: int,
+        device="cuda",
+        dtype=torch.bfloat16,
+    ) -> torch.Tensor:
+        if seq_len != self.seq_len_cached:
+            self.seq_len_cached = seq_len
+            t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
+            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+            emb = torch.cat((freqs, freqs), dim=-1).to(device)
+
+            if dtype in [torch.float16, torch.bfloat16]:
+                emb = emb.float()
+
+            self.cos_cached = emb.cos()[None, :, :]
+            self.sin_cached = emb.sin()[None, :, :]
+
+            self.cos_cached = self.cos_cached.type(dtype)
+            self.sin_cached = self.sin_cached.type(dtype)
+
+        return self.cos_cached, self.sin_cached
+
+    def forward(self, q, k):
+        batch, seq_len, head_dim = q.shape
+        cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
+        return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
+
+
+class LlamaLinearScaledRope(torch.nn.Module):
+    def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):
+        super().__init__()
+        self.scale = 1 / scale
+        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
+        self.register_buffer("inv_freq", inv_freq)
+
+        # Build here to make `torch.jit.trace` work.
+        self.max_seq_len_cached = max_position_embeddings
+        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+        t *= self.scale
+        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+        # Different from paper, but it uses a different permutation in order to obtain the same calculation
+        emb = torch.cat((freqs, freqs), dim=-1)
+        dtype = torch.get_default_dtype()
+        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
+        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
+
+    def forward(self, x, seq_len=None):
+        # x: [bs, num_attention_heads, seq_len, head_size]
+        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
+        if seq_len > self.max_seq_len_cached:
+            self.max_seq_len_cached = seq_len
+            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
+            t *= self.scale
+            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+            # Different from paper, but it uses a different permutation in order to obtain the same calculation
+            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
+            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
+        return (
+            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+        )
+        
+
+
+class LlamaNTKScaledRope(torch.nn.Module):
+    def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):
+        super().__init__()
+        base = base * alpha ** (dim / (dim-2))
+        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
+        self.register_buffer("inv_freq", inv_freq)
+
+        # Build here to make `torch.jit.trace` work.
+        self.max_seq_len_cached = max_position_embeddings
+        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+        # Different from paper, but it uses a different permutation in order to obtain the same calculation
+        emb = torch.cat((freqs, freqs), dim=-1)
+        dtype = torch.get_default_dtype()
+        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
+        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
+
+    def forward(self, x, seq_len=None):
+        # x: [bs, num_attention_heads, seq_len, head_size]
+        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
+        if seq_len > self.max_seq_len_cached:
+            self.max_seq_len_cached = seq_len
+            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
+            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+            # Different from paper, but it uses a different permutation in order to obtain the same calculation
+            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
+            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
+        return (
+            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+        )
+        
+        
+import math
+
+class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
+    def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):
+        super().__init__()
+        self.ntk = ntk
+        self.base = base
+        self.dim = dim
+        self.max_position_embeddings = max_position_embeddings
+        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
+        self.register_buffer("inv_freq", inv_freq)
+
+        # Build here to make `torch.jit.trace` work.
+        self.max_seq_len_cached = max_position_embeddings
+        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+        # Different from paper, but it uses a different permutation in order to obtain the same calculation
+        emb = torch.cat((freqs, freqs), dim=-1)
+        dtype = torch.get_default_dtype()
+        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
+        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
+
+    def forward(self, x, seq_len=None):
+        # x: [bs, num_attention_heads, seq_len, head_size]
+        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
+        if seq_len > self.max_seq_len_cached:
+            self.max_seq_len_cached = seq_len
+            if self.ntk:
+                base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))
+                inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
+                self.register_buffer("inv_freq", inv_freq)
+            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
+            if not self.ntk:
+                t *= self.max_position_embeddings / seq_len
+            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+            # Different from paper, but it uses a different permutation in order to obtain the same calculation
+            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
+            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
+        return (
+            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+        )
\ No newline at end of file

From b40c93d2c0559b321d9e48e1423016d169bf2f41 Mon Sep 17 00:00:00 2001
From: Shahules786 <shahules786@gmail.com>
Date: Tue, 11 Jul 2023 18:28:22 +0000
Subject: [PATCH 06/12] rename

---
 .../model_training/models/RWNTKScaledRope.py  | 61 -------------------
 1 file changed, 61 deletions(-)
 delete mode 100644 model/model_training/models/RWNTKScaledRope.py

diff --git a/model/model_training/models/RWNTKScaledRope.py b/model/model_training/models/RWNTKScaledRope.py
deleted file mode 100644
index cbc18a7185..0000000000
--- a/model/model_training/models/RWNTKScaledRope.py
+++ /dev/null
@@ -1,61 +0,0 @@
-
-import torch
-from typing import Optional
-
-
-# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
-def rotate_half(x):
-    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
-    return torch.cat((-x2, x1), dim=x1.ndim - 1)  # dim=-1 triggers a bug in torch < 1.8.0
-
-class RWNTKScaledRotary(torch.nn.Module):
-
-    """Implementation of RotaryEmbedding from GPT-NeoX.
-    This implementation is design to operate on queries and keys that are compatible with
-    [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
-    """
-
-    def __init__(
-        self,
-        head_dim: int,
-        base=10000,
-        alpha:int=2,
-    ):
-        super().__init__()
-        self.alpha = alpha
-        base = base * self.alpha ** (head_dim / (head_dim-2))
-        inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
-        self.register_buffer("inv_freq", inv_freq, persistent=False)
-        self.head_dim = head_dim
-        self.seq_len_cached = None
-        self.batch_size_cached = None
-        self.cos_cached: torch.Tensor | None = None
-        self.sin_cached: torch.Tensor | None = None
-
-    def cos_sin(
-        self,
-        seq_len: int,
-        device="cuda",
-        dtype=torch.bfloat16,
-    ) -> torch.Tensor:
-        if seq_len != self.seq_len_cached:
-            self.seq_len_cached = seq_len
-            t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
-            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
-            emb = torch.cat((freqs, freqs), dim=-1).to(device)
-
-            if dtype in [torch.float16, torch.bfloat16]:
-                emb = emb.float()
-
-            self.cos_cached = emb.cos()[None, :, :]
-            self.sin_cached = emb.sin()[None, :, :]
-
-            self.cos_cached = self.cos_cached.type(dtype)
-            self.sin_cached = self.sin_cached.type(dtype)
-
-        return self.cos_cached, self.sin_cached
-
-    def forward(self, q, k):
-        batch, seq_len, head_dim = q.shape
-        cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
-        return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

From e5620b345550a001adfa95ed31dec8c06314cc49 Mon Sep 17 00:00:00 2001
From: Shahules786 <shahules786@gmail.com>
Date: Tue, 11 Jul 2023 18:28:52 +0000
Subject: [PATCH 07/12] added rope

---
 model/model_training/trainer_sft.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/model/model_training/trainer_sft.py b/model/model_training/trainer_sft.py
index 1560f4c069..fa82967921 100755
--- a/model/model_training/trainer_sft.py
+++ b/model/model_training/trainer_sft.py
@@ -419,8 +419,9 @@ def main():
     model = get_model(training_conf, tokenizer)
     
     from model_training.models.patching import RopePatch
-    superhot = RopePatch(training_conf)
-    superhot.patch(model)
+    superhot = RopePatch.from_config(training_conf) if training_conf.superhot else None
+    if superhot:
+        superhot.patch(model)
 
     if training_conf.peft_model:
         print("Using PEFT model")

From d0e95878b7f9e1602b31b40cf907d880402fd794 Mon Sep 17 00:00:00 2001
From: Shahules786 <shahules786@gmail.com>
Date: Tue, 11 Jul 2023 18:30:21 +0000
Subject: [PATCH 08/12] added sample config

---
 model/model_training/configs/config.yaml | 32 ++++++++++++++++++++++++
 1 file changed, 32 insertions(+)

diff --git a/model/model_training/configs/config.yaml b/model/model_training/configs/config.yaml
index a8fc2ea92a..f172486cf5 100644
--- a/model/model_training/configs/config.yaml
+++ b/model/model_training/configs/config.yaml
@@ -97,6 +97,9 @@ webgpt_dataset_only:
   datasets:
     - webgpt
 
+instruction_datasets:
+  datasets:
+    - dolly15k
 per_digit_tokens:
   per_digit_tokens: true
 
@@ -779,3 +782,32 @@ debug:
   verbose: true
   num_train_epochs: 0.2
   dtype: fp32
+
+patching-test:
+  dtype: bf16
+  log_dir: "llama_log_7b"
+  learning_rate: 1e-5
+  model_name: "huggyllama/llama-7b"
+  deepspeed_config: configs/zero_config_falcon.json
+  output_dir: llama
+  weight_decay: 0.0
+  max_length: 4048
+  warmup_steps: 100
+  gradient_checkpointing: true
+  gradient_accumulation_steps: 2
+  per_device_train_batch_size: 1
+  per_device_eval_batch_size: 1
+  eval_steps: 100
+  save_steps: 500
+  num_train_epochs: 8
+  save_total_limit: 4
+  use_flash_attention: false
+  residual_dropout: 0.3
+  residual_dropout_lima: true
+  log_wandb: true
+  peft_model: true
+  peft_type: "lora"
+  superhot: true 
+  superhot_config:
+    type: linear
+    scale: 2

From 522594ef8b0b38d6696d0abfbabd6cff5618d0f3 Mon Sep 17 00:00:00 2001
From: Shahules786 <shahules786@gmail.com>
Date: Tue, 11 Jul 2023 18:42:02 +0000
Subject: [PATCH 09/12] pre-commit

---
 model/model_training/configs/config.yaml     |  2 +-
 model/model_training/models/patching.py      | 35 +++++---------------
 model/model_training/models/peft_modeling.py |  5 ++-
 model/model_training/models/rope.py          | 20 +++++------
 model/model_training/trainer_sft.py          |  3 +-
 5 files changed, 23 insertions(+), 42 deletions(-)

diff --git a/model/model_training/configs/config.yaml b/model/model_training/configs/config.yaml
index f172486cf5..b93d8965c5 100644
--- a/model/model_training/configs/config.yaml
+++ b/model/model_training/configs/config.yaml
@@ -807,7 +807,7 @@ patching-test:
   log_wandb: true
   peft_model: true
   peft_type: "lora"
-  superhot: true 
+  superhot: true
   superhot_config:
     type: linear
     scale: 2
diff --git a/model/model_training/models/patching.py b/model/model_training/models/patching.py
index 574c07a7ca..9f97514d03 100644
--- a/model/model_training/models/patching.py
+++ b/model/model_training/models/patching.py
@@ -6,12 +6,13 @@
 
 import torch.nn as nn
 import transformers
-from transformers import GPTNeoXForCausalLM, GPTNeoXModel, LlamaForCausalLM, LlamaModel
+from transformers import AutoConfig, GPTNeoXForCausalLM, GPTNeoXModel, LlamaForCausalLM, LlamaModel
 from trlx.models.modeling_ppo import AutoModelForCausalLMWithHydraValueHead
 
 from .patching_llama import llama_forward_with_flash_attn
 from .patching_neox import neox_forward_with_flash_attn
 from .reward_model import GPTNeoXRewardModel
+from .rope import LlamaDynamicScaledRotaryEmbedding, LlamaLinearScaledRope, LlamaNTKScaledRope, RWNTKScaledRope
 
 SUPPORTED_MODELS = [
     GPTNeoXModel,
@@ -176,15 +177,10 @@ def patch_model(
         if resid_pdrop is not None and resid_pdrop > 0:
             add_dropout(getattr(layer, attention_key), _patched_attn_forward, resid_pdrop)
             add_dropout(getattr(layer, mlp_key), _patched_mlp_forward, resid_pdrop)
-            
-        
-from .rope import RWNTKScaledRope, LlamaLinearScaledRope, LlamaNTKScaledRope, LlamaDynamicScaledRotaryEmbedding
-from transformers import AutoConfig
-            
+
+
 class RopePatch:
-    
     def __init__(self, model_name, **kwargs):
-
         self.args = kwargs
         rope_type = self.args.pop("type")
         config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
@@ -210,38 +206,25 @@ def __init__(self, model_name, **kwargs):
             else:
                 raise NotImplementedError()
 
-
     @classmethod
     def from_config(cls, config):
-        
-        model_name  = config.model_name
+        model_name = config.model_name
         args = config.superhot_config
         return cls(model_name, **args)
-        
+
     def patch(self, model):
-        
         if self.architecture == "RWForCausalLM":
             self.patch_rw_model(model, **self.args)
         elif self.architecture == "LlamaForCausalLM":
             self.patch_llama_model(model, **self.args)
         else:
             raise NotImplementedError()
-                
-    
+
     def patch_rw_model(self, model, **kwargs):
-        
         for each in model.transformer.h:
             each.self_attention.maybe_rotary = self.patch_fun(model.config.head_dim, **kwargs)
-    
-            
+
     def patch_llama_model(self, model, **kwargs):
-        
-        kwargs.update({"device":model.device})
+        kwargs.update({"device": model.device})
         for each in model.model.layers:
             each.self_attn.rotary_emb = self.patch_fun(each.self_attn.head_dim, **kwargs)
-            
-            
-        
-        
-    
-    
diff --git a/model/model_training/models/peft_modeling.py b/model/model_training/models/peft_modeling.py
index 7130a3c903..51b1f334d5 100644
--- a/model/model_training/models/peft_modeling.py
+++ b/model/model_training/models/peft_modeling.py
@@ -47,12 +47,11 @@ def peft_model(model, peft_type="lora", int8_training=False, gradient_checkpoint
         config = LoraConfig(
             r=16,
             lora_alpha=32,
-            target_modules=["query_key_value"],
+            target_modules=["q_proj", "v_proj", "k_proj"],
             lora_dropout=0.05,
             bias="none",
             task_type="CAUSAL_LM",
-            inference_mode= False,
-
+            inference_mode=False,
         )
     elif peft_type == "prefix-tuning":
         config = PrefixTuningConfig(
diff --git a/model/model_training/models/rope.py b/model/model_training/models/rope.py
index 1a13d0aebd..fdc8195f74 100644
--- a/model/model_training/models/rope.py
+++ b/model/model_training/models/rope.py
@@ -1,6 +1,4 @@
-
 import torch
-from typing import Optional
 
 
 # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
@@ -8,6 +6,7 @@ def rotate_half(x):
     x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
     return torch.cat((-x2, x1), dim=x1.ndim - 1)  # dim=-1 triggers a bug in torch < 1.8.0
 
+
 class RWNTKScaledRope(torch.nn.Module):
 
     """Implementation of RotaryEmbedding from GPT-NeoX.
@@ -19,11 +18,11 @@ def __init__(
         self,
         head_dim: int,
         base=10000,
-        alpha:int=2,
+        alpha: int = 2,
     ):
         super().__init__()
         self.alpha = alpha
-        base = base * self.alpha ** (head_dim / (head_dim-2))
+        base = base * self.alpha ** (head_dim / (head_dim - 2))
         inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
         self.register_buffer("inv_freq", inv_freq, persistent=False)
         self.head_dim = head_dim
@@ -95,13 +94,12 @@ def forward(self, x, seq_len=None):
             self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
             self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
         )
-        
 
 
 class LlamaNTKScaledRope(torch.nn.Module):
     def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):
         super().__init__()
-        base = base * alpha ** (dim / (dim-2))
+        base = base * alpha ** (dim / (dim - 2))
         inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
         self.register_buffer("inv_freq", inv_freq)
 
@@ -130,9 +128,7 @@ def forward(self, x, seq_len=None):
             self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
             self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
         )
-        
-        
-import math
+
 
 class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
     def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):
@@ -160,7 +156,9 @@ def forward(self, x, seq_len=None):
         if seq_len > self.max_seq_len_cached:
             self.max_seq_len_cached = seq_len
             if self.ntk:
-                base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))
+                base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (
+                    self.dim / (self.dim - 2)
+                )
                 inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
                 self.register_buffer("inv_freq", inv_freq)
             t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
@@ -174,4 +172,4 @@ def forward(self, x, seq_len=None):
         return (
             self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
             self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
-        )
\ No newline at end of file
+        )
diff --git a/model/model_training/trainer_sft.py b/model/model_training/trainer_sft.py
index fa82967921..84ad3f4b60 100755
--- a/model/model_training/trainer_sft.py
+++ b/model/model_training/trainer_sft.py
@@ -417,8 +417,9 @@ def main():
 
     metrics, preprocess_fns = get_metrics(training_conf, tokenizer)
     model = get_model(training_conf, tokenizer)
-    
+
     from model_training.models.patching import RopePatch
+
     superhot = RopePatch.from_config(training_conf) if training_conf.superhot else None
     if superhot:
         superhot.patch(model)

From cc14e5ad19fc9bb91ebd6ea17eb73473ce8a0d7e Mon Sep 17 00:00:00 2001
From: Shahules786 <shahules786@gmail.com>
Date: Tue, 11 Jul 2023 18:47:22 +0000
Subject: [PATCH 10/12] rmv changes

---
 model/model_training/models/peft_modeling.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/model/model_training/models/peft_modeling.py b/model/model_training/models/peft_modeling.py
index 51b1f334d5..54d51caa06 100644
--- a/model/model_training/models/peft_modeling.py
+++ b/model/model_training/models/peft_modeling.py
@@ -47,11 +47,10 @@ def peft_model(model, peft_type="lora", int8_training=False, gradient_checkpoint
         config = LoraConfig(
             r=16,
             lora_alpha=32,
-            target_modules=["q_proj", "v_proj", "k_proj"],
+            target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
             lora_dropout=0.05,
             bias="none",
             task_type="CAUSAL_LM",
-            inference_mode=False,
         )
     elif peft_type == "prefix-tuning":
         config = PrefixTuningConfig(

From eaae56c9242d43ead99db3382e52ab7d7aba508c Mon Sep 17 00:00:00 2001
From: Shahules786 <shahules786@gmail.com>
Date: Wed, 12 Jul 2023 10:34:19 +0000
Subject: [PATCH 11/12] cleanup

---
 model/model_training/configs/config.yaml | 5 +----
 model/model_training/trainer_sft.py      | 4 +---
 2 files changed, 2 insertions(+), 7 deletions(-)

diff --git a/model/model_training/configs/config.yaml b/model/model_training/configs/config.yaml
index b93d8965c5..a997c0d19a 100644
--- a/model/model_training/configs/config.yaml
+++ b/model/model_training/configs/config.yaml
@@ -97,9 +97,6 @@ webgpt_dataset_only:
   datasets:
     - webgpt
 
-instruction_datasets:
-  datasets:
-    - dolly15k
 per_digit_tokens:
   per_digit_tokens: true
 
@@ -783,7 +780,7 @@ debug:
   num_train_epochs: 0.2
   dtype: fp32
 
-patching-test:
+rope_scaling_test:
   dtype: bf16
   log_dir: "llama_log_7b"
   learning_rate: 1e-5
diff --git a/model/model_training/trainer_sft.py b/model/model_training/trainer_sft.py
index 84ad3f4b60..158e1c621e 100755
--- a/model/model_training/trainer_sft.py
+++ b/model/model_training/trainer_sft.py
@@ -11,6 +11,7 @@
 # from model_training.custom_datasets.formatting import DatasetEntry
 from model_training.custom_datasets.dialogue_collator import DialogueDataCollator
 from model_training.efficiency_utils import fuse_gelu
+from model_training.models.patching import RopePatch
 from model_training.models.peft_modeling import peft_model
 from model_training.utils.utils import (
     PerDatasetSampler,
@@ -331,7 +332,6 @@ def main():
     if not training_conf.deepspeed or training_conf.local_rank == 0:
         tokenizer_sanity_check(tokenizer)
 
-    print("POINT 1")
     train_collate_fn = DialogueDataCollator(
         tokenizer,
         max_length=training_conf.max_length,
@@ -418,8 +418,6 @@ def main():
     metrics, preprocess_fns = get_metrics(training_conf, tokenizer)
     model = get_model(training_conf, tokenizer)
 
-    from model_training.models.patching import RopePatch
-
     superhot = RopePatch.from_config(training_conf) if training_conf.superhot else None
     if superhot:
         superhot.patch(model)

From ef03d9b8e877dec0834c3aeb2bdd3f5e4ef3ea39 Mon Sep 17 00:00:00 2001
From: Shahules786 <shahules786@gmail.com>
Date: Wed, 12 Jul 2023 10:36:41 +0000
Subject: [PATCH 12/12] added docs

---
 model/model_training/models/rope.py | 18 +++++++++++++++---
 1 file changed, 15 insertions(+), 3 deletions(-)

diff --git a/model/model_training/models/rope.py b/model/model_training/models/rope.py
index fdc8195f74..005a40c729 100644
--- a/model/model_training/models/rope.py
+++ b/model/model_training/models/rope.py
@@ -9,9 +9,8 @@ def rotate_half(x):
 
 class RWNTKScaledRope(torch.nn.Module):
 
-    """Implementation of RotaryEmbedding from GPT-NeoX.
-    This implementation is design to operate on queries and keys that are compatible with
-    [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
+    """
+    NTK-Scaled RoPE for RefinedWebModel
     """
 
     def __init__(
@@ -61,6 +60,10 @@ def forward(self, q, k):
 
 
 class LlamaLinearScaledRope(torch.nn.Module):
+    """
+    reference: https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test
+    """
+
     def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):
         super().__init__()
         self.scale = 1 / scale
@@ -97,6 +100,11 @@ def forward(self, x, seq_len=None):
 
 
 class LlamaNTKScaledRope(torch.nn.Module):
+
+    """
+    reference: https://github.com/jquesnelle/scaled-rope
+    """
+
     def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):
         super().__init__()
         base = base * alpha ** (dim / (dim - 2))
@@ -131,6 +139,10 @@ def forward(self, x, seq_len=None):
 
 
 class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
+    """
+    reference: https://github.com/jquesnelle/scaled-rope
+    """
+
     def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):
         super().__init__()
         self.ntk = ntk