From 34d94094279d2c903d9d8a51a65edb265f22c849 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 13 Jul 2023 16:47:30 +0100 Subject: [PATCH] Llama/GPTNeoX: add RoPE scaling (#24653) * add rope_scaling * tmp commit * add gptneox * add tests * GPTNeoX can now handle long inputs, so the pipeline test was wrong * Update src/transformers/models/open_llama/configuration_open_llama.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove ntk * remove redundant validation --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../models/gpt_neox/configuration_gpt_neox.py | 35 +++++ .../models/gpt_neox/modeling_gpt_neox.py | 126 ++++++++++++++---- .../modeling_gpt_neox_japanese.py | 25 ++-- .../models/llama/configuration_llama.py | 34 +++++ .../models/llama/modeling_llama.py | 91 +++++++++++-- .../open_llama/configuration_open_llama.py | 35 +++++ .../models/open_llama/modeling_open_llama.py | 96 +++++++++++-- .../models/gpt_neox/test_modeling_gpt_neox.py | 37 ++++- tests/models/llama/test_modeling_llama.py | 35 ++++- .../open_llama/test_modeling_open_llama.py | 35 ++++- .../test_pipelines_text_generation.py | 2 +- 11 files changed, 484 insertions(+), 67 deletions(-) diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py index 046d858bf3b1..24c4b4a3df03 100644 --- a/src/transformers/models/gpt_neox/configuration_gpt_neox.py +++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py @@ -78,6 +78,15 @@ class GPTNeoXConfig(PretrainedConfig): use_parallel_residual (`bool`, *optional*, defaults to `True`): Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training speedup at large scales (e.g. 20B). + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling + strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format + is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + Example: ```python @@ -115,6 +124,7 @@ def __init__( eos_token_id=2, tie_word_embeddings=False, use_parallel_residual=True, + rope_scaling=None, **kwargs, ): super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @@ -135,7 +145,32 @@ def __init__( self.use_cache = use_cache self.tie_word_embeddings = tie_word_embeddings self.use_parallel_residual = use_parallel_residual + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + if self.hidden_size % self.num_attention_heads != 0: raise ValueError( "The hidden size is not divisble by the number of attention heads! Make sure to update them!" ) + + # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 8546d85ffa3e..b1d694e61de1 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -86,6 +86,7 @@ def _set_gradient_checkpointing(self, module, value=False): class GPTNeoXAttention(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.num_attention_heads = config.num_attention_heads self.hidden_size = config.hidden_size if self.hidden_size % self.num_attention_heads != 0: @@ -94,18 +95,11 @@ def __init__(self, config): ) self.head_size = self.hidden_size // self.num_attention_heads self.rotary_ndims = int(self.head_size * config.rotary_pct) - max_positions = config.max_position_embeddings - self.register_buffer( - "bias", - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( - 1, 1, max_positions, max_positions - ), - persistent=False, - ) + self._init_bias(config.max_position_embeddings) + self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False) - self.rotary_emb = RotaryEmbedding( - self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base - ) + self._init_rope() + self.register_buffer( "norm_factor", torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()), @@ -113,9 +107,44 @@ def __init__(self, config): ) self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.attention_dropout = nn.Dropout(config.attention_dropout) + def _init_bias(self, max_positions, device=None): + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) + if device is not None: + self.bias = self.bias.to(device) + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = GPTNeoXRotaryEmbedding( + self.rotary_ndims, self.config.max_position_embeddings, base=self.config.rotary_emb_base + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = GPTNeoXLinearScalingRotaryEmbedding( + self.rotary_ndims, + self.config.max_position_embeddings, + base=self.config.rotary_emb_base, + scaling_factor=scaling_factor, + ) + elif scaling_type == "dynamic": + self.rotary_emb = GPTNeoXDynamicNTKScalingRotaryEmbedding( + self.rotary_ndims, + self.config.max_position_embeddings, + base=self.config.rotary_emb_base, + scaling_factor=scaling_factor, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + def forward( self, hidden_states: torch.FloatTensor, @@ -210,6 +239,9 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): batch_size, num_attention_heads, query_length, attn_head_size = query.size() key_length = key.size(-2) + # dynamically increase the causal mask with the key length, if needed. + if key_length > self.bias.shape[-1]: + self._init_bias(key_length, device=key.device) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) @@ -258,15 +290,23 @@ def attention_mask_func(attention_scores, ltor_mask): return attention_scores -class RotaryEmbedding(torch.nn.Module): +class GPTNeoXRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings, base=10000, device=None): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device) + + def _set_cos_sin_cache(self, seq_len, device): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) @@ -275,18 +315,56 @@ def __init__(self, dim, max_position_embeddings, base=10000, device=None): def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.cos_cached = emb.cos()[None, None, :, :] - self.sin_cached = emb.sin()[None, None, :, :] + self._set_cos_sin_cache(seq_len=seq_len, device=x.device) return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) +class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): + """GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + + +class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): + """GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index e7cb510e6222..7df64174a599 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -238,16 +238,24 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): return attn_output, attn_weights -# Copied from transformers.models.gpt_neox.modeling_gpt_neox.RotaryEmbedding +# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings, base=10000, device=None): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device) + + def _set_cos_sin_cache(self, seq_len, device): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) @@ -256,15 +264,8 @@ def __init__(self, dim, max_position_embeddings, base=10000, device=None): def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.cos_cached = emb.cos()[None, None, :, :] - self.sin_cached = emb.sin()[None, None, :, :] + self._set_cos_sin_cache(seq_len=seq_len, device=x.device) return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index e3d075310e84..d456b79e66fb 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -64,6 +64,15 @@ class LlamaConfig(PretrainedConfig): relevant if `config.is_decoder=True`. tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling + strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format + is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + Example: ```python @@ -97,6 +106,7 @@ def __init__( bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, + rope_scaling=None, **kwargs, ): self.vocab_size = vocab_size @@ -109,6 +119,9 @@ def __init__( self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, @@ -116,3 +129,24 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 41a077c32fc9..6cdbb2623a7d 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -91,36 +91,84 @@ def forward(self, hidden_states): class LlamaRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - dtype = torch.get_default_dtype() self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -176,7 +224,24 @@ def __init__(self, config: LlamaConfig): self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() diff --git a/src/transformers/models/open_llama/configuration_open_llama.py b/src/transformers/models/open_llama/configuration_open_llama.py index cbde4d67d498..c0629b31e812 100644 --- a/src/transformers/models/open_llama/configuration_open_llama.py +++ b/src/transformers/models/open_llama/configuration_open_llama.py @@ -67,6 +67,15 @@ class OpenLlamaConfig(PretrainedConfig): relevant if `config.is_decoder=True`. tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling + strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format + is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + Example: ```python @@ -104,6 +113,7 @@ def __init__( attention_dropout_prob=0.1, use_stable_embedding=True, shared_input_output_embedding=True, + rope_scaling=None, **kwargs, ): self.vocab_size = vocab_size @@ -123,6 +133,9 @@ def __init__( self.attention_dropout_prob = attention_dropout_prob self.use_stable_embedding = use_stable_embedding self.shared_input_output_embedding = shared_input_output_embedding + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, @@ -130,3 +143,25 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) + + # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/open_llama/modeling_open_llama.py b/src/transformers/models/open_llama/modeling_open_llama.py index db05b868b066..34dc350dc556 100644 --- a/src/transformers/models/open_llama/modeling_open_llama.py +++ b/src/transformers/models/open_llama/modeling_open_llama.py @@ -102,36 +102,86 @@ def forward(self, hidden_states): class OpenLlamaRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - dtype = torch.get_default_dtype() self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->OpenLlama +class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): + """OpenLlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->OpenLlama +class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): + """OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -190,7 +240,27 @@ def __init__(self, config: OpenLlamaConfig): self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.rotary_emb = OpenLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + self._init_rope() + + # Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->OpenLlama + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = OpenLlamaRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = OpenLlamaLinearScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + ) + elif scaling_type == "dynamic": + self.rotary_emb = OpenLlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index ed9b5764a361..176970779ed5 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -17,7 +17,9 @@ import unittest -from transformers import AutoTokenizer, GPTNeoXConfig, is_torch_available +from parameterized import parameterized + +from transformers import AutoTokenizer, GPTNeoXConfig, is_torch_available, set_seed from transformers.testing_utils import require_torch, slow, torch_device from ...generation.test_utils import GenerationTesterMixin @@ -49,7 +51,7 @@ def __init__( use_token_type_ids=True, use_labels=True, vocab_size=99, - hidden_size=32, + hidden_size=64, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37, @@ -298,6 +300,37 @@ def test_model_for_token_classification(self): def test_feed_forward_chunking(self): pass + @parameterized.expand([("linear",), ("dynamic",)]) + def test_model_rope_scaling(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = GPTNeoXModel(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = GPTNeoXModel(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + @require_torch class GPTNeoXLanguageGenerationTest(unittest.TestCase): diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 113f74d30976..e8b808461945 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -17,7 +17,9 @@ import unittest -from transformers import LlamaConfig, is_torch_available +from parameterized import parameterized + +from transformers import LlamaConfig, is_torch_available, set_seed from transformers.testing_utils import require_torch, torch_device from ...generation.test_utils import GenerationTesterMixin @@ -332,3 +334,34 @@ def test_llama_sequence_classification_model_for_multi_label(self): @unittest.skip("LLaMA buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass + + @parameterized.expand([("linear",), ("dynamic",)]) + def test_model_rope_scaling(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = LlamaModel(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = LlamaModel(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) diff --git a/tests/models/open_llama/test_modeling_open_llama.py b/tests/models/open_llama/test_modeling_open_llama.py index 6cdf61a8729d..687b267b70ca 100644 --- a/tests/models/open_llama/test_modeling_open_llama.py +++ b/tests/models/open_llama/test_modeling_open_llama.py @@ -17,7 +17,9 @@ import unittest -from transformers import OpenLlamaConfig, is_torch_available +from parameterized import parameterized + +from transformers import OpenLlamaConfig, is_torch_available, set_seed from transformers.testing_utils import require_torch, torch_device from ...generation.test_utils import GenerationTesterMixin @@ -335,3 +337,34 @@ def test_open_llama_sequence_classification_model_for_multi_label(self): @unittest.skip("Open-Llama buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass + + @parameterized.expand([("linear",), ("dynamic",)]) + def test_model_rope_scaling(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = OpenLlamaModel(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = OpenLlamaModel(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index b4605164aefd..44a29a673d81 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -240,7 +240,7 @@ def run_pipeline_test(self, text_generator, _): # We don't care about infinite range models. # They already work. # Skip this test for XGLM, since it uses sinusoidal positional embeddings which are resized on-the-fly. - EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS = ["RwkvForCausalLM", "XGLMForCausalLM"] + EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS = ["RwkvForCausalLM", "XGLMForCausalLM", "GPTNeoXForCausalLM"] if ( tokenizer.model_max_length < 10000 and text_generator.model.__class__.__name__ not in EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS