From 1cdb388c6a1245bd74ae58ca5f5256ba6bf47a5c Mon Sep 17 00:00:00 2001 From: Shahules786 <shahules786@gmail.com> Date: Tue, 11 Jul 2023 12:32:18 +0000 Subject: [PATCH 01/12] NTK scaled rope --- .../model_training/models/RWNTKScaledRope.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 model/model_training/models/RWNTKScaledRope.py diff --git a/model/model_training/models/RWNTKScaledRope.py b/model/model_training/models/RWNTKScaledRope.py new file mode 100644 index 0000000000..cbc18a7185 --- /dev/null +++ b/model/model_training/models/RWNTKScaledRope.py @@ -0,0 +1,61 @@ + +import torch +from typing import Optional + + +# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...) +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0 + +class RWNTKScaledRotary(torch.nn.Module): + + """Implementation of RotaryEmbedding from GPT-NeoX. + This implementation is design to operate on queries and keys that are compatible with + [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format). + """ + + def __init__( + self, + head_dim: int, + base=10000, + alpha:int=2, + ): + super().__init__() + self.alpha = alpha + base = base * self.alpha ** (head_dim / (head_dim-2)) + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.head_dim = head_dim + self.seq_len_cached = None + self.batch_size_cached = None + self.cos_cached: torch.Tensor | None = None + self.sin_cached: torch.Tensor | None = None + + def cos_sin( + self, + seq_len: int, + device="cuda", + dtype=torch.bfloat16, + ) -> torch.Tensor: + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + t = torch.arange(seq_len, device=device).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(device) + + if dtype in [torch.float16, torch.bfloat16]: + emb = emb.float() + + self.cos_cached = emb.cos()[None, :, :] + self.sin_cached = emb.sin()[None, :, :] + + self.cos_cached = self.cos_cached.type(dtype) + self.sin_cached = self.sin_cached.type(dtype) + + return self.cos_cached, self.sin_cached + + def forward(self, q, k): + batch, seq_len, head_dim = q.shape + cos, sin = self.cos_sin(seq_len, q.device, q.dtype) + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) From 53189e4b20021ccf497858359481403a13a0d8d1 Mon Sep 17 00:00:00 2001 From: Shahules786 <shahules786@gmail.com> Date: Tue, 11 Jul 2023 12:32:33 +0000 Subject: [PATCH 02/12] added patching for falcon --- model/model_training/models/patching.py | 51 +++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/model/model_training/models/patching.py b/model/model_training/models/patching.py index c8757beb8f..196d30830f 100644 --- a/model/model_training/models/patching.py +++ b/model/model_training/models/patching.py @@ -176,3 +176,54 @@ def patch_model( if resid_pdrop is not None and resid_pdrop > 0: add_dropout(getattr(layer, attention_key), _patched_attn_forward, resid_pdrop) add_dropout(getattr(layer, mlp_key), _patched_mlp_forward, resid_pdrop) + + +from .RWNTKScaledRope import RWNTKScaledRotary +ROPE_DICT = { + "RWForCausalLM":{ + "ntk": RWNTKScaledRotary + } +} +from transformers import AutoConfig +import numpy as np + +class RopePatch: + + def __init__(self, training_config): + if training_config.superhot: + self.do_patch = True + self.args = training_config.superhot_config + rope_type = self.args.pop("type") + config = AutoConfig.from_pretrained(training_config.model_name, trust_remote_code=True) + architecture = np.intersect1d(config.architectures, list(ROPE_DICT.keys())) + if architecture: + self.model_name = architecture[0] + self.patch_fun = ROPE_DICT.get(self.model_name)[rope_type] + else: + raise NotImplementedError() + else: + self.do_patch = False + + + def patch(self, model): + + if self.do_patch: + if self.model_name == "RWForCausalLM": + self.patch_rw_model(model, **self.args) + else: + raise NotImplementedError() + + + def patch_rw_model(self, model, **kwargs): + + for each in model.transformer.h: + each.self_attention.maybe_rotary = self.patch_fun(model.config.head_dim, **kwargs) + + + + + + + + + From 252e96b9c5a896c1a869dedb80804bed3a4c7e8d Mon Sep 17 00:00:00 2001 From: Shahules786 <shahules786@gmail.com> Date: Tue, 11 Jul 2023 12:33:08 +0000 Subject: [PATCH 03/12] fix for falcon layers --- model/model_training/models/peft_modeling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/model/model_training/models/peft_modeling.py b/model/model_training/models/peft_modeling.py index 54d51caa06..7130a3c903 100644 --- a/model/model_training/models/peft_modeling.py +++ b/model/model_training/models/peft_modeling.py @@ -47,10 +47,12 @@ def peft_model(model, peft_type="lora", int8_training=False, gradient_checkpoint config = LoraConfig( r=16, lora_alpha=32, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + target_modules=["query_key_value"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", + inference_mode= False, + ) elif peft_type == "prefix-tuning": config = PrefixTuningConfig( From ef45c26b29dc1fcbf6357d604db5d9788b256a35 Mon Sep 17 00:00:00 2001 From: Shahules786 <shahules786@gmail.com> Date: Tue, 11 Jul 2023 12:33:23 +0000 Subject: [PATCH 04/12] add rope scaling --- model/model_training/trainer_sft.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/model/model_training/trainer_sft.py b/model/model_training/trainer_sft.py index 547f1e5aed..1560f4c069 100755 --- a/model/model_training/trainer_sft.py +++ b/model/model_training/trainer_sft.py @@ -331,6 +331,7 @@ def main(): if not training_conf.deepspeed or training_conf.local_rank == 0: tokenizer_sanity_check(tokenizer) + print("POINT 1") train_collate_fn = DialogueDataCollator( tokenizer, max_length=training_conf.max_length, @@ -362,7 +363,6 @@ def main(): ) train, evals = get_dataset(training_conf) - show_dataset_stats = (training_conf.verbose or training_conf.show_dataset_stats) and ( not training_conf.deepspeed or training_conf.local_rank == 0 ) @@ -416,8 +416,11 @@ def main(): sampler = None metrics, preprocess_fns = get_metrics(training_conf, tokenizer) - model = get_model(training_conf, tokenizer) + + from model_training.models.patching import RopePatch + superhot = RopePatch(training_conf) + superhot.patch(model) if training_conf.peft_model: print("Using PEFT model") From 44b9522795ed770008e5dea653c26bc6420fd3f2 Mon Sep 17 00:00:00 2001 From: Shahules786 <shahules786@gmail.com> Date: Tue, 11 Jul 2023 18:28:02 +0000 Subject: [PATCH 05/12] add dynamic ntk --- model/model_training/models/patching.py | 68 +++++---- model/model_training/models/rope.py | 177 ++++++++++++++++++++++++ 2 files changed, 220 insertions(+), 25 deletions(-) create mode 100644 model/model_training/models/rope.py diff --git a/model/model_training/models/patching.py b/model/model_training/models/patching.py index 196d30830f..574c07a7ca 100644 --- a/model/model_training/models/patching.py +++ b/model/model_training/models/patching.py @@ -178,40 +178,54 @@ def patch_model( add_dropout(getattr(layer, mlp_key), _patched_mlp_forward, resid_pdrop) -from .RWNTKScaledRope import RWNTKScaledRotary -ROPE_DICT = { - "RWForCausalLM":{ - "ntk": RWNTKScaledRotary - } -} +from .rope import RWNTKScaledRope, LlamaLinearScaledRope, LlamaNTKScaledRope, LlamaDynamicScaledRotaryEmbedding from transformers import AutoConfig -import numpy as np class RopePatch: - def __init__(self, training_config): - if training_config.superhot: - self.do_patch = True - self.args = training_config.superhot_config - rope_type = self.args.pop("type") - config = AutoConfig.from_pretrained(training_config.model_name, trust_remote_code=True) - architecture = np.intersect1d(config.architectures, list(ROPE_DICT.keys())) - if architecture: - self.model_name = architecture[0] - self.patch_fun = ROPE_DICT.get(self.model_name)[rope_type] + def __init__(self, model_name, **kwargs): + + self.args = kwargs + rope_type = self.args.pop("type") + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + architecture = config.architectures + if architecture: + self.model_name = architecture[0] + if "RWForCausalLM" in architecture: + self.architecture = "RWForCausalLM" + if rope_type == "ntk": + self.patch_fun = RWNTKScaledRope + else: + raise NotImplementedError() + elif "LlamaForCausalLM" in architecture: + self.architecture = "LlamaForCausalLM" + if rope_type == "linear": + self.patch_fun = LlamaLinearScaledRope + elif rope_type == "ntk": + self.patch_fun = LlamaNTKScaledRope + elif rope_type == "dynamic-ntk": + self.patch_fun = LlamaDynamicScaledRotaryEmbedding + else: + raise NotImplementedError() else: raise NotImplementedError() - else: - self.do_patch = False + + @classmethod + def from_config(cls, config): + + model_name = config.model_name + args = config.superhot_config + return cls(model_name, **args) def patch(self, model): - if self.do_patch: - if self.model_name == "RWForCausalLM": - self.patch_rw_model(model, **self.args) - else: - raise NotImplementedError() + if self.architecture == "RWForCausalLM": + self.patch_rw_model(model, **self.args) + elif self.architecture == "LlamaForCausalLM": + self.patch_llama_model(model, **self.args) + else: + raise NotImplementedError() def patch_rw_model(self, model, **kwargs): @@ -220,7 +234,11 @@ def patch_rw_model(self, model, **kwargs): each.self_attention.maybe_rotary = self.patch_fun(model.config.head_dim, **kwargs) - + def patch_llama_model(self, model, **kwargs): + + kwargs.update({"device":model.device}) + for each in model.model.layers: + each.self_attn.rotary_emb = self.patch_fun(each.self_attn.head_dim, **kwargs) diff --git a/model/model_training/models/rope.py b/model/model_training/models/rope.py new file mode 100644 index 0000000000..1a13d0aebd --- /dev/null +++ b/model/model_training/models/rope.py @@ -0,0 +1,177 @@ + +import torch +from typing import Optional + + +# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...) +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0 + +class RWNTKScaledRope(torch.nn.Module): + + """Implementation of RotaryEmbedding from GPT-NeoX. + This implementation is design to operate on queries and keys that are compatible with + [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format). + """ + + def __init__( + self, + head_dim: int, + base=10000, + alpha:int=2, + ): + super().__init__() + self.alpha = alpha + base = base * self.alpha ** (head_dim / (head_dim-2)) + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.head_dim = head_dim + self.seq_len_cached = None + self.batch_size_cached = None + self.cos_cached: torch.Tensor | None = None + self.sin_cached: torch.Tensor | None = None + + def cos_sin( + self, + seq_len: int, + device="cuda", + dtype=torch.bfloat16, + ) -> torch.Tensor: + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + t = torch.arange(seq_len, device=device).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(device) + + if dtype in [torch.float16, torch.bfloat16]: + emb = emb.float() + + self.cos_cached = emb.cos()[None, :, :] + self.sin_cached = emb.sin()[None, :, :] + + self.cos_cached = self.cos_cached.type(dtype) + self.sin_cached = self.sin_cached.type(dtype) + + return self.cos_cached, self.sin_cached + + def forward(self, q, k): + batch, seq_len, head_dim = q.shape + cos, sin = self.cos_sin(seq_len, q.device, q.dtype) + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + + +class LlamaLinearScaledRope(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None): + super().__init__() + self.scale = 1 / scale + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + t *= self.scale + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + dtype = torch.get_default_dtype() + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + t *= self.scale + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + + +class LlamaNTKScaledRope(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None): + super().__init__() + base = base * alpha ** (dim / (dim-2)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + dtype = torch.get_default_dtype() + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +import math + +class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None): + super().__init__() + self.ntk = ntk + self.base = base + self.dim = dim + self.max_position_embeddings = max_position_embeddings + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + dtype = torch.get_default_dtype() + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + if self.ntk: + base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + if not self.ntk: + t *= self.max_position_embeddings / seq_len + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) \ No newline at end of file From b40c93d2c0559b321d9e48e1423016d169bf2f41 Mon Sep 17 00:00:00 2001 From: Shahules786 <shahules786@gmail.com> Date: Tue, 11 Jul 2023 18:28:22 +0000 Subject: [PATCH 06/12] rename --- .../model_training/models/RWNTKScaledRope.py | 61 ------------------- 1 file changed, 61 deletions(-) delete mode 100644 model/model_training/models/RWNTKScaledRope.py diff --git a/model/model_training/models/RWNTKScaledRope.py b/model/model_training/models/RWNTKScaledRope.py deleted file mode 100644 index cbc18a7185..0000000000 --- a/model/model_training/models/RWNTKScaledRope.py +++ /dev/null @@ -1,61 +0,0 @@ - -import torch -from typing import Optional - - -# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...) -def rotate_half(x): - x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0 - -class RWNTKScaledRotary(torch.nn.Module): - - """Implementation of RotaryEmbedding from GPT-NeoX. - This implementation is design to operate on queries and keys that are compatible with - [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format). - """ - - def __init__( - self, - head_dim: int, - base=10000, - alpha:int=2, - ): - super().__init__() - self.alpha = alpha - base = base * self.alpha ** (head_dim / (head_dim-2)) - inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.head_dim = head_dim - self.seq_len_cached = None - self.batch_size_cached = None - self.cos_cached: torch.Tensor | None = None - self.sin_cached: torch.Tensor | None = None - - def cos_sin( - self, - seq_len: int, - device="cuda", - dtype=torch.bfloat16, - ) -> torch.Tensor: - if seq_len != self.seq_len_cached: - self.seq_len_cached = seq_len - t = torch.arange(seq_len, device=device).type_as(self.inv_freq) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1).to(device) - - if dtype in [torch.float16, torch.bfloat16]: - emb = emb.float() - - self.cos_cached = emb.cos()[None, :, :] - self.sin_cached = emb.sin()[None, :, :] - - self.cos_cached = self.cos_cached.type(dtype) - self.sin_cached = self.sin_cached.type(dtype) - - return self.cos_cached, self.sin_cached - - def forward(self, q, k): - batch, seq_len, head_dim = q.shape - cos, sin = self.cos_sin(seq_len, q.device, q.dtype) - return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) From e5620b345550a001adfa95ed31dec8c06314cc49 Mon Sep 17 00:00:00 2001 From: Shahules786 <shahules786@gmail.com> Date: Tue, 11 Jul 2023 18:28:52 +0000 Subject: [PATCH 07/12] added rope --- model/model_training/trainer_sft.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/model/model_training/trainer_sft.py b/model/model_training/trainer_sft.py index 1560f4c069..fa82967921 100755 --- a/model/model_training/trainer_sft.py +++ b/model/model_training/trainer_sft.py @@ -419,8 +419,9 @@ def main(): model = get_model(training_conf, tokenizer) from model_training.models.patching import RopePatch - superhot = RopePatch(training_conf) - superhot.patch(model) + superhot = RopePatch.from_config(training_conf) if training_conf.superhot else None + if superhot: + superhot.patch(model) if training_conf.peft_model: print("Using PEFT model") From d0e95878b7f9e1602b31b40cf907d880402fd794 Mon Sep 17 00:00:00 2001 From: Shahules786 <shahules786@gmail.com> Date: Tue, 11 Jul 2023 18:30:21 +0000 Subject: [PATCH 08/12] added sample config --- model/model_training/configs/config.yaml | 32 ++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/model/model_training/configs/config.yaml b/model/model_training/configs/config.yaml index a8fc2ea92a..f172486cf5 100644 --- a/model/model_training/configs/config.yaml +++ b/model/model_training/configs/config.yaml @@ -97,6 +97,9 @@ webgpt_dataset_only: datasets: - webgpt +instruction_datasets: + datasets: + - dolly15k per_digit_tokens: per_digit_tokens: true @@ -779,3 +782,32 @@ debug: verbose: true num_train_epochs: 0.2 dtype: fp32 + +patching-test: + dtype: bf16 + log_dir: "llama_log_7b" + learning_rate: 1e-5 + model_name: "huggyllama/llama-7b" + deepspeed_config: configs/zero_config_falcon.json + output_dir: llama + weight_decay: 0.0 + max_length: 4048 + warmup_steps: 100 + gradient_checkpointing: true + gradient_accumulation_steps: 2 + per_device_train_batch_size: 1 + per_device_eval_batch_size: 1 + eval_steps: 100 + save_steps: 500 + num_train_epochs: 8 + save_total_limit: 4 + use_flash_attention: false + residual_dropout: 0.3 + residual_dropout_lima: true + log_wandb: true + peft_model: true + peft_type: "lora" + superhot: true + superhot_config: + type: linear + scale: 2 From 522594ef8b0b38d6696d0abfbabd6cff5618d0f3 Mon Sep 17 00:00:00 2001 From: Shahules786 <shahules786@gmail.com> Date: Tue, 11 Jul 2023 18:42:02 +0000 Subject: [PATCH 09/12] pre-commit --- model/model_training/configs/config.yaml | 2 +- model/model_training/models/patching.py | 35 +++++--------------- model/model_training/models/peft_modeling.py | 5 ++- model/model_training/models/rope.py | 20 +++++------ model/model_training/trainer_sft.py | 3 +- 5 files changed, 23 insertions(+), 42 deletions(-) diff --git a/model/model_training/configs/config.yaml b/model/model_training/configs/config.yaml index f172486cf5..b93d8965c5 100644 --- a/model/model_training/configs/config.yaml +++ b/model/model_training/configs/config.yaml @@ -807,7 +807,7 @@ patching-test: log_wandb: true peft_model: true peft_type: "lora" - superhot: true + superhot: true superhot_config: type: linear scale: 2 diff --git a/model/model_training/models/patching.py b/model/model_training/models/patching.py index 574c07a7ca..9f97514d03 100644 --- a/model/model_training/models/patching.py +++ b/model/model_training/models/patching.py @@ -6,12 +6,13 @@ import torch.nn as nn import transformers -from transformers import GPTNeoXForCausalLM, GPTNeoXModel, LlamaForCausalLM, LlamaModel +from transformers import AutoConfig, GPTNeoXForCausalLM, GPTNeoXModel, LlamaForCausalLM, LlamaModel from trlx.models.modeling_ppo import AutoModelForCausalLMWithHydraValueHead from .patching_llama import llama_forward_with_flash_attn from .patching_neox import neox_forward_with_flash_attn from .reward_model import GPTNeoXRewardModel +from .rope import LlamaDynamicScaledRotaryEmbedding, LlamaLinearScaledRope, LlamaNTKScaledRope, RWNTKScaledRope SUPPORTED_MODELS = [ GPTNeoXModel, @@ -176,15 +177,10 @@ def patch_model( if resid_pdrop is not None and resid_pdrop > 0: add_dropout(getattr(layer, attention_key), _patched_attn_forward, resid_pdrop) add_dropout(getattr(layer, mlp_key), _patched_mlp_forward, resid_pdrop) - - -from .rope import RWNTKScaledRope, LlamaLinearScaledRope, LlamaNTKScaledRope, LlamaDynamicScaledRotaryEmbedding -from transformers import AutoConfig - + + class RopePatch: - def __init__(self, model_name, **kwargs): - self.args = kwargs rope_type = self.args.pop("type") config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) @@ -210,38 +206,25 @@ def __init__(self, model_name, **kwargs): else: raise NotImplementedError() - @classmethod def from_config(cls, config): - - model_name = config.model_name + model_name = config.model_name args = config.superhot_config return cls(model_name, **args) - + def patch(self, model): - if self.architecture == "RWForCausalLM": self.patch_rw_model(model, **self.args) elif self.architecture == "LlamaForCausalLM": self.patch_llama_model(model, **self.args) else: raise NotImplementedError() - - + def patch_rw_model(self, model, **kwargs): - for each in model.transformer.h: each.self_attention.maybe_rotary = self.patch_fun(model.config.head_dim, **kwargs) - - + def patch_llama_model(self, model, **kwargs): - - kwargs.update({"device":model.device}) + kwargs.update({"device": model.device}) for each in model.model.layers: each.self_attn.rotary_emb = self.patch_fun(each.self_attn.head_dim, **kwargs) - - - - - - diff --git a/model/model_training/models/peft_modeling.py b/model/model_training/models/peft_modeling.py index 7130a3c903..51b1f334d5 100644 --- a/model/model_training/models/peft_modeling.py +++ b/model/model_training/models/peft_modeling.py @@ -47,12 +47,11 @@ def peft_model(model, peft_type="lora", int8_training=False, gradient_checkpoint config = LoraConfig( r=16, lora_alpha=32, - target_modules=["query_key_value"], + target_modules=["q_proj", "v_proj", "k_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", - inference_mode= False, - + inference_mode=False, ) elif peft_type == "prefix-tuning": config = PrefixTuningConfig( diff --git a/model/model_training/models/rope.py b/model/model_training/models/rope.py index 1a13d0aebd..fdc8195f74 100644 --- a/model/model_training/models/rope.py +++ b/model/model_training/models/rope.py @@ -1,6 +1,4 @@ - import torch -from typing import Optional # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...) @@ -8,6 +6,7 @@ def rotate_half(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0 + class RWNTKScaledRope(torch.nn.Module): """Implementation of RotaryEmbedding from GPT-NeoX. @@ -19,11 +18,11 @@ def __init__( self, head_dim: int, base=10000, - alpha:int=2, + alpha: int = 2, ): super().__init__() self.alpha = alpha - base = base * self.alpha ** (head_dim / (head_dim-2)) + base = base * self.alpha ** (head_dim / (head_dim - 2)) inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self.head_dim = head_dim @@ -95,13 +94,12 @@ def forward(self, x, seq_len=None): self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) - class LlamaNTKScaledRope(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None): super().__init__() - base = base * alpha ** (dim / (dim-2)) + base = base * alpha ** (dim / (dim - 2)) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) self.register_buffer("inv_freq", inv_freq) @@ -130,9 +128,7 @@ def forward(self, x, seq_len=None): self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) - - -import math + class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None): @@ -160,7 +156,9 @@ def forward(self, x, seq_len=None): if seq_len > self.max_seq_len_cached: self.max_seq_len_cached = seq_len if self.ntk: - base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2)) + base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** ( + self.dim / (self.dim - 2) + ) inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim)) self.register_buffer("inv_freq", inv_freq) t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) @@ -174,4 +172,4 @@ def forward(self, x, seq_len=None): return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) \ No newline at end of file + ) diff --git a/model/model_training/trainer_sft.py b/model/model_training/trainer_sft.py index fa82967921..84ad3f4b60 100755 --- a/model/model_training/trainer_sft.py +++ b/model/model_training/trainer_sft.py @@ -417,8 +417,9 @@ def main(): metrics, preprocess_fns = get_metrics(training_conf, tokenizer) model = get_model(training_conf, tokenizer) - + from model_training.models.patching import RopePatch + superhot = RopePatch.from_config(training_conf) if training_conf.superhot else None if superhot: superhot.patch(model) From cc14e5ad19fc9bb91ebd6ea17eb73473ce8a0d7e Mon Sep 17 00:00:00 2001 From: Shahules786 <shahules786@gmail.com> Date: Tue, 11 Jul 2023 18:47:22 +0000 Subject: [PATCH 10/12] rmv changes --- model/model_training/models/peft_modeling.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/model/model_training/models/peft_modeling.py b/model/model_training/models/peft_modeling.py index 51b1f334d5..54d51caa06 100644 --- a/model/model_training/models/peft_modeling.py +++ b/model/model_training/models/peft_modeling.py @@ -47,11 +47,10 @@ def peft_model(model, peft_type="lora", int8_training=False, gradient_checkpoint config = LoraConfig( r=16, lora_alpha=32, - target_modules=["q_proj", "v_proj", "k_proj"], + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", - inference_mode=False, ) elif peft_type == "prefix-tuning": config = PrefixTuningConfig( From eaae56c9242d43ead99db3382e52ab7d7aba508c Mon Sep 17 00:00:00 2001 From: Shahules786 <shahules786@gmail.com> Date: Wed, 12 Jul 2023 10:34:19 +0000 Subject: [PATCH 11/12] cleanup --- model/model_training/configs/config.yaml | 5 +---- model/model_training/trainer_sft.py | 4 +--- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/model/model_training/configs/config.yaml b/model/model_training/configs/config.yaml index b93d8965c5..a997c0d19a 100644 --- a/model/model_training/configs/config.yaml +++ b/model/model_training/configs/config.yaml @@ -97,9 +97,6 @@ webgpt_dataset_only: datasets: - webgpt -instruction_datasets: - datasets: - - dolly15k per_digit_tokens: per_digit_tokens: true @@ -783,7 +780,7 @@ debug: num_train_epochs: 0.2 dtype: fp32 -patching-test: +rope_scaling_test: dtype: bf16 log_dir: "llama_log_7b" learning_rate: 1e-5 diff --git a/model/model_training/trainer_sft.py b/model/model_training/trainer_sft.py index 84ad3f4b60..158e1c621e 100755 --- a/model/model_training/trainer_sft.py +++ b/model/model_training/trainer_sft.py @@ -11,6 +11,7 @@ # from model_training.custom_datasets.formatting import DatasetEntry from model_training.custom_datasets.dialogue_collator import DialogueDataCollator from model_training.efficiency_utils import fuse_gelu +from model_training.models.patching import RopePatch from model_training.models.peft_modeling import peft_model from model_training.utils.utils import ( PerDatasetSampler, @@ -331,7 +332,6 @@ def main(): if not training_conf.deepspeed or training_conf.local_rank == 0: tokenizer_sanity_check(tokenizer) - print("POINT 1") train_collate_fn = DialogueDataCollator( tokenizer, max_length=training_conf.max_length, @@ -418,8 +418,6 @@ def main(): metrics, preprocess_fns = get_metrics(training_conf, tokenizer) model = get_model(training_conf, tokenizer) - from model_training.models.patching import RopePatch - superhot = RopePatch.from_config(training_conf) if training_conf.superhot else None if superhot: superhot.patch(model) From ef03d9b8e877dec0834c3aeb2bdd3f5e4ef3ea39 Mon Sep 17 00:00:00 2001 From: Shahules786 <shahules786@gmail.com> Date: Wed, 12 Jul 2023 10:36:41 +0000 Subject: [PATCH 12/12] added docs --- model/model_training/models/rope.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/model/model_training/models/rope.py b/model/model_training/models/rope.py index fdc8195f74..005a40c729 100644 --- a/model/model_training/models/rope.py +++ b/model/model_training/models/rope.py @@ -9,9 +9,8 @@ def rotate_half(x): class RWNTKScaledRope(torch.nn.Module): - """Implementation of RotaryEmbedding from GPT-NeoX. - This implementation is design to operate on queries and keys that are compatible with - [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format). + """ + NTK-Scaled RoPE for RefinedWebModel """ def __init__( @@ -61,6 +60,10 @@ def forward(self, q, k): class LlamaLinearScaledRope(torch.nn.Module): + """ + reference: https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test + """ + def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None): super().__init__() self.scale = 1 / scale @@ -97,6 +100,11 @@ def forward(self, x, seq_len=None): class LlamaNTKScaledRope(torch.nn.Module): + + """ + reference: https://github.com/jquesnelle/scaled-rope + """ + def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None): super().__init__() base = base * alpha ** (dim / (dim - 2)) @@ -131,6 +139,10 @@ def forward(self, x, seq_len=None): class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module): + """ + reference: https://github.com/jquesnelle/scaled-rope + """ + def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None): super().__init__() self.ntk = ntk