Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug fix] Add rope_theta for llama config #4480

Merged
merged 9 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ __global__ void apply_rotary_pos_half(T* mixed_query,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
float rope_theta,
int max_out_tokens)
{
constexpr int T_per_thread = granularity / sizeof(T);
Expand Down Expand Up @@ -61,7 +62,7 @@ __global__ void apply_rotary_pos_half(T* mixed_query,
const int neuron_idx = base_neuron_idx + i;
if (neuron_idx < rotary_dim) {
float inv_freq = (float)((neuron_idx % half_dim) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_idx;
inv_freq = 1.0 / powf(rope_theta, inv_freq) * (float)seq_idx;

float rotary_sign = (neuron_idx > (half_dim - 1) ? -1.0 : 1.0);
float q_rot = conversion::to<float>(q[i]) * rotary_sign;
Expand Down Expand Up @@ -95,6 +96,7 @@ __global__ void apply_rotary_pos_half(T* mixed_query,
num_heads, \
head_size, \
total_count, \
rope_theta, \
max_out_tokens);

#ifdef __HIP_PLATFORM_HCC__
Expand Down Expand Up @@ -136,6 +138,7 @@ void launch_apply_rotary_pos_emb(T* mixed_query,
unsigned offset,
unsigned num_heads,
unsigned batch,
float rope_theta,
cudaStream_t stream,
int max_out_tokens)
{
Expand Down Expand Up @@ -176,9 +179,18 @@ void launch_apply_rotary_pos_emb(T* mixed_query,
}
}

#define INSTANTIATE_LAUNCH_ROTARY_POS_EMB(T) \
template void launch_apply_rotary_pos_emb<T>( \
T*, T*, unsigned, unsigned, unsigned, unsigned, unsigned, unsigned, cudaStream_t, int);
#define INSTANTIATE_LAUNCH_ROTARY_POS_EMB(T) \
template void launch_apply_rotary_pos_emb<T>(T*, \
T*, \
unsigned, \
unsigned, \
unsigned, \
unsigned, \
unsigned, \
unsigned, \
float, \
cudaStream_t, \
int);

INSTANTIATE_LAUNCH_ROTARY_POS_EMB(float);
#ifdef BF16_AVAILABLE
Expand Down
12 changes: 9 additions & 3 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,8 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
bool no_masking,
unsigned layer_id,
unsigned num_layers,
at::Tensor& alibi)
at::Tensor& alibi,
float rope_theta)
{
unsigned bsz = query_key_value.size(0);
unsigned seq_len = query_key_value.size(1);
Expand Down Expand Up @@ -493,7 +494,8 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
rotate_every_two,
InferenceContext::Instance().GetCurrentStream(),
3,
InferenceContext::Instance().GetMaxTokenLength());
InferenceContext::Instance().GetMaxTokenLength(),
rope_theta);
if (rotary_dim > 0 && rotate_half)
launch_apply_rotary_pos_emb(query_cont,
kv_cache,
Expand All @@ -503,6 +505,7 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
(is_prompt ? 0 : soft_len - 1),
heads,
bsz,
rope_theta,
InferenceContext::Instance().GetCurrentStream(),
InferenceContext::Instance().GetMaxTokenLength());

Expand Down Expand Up @@ -1847,7 +1850,8 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
bool rotate_half)
bool rotate_half,
float rope_theta)
{
auto query_cont = mixed_query.contiguous();
auto key_cont = key_layer.contiguous();
Expand All @@ -1865,6 +1869,7 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
offset,
num_heads,
bsz,
rope_theta,
InferenceContext::Instance().GetCurrentStream(),
InferenceContext::Instance().GetMaxTokenLength());
else
Expand All @@ -1876,6 +1881,7 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
offset,
num_heads,
bsz,
rope_theta,
InferenceContext::Instance().GetCurrentStream(),
InferenceContext::Instance().GetMaxTokenLength());
return {query_cont, key_cont};
Expand Down
25 changes: 16 additions & 9 deletions csrc/transformer/inference/csrc/transform.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ __global__ void bias_add_transform_0213(float* output,
bool rotate_half,
bool rotate_every_two,
int head_ext,
int max_out_tokens)
int max_out_tokens,
float rope_theta)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
Expand Down Expand Up @@ -70,7 +71,7 @@ __global__ void bias_add_transform_0213(float* output,
#pragma unroll
for (int o = 0; o < 2; o++) {
float inv_freq = (float)(((d3 << 1) + o) * 2) / (float)(rotary_dim << 2);
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
inv_freq = 1.0 / powf(rope_theta, inv_freq) * (float)seq_id;
q_f[o].x = (-1.0 * q_f[o].y * sinf(inv_freq) + q_f[o].x * cosf(inv_freq));
q_f[o].y = (q_f[o].x * sinf(inv_freq) + q_f[o].y * cosf(inv_freq));
}
Expand Down Expand Up @@ -100,7 +101,8 @@ __global__ void bias_add_transform_0213(T* output, // q
bool rotate_half,
bool rotate_every_two,
int head_ext,
int max_out_tokens)
int max_out_tokens,
float rope_theta)
{
using T2 =
typename std::conditional<std::is_same<T, __half>::value, __half2, __nv_bfloat162>::type;
Expand Down Expand Up @@ -147,7 +149,7 @@ __global__ void bias_add_transform_0213(T* output, // q
#pragma unroll
for (int o = 0; o < 4; o++) {
float inv_freq = (float)(((d3 << 2) + o) * 2) / (float)(rotary_dim << 3);
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
inv_freq = 1.0 / powf(rope_theta, inv_freq) * (float)seq_id;
float q_data[2];
q_data[0] = conversion::to<float>(q_h[o].x);
q_data[1] = conversion::to<float>(q_h[o].y);
Expand Down Expand Up @@ -181,7 +183,8 @@ void launch_bias_add_transform_0213<float>(float* output,
bool rotate_every_two,
cudaStream_t stream,
int trans_count,
int max_out_tokens)
int max_out_tokens,
float rope_theta)
{
hidden_dim >>= 2;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
Expand All @@ -204,7 +207,8 @@ void launch_bias_add_transform_0213<float>(float* output,
rotate_half,
rotate_every_two,
head_ext,
max_out_tokens);
max_out_tokens,
rope_theta);
}

template <typename T>
Expand All @@ -225,7 +229,8 @@ void launch_bias_add_transform_0213(T* output,
bool rotate_every_two,
cudaStream_t stream,
int trans_count,
int max_out_tokens)
int max_out_tokens,
float rope_theta)
{
hidden_dim >>= 3;
int head_ext = 1; // (hidden_dim - 1) / MAX_THREADS + 1;
Expand All @@ -247,7 +252,8 @@ void launch_bias_add_transform_0213(T* output,
rotate_half,
rotate_every_two,
head_ext,
max_out_tokens);
max_out_tokens,
rope_theta);
}

#define INSTANTIATE_LAUNCH_BIAS_ADD_TRANSFORM_0213(T) \
Expand All @@ -268,7 +274,8 @@ void launch_bias_add_transform_0213(T* output,
bool, \
cudaStream_t, \
int, \
int)
int, \
float)

#ifdef BF16_AVAILABLE
INSTANTIATE_LAUNCH_BIAS_ADD_TRANSFORM_0213(__nv_bfloat16);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ void launch_apply_rotary_pos_emb(T* mixed_query,
unsigned offset,
unsigned num_heads,
unsigned batch,
float rope_theta,
cudaStream_t stream,
int max_out_tokens);

Expand Down
1 change: 1 addition & 0 deletions deepspeed/module_inject/containers/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def create_module(self, config=None):
_config.rotate_half = True
_config.rotate_every_two = False
_config.rotary_dim = self.hidden_size // self.num_attention_heads
_config.rope_theta = self.policy.client_module.self_attn.rope_theta
self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group)

return self.module
Expand Down
4 changes: 3 additions & 1 deletion deepspeed/ops/transformer/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def __init__(self,
transposed_mode=False,
use_triton=False,
triton_autotune=False,
num_kv=-1):
num_kv=-1,
rope_theta=10000):
super(DeepSpeedInferenceConfig,
self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads,
num_hidden_layers)
Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(self,
self.use_triton = use_triton
self.triton_autotune = triton_autotune
self.num_kv = num_kv
self.rope_theta = rope_theta

@classmethod
def from_dict(cls, json_object):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def __init__(self, config: DeepSpeedInferenceConfig):
except AttributeError:
self.softmax_context_func = self.softmax_context_fallback

def softmax_context_fallback(self, query_key_value, attn_mask, rotary_dim, rotate_half, roteate_every_two, heads,
norm_factor, triangular_masking, local_attention, window_size, no_masking, layer_id,
num_layers, alibi):
def softmax_context_fallback(self, query_key_value, attn_mask, rotary_dim, rotate_half, rotate_every_two, heads,
num_kv, norm_factor, triangular_masking, local_attention, window_size, no_masking,
layer_id, num_layers, alibi, rope_theta):
raise NotImplementedError

def forward(self, query_key_value: torch.Tensor, attn_mask: torch.Tensor, heads: int, num_kv: int,
Expand All @@ -41,6 +41,7 @@ def forward(self, query_key_value: torch.Tensor, attn_mask: torch.Tensor, heads:
output = self.softmax_context_func(query_key_value, attn_mask, self.config.rotary_dim, self.config.rotate_half,
self.config.rotate_every_two, heads, num_kv, norm_factor,
self.config.triangular_masking, self.config.local_attention,
self.config.window_size, no_masking, layer_id, num_layers, alibi)
self.config.window_size, no_masking, layer_id, num_layers, alibi,
self.config.rope_theta)

return output
Loading