-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PIR】Add AutoLayoutPass and AutoLayoutSimplifyPass #67576
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
Sorry to inform you that 70f53ac's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
@@ -1368,6 +1368,10 @@ PHI_DEFINE_EXPORTED_bool( | |||
false, | |||
"EinsumOp backward will be speedup at the expense of more gpu memory."); | |||
|
|||
PHI_DEFINE_EXPORTED_bool(enable_auto_layout_pass, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
if (gpu_pass == "transfer_layout_pass" && | ||
config_.autolayout_enabled()) | ||
continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样的逻辑有些trick,而且不止一处, 不了解背景的开发者看到会感到迷茫,最好想办法优化下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
拆分pr,关于推理的flag拆除一个独立pr进行处理
void TransferLayout(pir::Builder builder, pir::Block* block) { | ||
for (auto& op_item : *block) { | ||
auto op = &op_item; | ||
auto op_name = op->name(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const auto&
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
由于DoTransposeOpOperand、RewriteLayout、DoTransposeOpResult,中要求 pir::Operation*
因此优化为 auto&&
if (op_name == "builtin.parameter" || op_name == "pd_op.feed" || | ||
op_name == "builtin.shadow_output") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的判断可以看看 ImmutableLayoutTrait
这个trait
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确实可以使用 ImmutableLayoutTrait 辅助判断
RewriteLayout(op, op->operands_source()); | ||
DoTransposeOpResult(op, builder); | ||
} | ||
} else if (IsInsertTransposeOpBefore(op)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个else分支怎么理解?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True: 当前 Op 不是 NHWC 并且该 Op 的 operand 中至少一个被插入了 transpose 而为 NCHW。但是该 OP 对 layout 不敏感,可以运行在 NHWC 下,因此需在其前插入 Transpose 以消除 NCHW 的 Transpose
if (op->HasAttribute("data_format")) { | ||
op->set_attribute("data_format", pir::StrAttribute::get(ctx_, "NHWC")); | ||
} | ||
auto p_attribute_map = op->attributes(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const auto&
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
// Skip the operand which is not dense tensor or not 4-D tensor, they don't | ||
// need transpose. | ||
bool JudgeValue(const pir::Value& value) { | ||
if (auto type = value.type().dyn_cast<paddle::dialect::DenseTensorType>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最好先判断下value和value.type()是否为空
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
builder.SetInsertionPointAfter(op); | ||
for (auto& result : op->results()) { | ||
// Canbe optimize with cache when not eliminate the transpose op. | ||
if (!JudgeValue(result)) continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以再加个过滤条件,value的user为空时就不插transpose了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
bool Match(paddle::dialect::TransposeOp op) const override { | ||
auto before_transpose = op.x().defining_op(); | ||
if (before_transpose->dyn_cast<paddle::dialect::TransposeOp>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以写成
if(!before_transpose->isa<paddle::dialect::TransposeOp>()) {
return false;
}
从而避免后面大量代码的缩进
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
std::vector<int32_t> before_perm; | ||
for (size_t i = 0; i < before_perm_attr.size(); ++i) { | ||
auto attr = before_perm_attr.at(i); | ||
before_perm.push_back(attr.dyn_cast<pir::Int32Attribute>().data()); | ||
} | ||
|
||
const auto after_perm_attr = op.attribute<pir::ArrayAttribute>("perm"); | ||
std::vector<int32_t> after_perm; | ||
for (size_t i = 0; i < after_perm_attr.size(); ++i) { | ||
auto attr = after_perm_attr.at(i); | ||
after_perm.push_back(attr.dyn_cast<pir::Int32Attribute>().data()); | ||
} | ||
|
||
if (before_perm[0] == after_perm[0] && before_perm[1] == after_perm[3] && | ||
before_perm[2] == after_perm[1] && before_perm[3] == after_perm[2] && | ||
before_perm == NCHW2NHWC_) { | ||
return true; | ||
} | ||
|
||
if (before_perm[0] == after_perm[0] && before_perm[1] == after_perm[2] && | ||
before_perm[2] == after_perm[3] && before_perm[3] == after_perm[1] && | ||
before_perm == NHWC2NCHW_) { | ||
return true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的判断逻辑是不可以化简下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议对这layout的调整和冗余transpose的化简分别构造一个单测case,来确保功能始终是正常的。
paddle/common/flags.cc
Outdated
/** | ||
* Performance related FLAG | ||
* Name: enable_auto_layout_pass | ||
* Since Version: 2.6.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Since Version: 2.6.0 | |
* Since Version: 3.0.0 |
准确的讲,这里应该写下一个要发布的正式版本
#include <unordered_set> | ||
|
||
#include "paddle/common/enforce.h" | ||
#include "paddle/common/errors.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
只需包含enforce.h即可,不需要显式包含errors.h
|
||
class AutoLayoutPass : public pir::Pass { | ||
public: | ||
AutoLayoutPass() : pir::Pass("auto_layout_pass", 1) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// opt_level=1: constant fold, cse, memory optimize, etc.
// opt_level=2: the fusion logical pass.
// opt_level=3: layout, etc.
// opt_level=4: the radical optimization, maybe affect precision, etc.
uint8_t opt_level;
see opt_level in paddle/pir/include/pass/pass.h
AutoLayoutPass() : pir::Pass("auto_layout_pass", 1) {} | |
AutoLayoutPass() : pir::Pass("auto_layout_pass", 3) {} |
builder.set_insertion_point(op); | ||
|
||
// For conv2d, only transpose the input. | ||
if (op->name() == "pd_op.conv2d") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一般我们更建议用 op->isa()
… autolayoutpass
… autolayoutpass
Sorry to inform you that b871d4a's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
… autolayoutpass
PR Category
Performance Optimization
PR Types
Performance
Description
The two passes are used to greedily insert transpose into the IR and remove consecutive redundant transposes to ensure that the Operation works on their appropriate Layout.
Pcard-67164