Skip to content

Commit

Permalink
Passing in attn_score_scaling_factor into tvm_wrapper (#126)
Browse files Browse the repository at this point in the history
In GPT-2, attention calculation requires an additional feature
scale_attn_by_inverse_layer_idx. It provides a scaling factor per
attention layer when calculating the attention score, before applying
the softmax function.

This PR supports this additional parameter in tvm_wrapper.

See: apache/tvm#16606
  • Loading branch information
rickzx authored Feb 19, 2024
1 parent 1b75874 commit f1f6a0d
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/tvm_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q
int64_t causal = 1, //
int64_t rotary_mode = 0, //
double rope_scale = 1.0f, //
double rope_theta = 1e4) {
double rope_theta = 1e4,
double attn_score_scaling_factor = 1.0f) {
CHECK(handler_id < max_num_handlers) << "The handler id must be less than " << max_num_handlers;
CHECK_EQ(q_data->device.device_type, kDLCUDA) << "The device of q_data must be CUDA.";
CHECK_EQ(pages->device.device_type, kDLCUDA) << "The device of kv pages must be CUDA.";
Expand All @@ -238,6 +239,7 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q
CHECK_EQ(qo_indptr->device.device_type, kDLCUDA)
<< "The device of qo_indptr matrix must be CUDA.";
CHECK_EQ(output->device.device_type, kDLCUDA) << "The device of output must be CUDA.";
CHECK_EQ(attn_score_scaling_factor, 1.0f) << "The attention score scaling factor must be 1.0.";

int32_t dev_id = q_data->device.device_id;
CHECK_EQ(pages->device.device_id, dev_id);
Expand Down Expand Up @@ -355,7 +357,8 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_
DLTensor* lse, //
int64_t rotary_mode = 0, //
double rope_scale = 1.0f, //
double rope_theta = 1e4) {
double rope_theta = 1e4,
double attn_score_scaling_factor = 1.0f) {
CHECK_LT(handler_id, max_num_handlers) << "The handler id must be less than " << max_num_handlers;
CHECK_EQ(q_data->device.device_type, kDLCUDA) << "The device of q_data must be CUDA.";
CHECK_EQ(pages->device.device_type, kDLCUDA) << "The device of kv pages must be CUDA.";
Expand All @@ -370,6 +373,7 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_
CHECK_EQ(k_rope_pos_offset->device.device_type, kDLCUDA)
<< "The device of k_rope_pos_offset matrix must be CUDA.";
CHECK_EQ(output->device.device_type, kDLCUDA) << "The device of output must be CUDA.";
CHECK_EQ(attn_score_scaling_factor, 1.0f) << "The attention score scaling factor must be 1.0.";

int32_t dev_id = q_data->device.device_id;
CHECK_EQ(pages->device.device_id, dev_id);
Expand Down Expand Up @@ -511,7 +515,7 @@ void _FlashInferAttentionPrefillWithRaggedKVCache(
DLTensor* q_data, DLTensor* qo_indptr, DLTensor* k_data, DLTensor* v_data, DLTensor* kv_indptr,
DLTensor* q_rope_position_map, DLTensor* k_rope_pos_offset, DLTensor* output, DLTensor* lse,
int64_t causal = 1, int64_t rotary_mode = 0, double rope_scale = 1.0f,
double rope_theta = 1e4) {
double rope_theta = 1e4, double attn_score_scaling_factor = 1.0f) {
CHECK_EQ(q_data->device.device_type, kDLCUDA) << "The device of q_data must be CUDA.";
CHECK_EQ(qo_indptr->device.device_type, kDLCUDA) << "The device of qo_indptr must be CUDA.";
CHECK_EQ(k_data->device.device_type, kDLCUDA) << "The device of k_data must be CUDA.";
Expand All @@ -523,6 +527,7 @@ void _FlashInferAttentionPrefillWithRaggedKVCache(
<< "The device of q_rope_position_map must be CUDA.";
CHECK_EQ(k_rope_pos_offset->device.device_type, kDLCUDA)
<< "The device of k_rope_pos_offset must be CUDA.";
CHECK_EQ(attn_score_scaling_factor, 1.0f) << "The attention score scaling factor must be 1.0.";

int dev_id = q_data->device.device_id;
CHECK_EQ(qo_indptr->device.device_id, dev_id);
Expand Down

0 comments on commit f1f6a0d

Please sign in to comment.