From 7b0b002b89a9fa9cf7fca1d59091187fc71f973c Mon Sep 17 00:00:00 2001 From: Ilya Vologin Date: Mon, 9 Oct 2023 17:23:30 +0200 Subject: [PATCH 1/4] Add rope_theta for llama config --- .../inference/csrc/apply_rotary_pos_emb.cu | 20 +++++++++++++++---- .../transformer/inference/csrc/pt_binding.cpp | 9 +++++++-- .../includes/inference_cuda_layers.h | 1 + deepspeed/module_inject/containers/llama.py | 1 + deepspeed/ops/transformer/inference/config.py | 4 +++- .../inference/op_binding/softmax_context.py | 9 +++++---- 6 files changed, 33 insertions(+), 11 deletions(-) diff --git a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu index e743ffc3f64f..e326c762c0f3 100644 --- a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu +++ b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -32,6 +32,7 @@ __global__ void apply_rotary_pos_half(T* mixed_query, unsigned num_heads, unsigned head_size, unsigned total_count, + float rope_theta, int max_out_tokens) { constexpr int T_per_thread = granularity / sizeof(T); @@ -61,7 +62,7 @@ __global__ void apply_rotary_pos_half(T* mixed_query, const int neuron_idx = base_neuron_idx + i; if (neuron_idx < rotary_dim) { float inv_freq = (float)((neuron_idx % half_dim) * 2) / (float)rotary_dim; - inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_idx; + inv_freq = 1.0 / powf(rope_theta, inv_freq) * (float)seq_idx; float rotary_sign = (neuron_idx > (half_dim - 1) ? -1.0 : 1.0); float q_rot = conversion::to(q[i]) * rotary_sign; @@ -95,6 +96,7 @@ __global__ void apply_rotary_pos_half(T* mixed_query, num_heads, \ head_size, \ total_count, \ + rope_theta, \ max_out_tokens); #ifdef __HIP_PLATFORM_HCC__ @@ -136,6 +138,7 @@ void launch_apply_rotary_pos_emb(T* mixed_query, unsigned offset, unsigned num_heads, unsigned batch, + float rope_theta, cudaStream_t stream, int max_out_tokens) { @@ -176,9 +179,18 @@ void launch_apply_rotary_pos_emb(T* mixed_query, } } -#define INSTANTIATE_LAUNCH_ROTARY_POS_EMB(T) \ - template void launch_apply_rotary_pos_emb( \ - T*, T*, unsigned, unsigned, unsigned, unsigned, unsigned, unsigned, cudaStream_t, int); +#define INSTANTIATE_LAUNCH_ROTARY_POS_EMB(T) \ + template void launch_apply_rotary_pos_emb(T*, \ + T*, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + float, \ + cudaStream_t, \ + int); INSTANTIATE_LAUNCH_ROTARY_POS_EMB(float); #ifdef BF16_AVAILABLE diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 634b6e3adbbb..61f89090c4cc 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -445,7 +445,8 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, bool no_masking, unsigned layer_id, unsigned num_layers, - at::Tensor& alibi) + at::Tensor& alibi, + float rope_theta) { unsigned bsz = query_key_value.size(0); unsigned seq_len = query_key_value.size(1); @@ -503,6 +504,7 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, (is_prompt ? 0 : soft_len - 1), heads, bsz, + rope_theta, InferenceContext::Instance().GetCurrentStream(), InferenceContext::Instance().GetMaxTokenLength()); @@ -1847,7 +1849,8 @@ std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, unsigned rotary_dim, unsigned offset, unsigned num_heads, - bool rotate_half) + bool rotate_half, + float rope_theta) { auto query_cont = mixed_query.contiguous(); auto key_cont = key_layer.contiguous(); @@ -1865,6 +1868,7 @@ std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, offset, num_heads, bsz, + rope_theta, InferenceContext::Instance().GetCurrentStream(), InferenceContext::Instance().GetMaxTokenLength()); else @@ -1876,6 +1880,7 @@ std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, offset, num_heads, bsz, + rope_theta, InferenceContext::Instance().GetCurrentStream(), InferenceContext::Instance().GetMaxTokenLength()); return {query_cont, key_cont}; diff --git a/csrc/transformer/inference/includes/inference_cuda_layers.h b/csrc/transformer/inference/includes/inference_cuda_layers.h index 5240ebb1d524..9bf7b26f3b2f 100644 --- a/csrc/transformer/inference/includes/inference_cuda_layers.h +++ b/csrc/transformer/inference/includes/inference_cuda_layers.h @@ -168,6 +168,7 @@ void launch_apply_rotary_pos_emb(T* mixed_query, unsigned offset, unsigned num_heads, unsigned batch, + float rope_theta, cudaStream_t stream, int max_out_tokens); diff --git a/deepspeed/module_inject/containers/llama.py b/deepspeed/module_inject/containers/llama.py index af99d658017c..f6157e5cdfed 100644 --- a/deepspeed/module_inject/containers/llama.py +++ b/deepspeed/module_inject/containers/llama.py @@ -34,6 +34,7 @@ def create_module(self, config=None): _config.rotate_half = True _config.rotate_every_two = False _config.rotary_dim = self.hidden_size // self.num_attention_heads + _config.rope_theta = self.policy.client_module.self_attn.rope_theta self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group) return self.module diff --git a/deepspeed/ops/transformer/inference/config.py b/deepspeed/ops/transformer/inference/config.py index 4e29a2137c64..d5aff4f541f7 100644 --- a/deepspeed/ops/transformer/inference/config.py +++ b/deepspeed/ops/transformer/inference/config.py @@ -79,7 +79,8 @@ def __init__(self, transposed_mode=False, use_triton=False, triton_autotune=False, - num_kv=-1): + num_kv=-1, + rope_theta=10000): super(DeepSpeedInferenceConfig, self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads, num_hidden_layers) @@ -114,6 +115,7 @@ def __init__(self, self.use_triton = use_triton self.triton_autotune = triton_autotune self.num_kv = num_kv + self.rope_theta = rope_theta @classmethod def from_dict(cls, json_object): diff --git a/deepspeed/ops/transformer/inference/op_binding/softmax_context.py b/deepspeed/ops/transformer/inference/op_binding/softmax_context.py index 012399ea1ef3..0dc4e08a3633 100644 --- a/deepspeed/ops/transformer/inference/op_binding/softmax_context.py +++ b/deepspeed/ops/transformer/inference/op_binding/softmax_context.py @@ -23,9 +23,9 @@ def __init__(self, config: DeepSpeedInferenceConfig): except AttributeError: self.softmax_context_func = self.softmax_context_fallback - def softmax_context_fallback(self, query_key_value, attn_mask, rotary_dim, rotate_half, roteate_every_two, heads, - norm_factor, triangular_masking, local_attention, window_size, no_masking, layer_id, - num_layers, alibi): + def softmax_context_fallback(self, query_key_value, attn_mask, rotary_dim, rotate_half, rotate_every_two, heads, + num_kv, norm_factor, triangular_masking, local_attention, window_size, no_masking, + layer_id, num_layers, alibi, rope_theta): raise NotImplementedError def forward(self, query_key_value: torch.Tensor, attn_mask: torch.Tensor, heads: int, num_kv: int, @@ -41,6 +41,7 @@ def forward(self, query_key_value: torch.Tensor, attn_mask: torch.Tensor, heads: output = self.softmax_context_func(query_key_value, attn_mask, self.config.rotary_dim, self.config.rotate_half, self.config.rotate_every_two, heads, num_kv, norm_factor, self.config.triangular_masking, self.config.local_attention, - self.config.window_size, no_masking, layer_id, num_layers, alibi) + self.config.window_size, no_masking, layer_id, num_layers, alibi, + self.config.rope_theta) return output From f1b119ef52fda7d4fcb3615ef976ed0c6cbb1549 Mon Sep 17 00:00:00 2001 From: Ilya Vologin Date: Thu, 12 Oct 2023 12:22:52 +0200 Subject: [PATCH 2/4] Add rope_theta to bias_add_transform_0213 --- .../transformer/inference/csrc/pt_binding.cpp | 3 ++- csrc/transformer/inference/csrc/transform.cu | 25 ++++++++++++------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 61f89090c4cc..06319852c513 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -494,7 +494,8 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, rotate_every_two, InferenceContext::Instance().GetCurrentStream(), 3, - InferenceContext::Instance().GetMaxTokenLength()); + InferenceContext::Instance().GetMaxTokenLength(), + rope_theta); if (rotary_dim > 0 && rotate_half) launch_apply_rotary_pos_emb(query_cont, kv_cache, diff --git a/csrc/transformer/inference/csrc/transform.cu b/csrc/transformer/inference/csrc/transform.cu index 0b8bffa643c6..06b29647ab2a 100644 --- a/csrc/transformer/inference/csrc/transform.cu +++ b/csrc/transformer/inference/csrc/transform.cu @@ -32,7 +32,8 @@ __global__ void bias_add_transform_0213(float* output, bool rotate_half, bool rotate_every_two, int head_ext, - int max_out_tokens) + int max_out_tokens, + float rope_theta) { int d0_stride = hidden_dim * seq_length; int d1_stride = hidden_dim; @@ -70,7 +71,7 @@ __global__ void bias_add_transform_0213(float* output, #pragma unroll for (int o = 0; o < 2; o++) { float inv_freq = (float)(((d3 << 1) + o) * 2) / (float)(rotary_dim << 2); - inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; + inv_freq = 1.0 / powf(rope_theta, inv_freq) * (float)seq_id; q_f[o].x = (-1.0 * q_f[o].y * sinf(inv_freq) + q_f[o].x * cosf(inv_freq)); q_f[o].y = (q_f[o].x * sinf(inv_freq) + q_f[o].y * cosf(inv_freq)); } @@ -100,7 +101,8 @@ __global__ void bias_add_transform_0213(T* output, // q bool rotate_half, bool rotate_every_two, int head_ext, - int max_out_tokens) + int max_out_tokens, + float rope_theta) { using T2 = typename std::conditional::value, __half2, __nv_bfloat162>::type; @@ -147,7 +149,7 @@ __global__ void bias_add_transform_0213(T* output, // q #pragma unroll for (int o = 0; o < 4; o++) { float inv_freq = (float)(((d3 << 2) + o) * 2) / (float)(rotary_dim << 3); - inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; + inv_freq = 1.0 / powf(rope_theta, inv_freq) * (float)seq_id; float q_data[2]; q_data[0] = conversion::to(q_h[o].x); q_data[1] = conversion::to(q_h[o].y); @@ -181,7 +183,8 @@ void launch_bias_add_transform_0213(float* output, bool rotate_every_two, cudaStream_t stream, int trans_count, - int max_out_tokens) + int max_out_tokens, + float rope_theta) { hidden_dim >>= 2; int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; @@ -204,7 +207,8 @@ void launch_bias_add_transform_0213(float* output, rotate_half, rotate_every_two, head_ext, - max_out_tokens); + max_out_tokens, + rope_theta); } template @@ -225,7 +229,8 @@ void launch_bias_add_transform_0213(T* output, bool rotate_every_two, cudaStream_t stream, int trans_count, - int max_out_tokens) + int max_out_tokens, + float rope_theta) { hidden_dim >>= 3; int head_ext = 1; // (hidden_dim - 1) / MAX_THREADS + 1; @@ -247,7 +252,8 @@ void launch_bias_add_transform_0213(T* output, rotate_half, rotate_every_two, head_ext, - max_out_tokens); + max_out_tokens, + rope_theta); } #define INSTANTIATE_LAUNCH_BIAS_ADD_TRANSFORM_0213(T) \ @@ -268,7 +274,8 @@ void launch_bias_add_transform_0213(T* output, bool, \ cudaStream_t, \ int, \ - int) + int, \ + float) #ifdef BF16_AVAILABLE INSTANTIATE_LAUNCH_BIAS_ADD_TRANSFORM_0213(__nv_bfloat16); From 6e131a6c048dea3e13b82ab247b57fa5b2686dae Mon Sep 17 00:00:00 2001 From: Ilya Vologin Date: Mon, 16 Oct 2023 12:49:03 +0200 Subject: [PATCH 3/4] Fix CI problems --- csrc/transformer/inference/includes/inference_cuda_layers.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/transformer/inference/includes/inference_cuda_layers.h b/csrc/transformer/inference/includes/inference_cuda_layers.h index 9bf7b26f3b2f..dcc020483687 100644 --- a/csrc/transformer/inference/includes/inference_cuda_layers.h +++ b/csrc/transformer/inference/includes/inference_cuda_layers.h @@ -208,7 +208,8 @@ void launch_bias_add_transform_0213(T* outputs, bool rotate_every_two, cudaStream_t stream, int trans_count, - int max_out_tokens); + int max_out_tokens, + float rope_theta); template void pad_data(T* padded_output, T* output, From 0e12d1d44eae987f87b4cc603873546114bcbe1e Mon Sep 17 00:00:00 2001 From: Ilya Vologin Date: Tue, 17 Oct 2023 12:42:12 +0200 Subject: [PATCH 4/4] Add rope_theta to linear layer --- csrc/transformer/inference/csrc/pt_binding.cpp | 9 ++++++--- deepspeed/ops/transformer/inference/op_binding/linear.py | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 06319852c513..4fd64112e148 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -1103,7 +1103,8 @@ at::Tensor ds_linear_layer(at::Tensor& input, bool add_bias, bool do_flash_attn, int num_heads, - bool transposed_mode) + bool transposed_mode, + float rope_theta) { auto input_cont = input.contiguous(); auto options = at::TensorOptions() @@ -1177,7 +1178,8 @@ at::Tensor ds_linear_layer(at::Tensor& input, false, InferenceContext::Instance().GetCurrentStream(), 3, - input.size(1)); + input.size(1), + rope_theta); return at::from_blob(final_output, {3, input.size(0), num_heads, input.size(1), padded_head_size}, options); @@ -1203,7 +1205,8 @@ at::Tensor ds_linear_layer(at::Tensor& input, false, InferenceContext::Instance().GetCurrentStream(), 3, - input.size(1)); + input.size(1), + rope_theta); return at::from_blob( final_output, {3, input.size(0), num_heads, input.size(1), head_size}, options); // return at::from_blob(workspace, {input.size(0) * input.size(1), 3, num_heads, diff --git a/deepspeed/ops/transformer/inference/op_binding/linear.py b/deepspeed/ops/transformer/inference/op_binding/linear.py index e970b562c6d6..b8decb6dc5ea 100644 --- a/deepspeed/ops/transformer/inference/op_binding/linear.py +++ b/deepspeed/ops/transformer/inference/op_binding/linear.py @@ -31,7 +31,7 @@ def __init__(self, config: DeepSpeedInferenceConfig): except AttributeError: self.linear_func = self.linear_fallback - def linear_fallback(self, input, weight, bias, add_bias, do_flash_attn, num_heads, transpose): + def linear_fallback(self, input, weight, bias, add_bias, do_flash_attn, num_heads, transpose, rope_theta): raise NotImplementedError def forward(self, @@ -44,7 +44,7 @@ def forward(self, external_cache: bool = None, num_layers: int = None): qkv_out = self.linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads, - self.config.transposed_mode) + self.config.transposed_mode, self.config.rope_theta) return qkv_out @staticmethod