-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cherry-pick fused_rope from develop #55931
Conversation
* style * more * update ctest * Update legacy_backward.yaml * Update legacy_ops.yaml * Update legacy_ops.yaml * update * update * update for move
你的PR提交成功,感谢你对开源项目的贡献! |
out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v) | ||
""" | ||
if in_dynamic_mode(): | ||
return _C_ops.fused_rotary_position_embedding(q, k, v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不支持静态图?加个assert?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
MPType div_c) { | ||
int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; | ||
int stride = gridDim.x * blockDim.x * VecSize; | ||
int size = batch_size * seq_len * num_heads * head_dim; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用int64_t
防止溢出。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
for (int nx = 0; nx < VecSize; ++nx) { | ||
// get sin_index and cos_index | ||
int index_wc = (index + nx) % (seq_len * num_heads * head_dim); | ||
int pos_seq = index_wc / (num_heads * head_dim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用int64_t
。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
DenseTensor* dq, | ||
DenseTensor* dk, | ||
DenseTensor* dv) { | ||
int numel = dout_q.numel(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用int64_t
。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
MPType div_c) { | ||
int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; | ||
int stride = gridDim.x * blockDim.x * VecSize; | ||
int size = batch_size * seq_len * num_heads * head_dim; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用int64_t
。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
for (int nx = 0; nx < VecSize; ++nx) { | ||
// get sin_index and cos_index | ||
int index_wc = (index + nx) % (seq_len * num_heads * head_dim); | ||
int pos_seq = index_wc / (num_heads * head_dim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用int64_t
。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
DenseTensor* out_q, | ||
DenseTensor* out_k, | ||
DenseTensor* out_v) { | ||
int numel = q.numel(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用int64_t
。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
int numel = q.numel(); | ||
if (numel <= 0) return; | ||
dev_ctx.template Alloc<T>(out_q); | ||
out_q->Resize(q.dims()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个Resize是冗余的?因为在InferMeta的时候已经设置过shape了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
if (k.get_ptr()) { | ||
dev_ctx.template Alloc<T>(out_k); | ||
out_k->Resize(q.dims()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个Resize是冗余的?因为在InferMeta的时候已经设置过shape了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
if (v.get_ptr()) { | ||
dev_ctx.template Alloc<T>(out_v); | ||
out_v->Resize(q.dims()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个Resize是冗余的?因为在InferMeta的时候已经设置过shape了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
phi::Array<T*, 3> outs_data, | ||
int num_inputs, | ||
MPType div_c) { | ||
int64_t index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前曾经遇到过blockIdx.x * blockDim.x + threadIdx.x
计算本身越界,导致出现负数的情况,因此建议做static_cast<int64_t>(blockDim.x)
后运算,下同。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
…addlePaddle#55931) * Add fused_rope forward op (PaddlePaddle#54351) * style * more * update ctest * Update legacy_backward.yaml * Update legacy_ops.yaml * Update legacy_ops.yaml * update * update * update for move * Update the rope op according to the comments (PaddlePaddle#54985) * Update multiary.cc * Update __init__.py * for int64_t and assert * more * remove useless assert first --------- Co-authored-by: sneaxiy <[email protected]>
…addlePaddle#55931) * Add fused_rope forward op (PaddlePaddle#54351) * style * more * update ctest * Update legacy_backward.yaml * Update legacy_ops.yaml * Update legacy_ops.yaml * update * update * update for move * Update the rope op according to the comments (PaddlePaddle#54985) * Update multiary.cc * Update __init__.py * for int64_t and assert * more * remove useless assert first --------- Co-authored-by: sneaxiy <[email protected]>
PR types
Others
PR changes
Others
Description
Others