-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support npu kernel for tile op #34606
Conversation
Thanks for your contribution! |
LGTM |
|
||
namespace paddle { | ||
namespace operators { | ||
inline std::vector<int> get_repeat_times_npu( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数不需要重写,修改一下tile_op.h的判断,改为就可以了
if (platform::is_gpu_place(repeat_tensor->place()) || platform::is_npu_place(repeat_tensor->place())) {
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到,我之前以为不能改已有的代码呢
self.check_output_with_place(self.place) | ||
|
||
def test_check_grad(self): | ||
self.check_grad(['X'], 'Out') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
self.check_output_with_place(self.place) | ||
|
||
def test_check_grad(self): | ||
self.check_grad(['X'], 'Out') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没有实现tile_grad的话,把这个函数改成 pass就好了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已改
self.check_output_with_place(self.place) | ||
|
||
def test_check_grad(self): | ||
self.check_grad(['X'], 'Out') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
# Test python API | ||
class TestTileAPI(unittest.TestCase): | ||
def test_api(self): | ||
with fluid.dygraph.guard(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以改成 fluid.dygraph.guard(paddle.NPUPlace(0)): 后面 to_tensor就不需要加NPUPlace了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
"must be less than or equal to %d, but the value received is %d.", | ||
MAX_RANK_SUPPORTED, repeat_times_size)); | ||
rank = std::max(rank, repeat_times_size); | ||
switch (rank) { REP_TILE_TEMPLATE(MAX_RANK_SUPPORTED) } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个在NPU下有编译错误,需要根据最新develop分支的代码修改一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
OPs
Describe
Support npu kernel for tile op
data:image/s3,"s3://crabby-images/2f32b/2f32b573e552fc854adecf28fa4c63eab4abdd3e" alt="image"
data:image/s3,"s3://crabby-images/c3075/c3075bbb62a748dc47a1af4d09bfb52e3299378c" alt="tileop"