Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cherry-pick fused_rope from develop #55931

Merged
merged 8 commits into from
Aug 7, 2023

Conversation

AnnaTrainingG
Copy link
Contributor

PR types

Others

PR changes

Others

Description

Others

* style

* more

* update ctest

* Update legacy_backward.yaml

* Update legacy_ops.yaml

* Update legacy_ops.yaml

* update

* update

* update for move
@paddle-bot
Copy link

paddle-bot bot commented Aug 3, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@Xreki Xreki requested a review from sneaxiy August 3, 2023 02:29
out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v)
"""
if in_dynamic_mode():
return _C_ops.fused_rotary_position_embedding(q, k, v)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不支持静态图?加个assert?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

MPType div_c) {
int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int stride = gridDim.x * blockDim.x * VecSize;
int size = batch_size * seq_len * num_heads * head_dim;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用int64_t防止溢出。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

for (int nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index
int index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用int64_t

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
int numel = dout_q.numel();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用int64_t

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

MPType div_c) {
int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int stride = gridDim.x * blockDim.x * VecSize;
int size = batch_size * seq_len * num_heads * head_dim;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用int64_t

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

for (int nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index
int index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用int64_t

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

DenseTensor* out_q,
DenseTensor* out_k,
DenseTensor* out_v) {
int numel = q.numel();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用int64_t

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int numel = q.numel();
if (numel <= 0) return;
dev_ctx.template Alloc<T>(out_q);
out_q->Resize(q.dims());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个Resize是冗余的?因为在InferMeta的时候已经设置过shape了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


if (k.get_ptr()) {
dev_ctx.template Alloc<T>(out_k);
out_k->Resize(q.dims());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个Resize是冗余的?因为在InferMeta的时候已经设置过shape了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


if (v.get_ptr()) {
dev_ctx.template Alloc<T>(out_v);
out_v->Resize(q.dims());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个Resize是冗余的?因为在InferMeta的时候已经设置过shape了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

phi::Array<T*, 3> outs_data,
int num_inputs,
MPType div_c) {
int64_t index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前曾经遇到过blockIdx.x * blockDim.x + threadIdx.x计算本身越界,导致出现负数的情况,因此建议做static_cast<int64_t>(blockDim.x)后运算,下同。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@sneaxiy sneaxiy merged commit 8d3a988 into PaddlePaddle:incubate/new_frl Aug 7, 2023
hitywt pushed a commit to hitywt/Paddle that referenced this pull request Nov 20, 2023
…addlePaddle#55931)

* Add fused_rope forward op (PaddlePaddle#54351)

* style

* more

* update ctest

* Update legacy_backward.yaml

* Update legacy_ops.yaml

* Update legacy_ops.yaml

* update

* update

* update for move

* Update the rope op according to the comments (PaddlePaddle#54985)

* Update multiary.cc

* Update __init__.py

* for int64_t and assert

* more

* remove useless assert first

---------

Co-authored-by: sneaxiy <[email protected]>
hitywt pushed a commit to hitywt/Paddle that referenced this pull request Nov 22, 2023
…addlePaddle#55931)

* Add fused_rope forward op (PaddlePaddle#54351)

* style

* more

* update ctest

* Update legacy_backward.yaml

* Update legacy_ops.yaml

* Update legacy_ops.yaml

* update

* update

* update for move

* Update the rope op according to the comments (PaddlePaddle#54985)

* Update multiary.cc

* Update __init__.py

* for int64_t and assert

* more

* remove useless assert first

---------

Co-authored-by: sneaxiy <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants