-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NPU] Support NPU kernel for TopKV2 op #34599
Conversation
✅ This PR's description meets the template requirements! |
Thanks for your contribution! |
@@ -0,0 +1,250 @@ | |||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2016 -> 2021
template <typename T> | ||
class TopkV2NPUKernel : public framework::OpKernel<T> { | ||
public: | ||
// Use Ascend TopKV2 operator to implement paddle TopKV2Op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this comment is not literally necessary since it is easy for the reader to understand.
public: | ||
// Use Ascend TopKV2 operator to implement paddle TopKV2Op | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
// Read message from context |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same above.
indices->mutable_data<int64_t>(context.GetPlace()); | ||
|
||
// Allocate space for output indices of Ascend topkV2 operator | ||
framework::Tensor* indices_int32 = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No related delete
?
new Tensor(framework::proto::VarType::INT32); | ||
indices_int32->Resize(output_dims); | ||
indices_int32->mutable_data<int32_t>(context.GetPlace()); | ||
VLOG(4) << "input:" << *input; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LOG the tensor may cost much time.
.AddAttr("dim", axis) | ||
.AddAttr("largest", largest) | ||
.Run(npu_stream_topkv2); | ||
VLOG(4) << "output:" << *out; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same above.
auto npu_stream_cast = | ||
context.template device_context<paddle::platform::NPUDeviceContext>() | ||
.stream(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is exactly the same stream with npu_stream_topkv2
… add_npu_op_tok_k_v2
…into add_npu_op_tok_k_v2
…into add_npu_op_tok_k_v2
…into add_npu_op_tok_k_v2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
OPs
Describe
Develop the NPU kernel for TopKV2 op and reuse the Python unit test for CPU kernel to test the new kernel.
Unit test result
data:image/s3,"s3://crabby-images/66f3f/66f3f95e1be6f766406c02f9b4201ceba386000c" alt="image"
NPU op call result
data:image/s3,"s3://crabby-images/3a062/3a062200aa6c372907b0cb0eb3777dceaebd34a1" alt="image"