-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PIR] convert pd_op.sum to cinn_op.reduce_sum #58207
[PIR] convert pd_op.sum to cinn_op.reduce_sum #58207
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
❌ The PR is not created using PR's template. You can refer to this Demo. |
… replace_sum_to_cinn_reduce_sum
… replace_sum_to_cinn_reduce_sum
… replace_sum_to_cinn_reduce_sum
… replace_sum_to_cinn_reduce_sum
… replace_sum_to_cinn_reduce_sum
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/cinn/hlir/dialect/operator/transforms/pd_op_to_cinn_op_convert_pass.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件名是不是可精简为:pd_to_cinn_pass.h
,因为目前我们没有Kernel层的转换需求,所以默认是Op层面的。或者pd_to_cinn_op_pass.h
、pd_op_to_cinn_pass.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
namespace dialect { | ||
namespace ir { | ||
|
||
class PDSum2CINNReduceSumPattern |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class PDSum2CINNReduceSumPattern | |
class SumOpPattern |
是不是可以约定下这个命名范式,因为源Dialect是pd_op,所以我们只需要以其为命名即可?即使后续出现一对多的场景,这里的class name也会比较简洁。具体的映射到哪个算子,看code也是很清楚的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
public: | ||
void operator()(pir::drr::DrrPatternContext *ctx) const override { | ||
// Source Pattern | ||
pir::drr::SourcePattern pat = ctx->SourcePattern(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pir::drr::SourcePattern pat = ctx->SourcePattern(); | |
pir::drr::SourcePattern patttern = ctx->SourcePattern(); |
pat 作为 pattern 缩写似乎比较少见?不像idx这些
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
void operator()(pir::drr::DrrPatternContext *ctx) const override { | ||
// Source Pattern | ||
pir::drr::SourcePattern pat = ctx->SourcePattern(); | ||
const auto &full_int_array = pat.Op("pd_op.full_int_array", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const auto &full_int_array = pat.Op("pd_op.full_int_array", | |
const auto &full_int_array = pat.Op(paddle::dialect::FullIntArrayOp::name(), |
我们应该尽量避免直接使用 op_name_str
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
{"place", pat.Attr("place_2")}}); | ||
|
||
const auto &sum = pat.Op( | ||
"pd_op.sum", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
} | ||
|
||
bool PdOpToCinnOpPass::CanApplyOn(pir::Operation *op) const { | ||
return op->name() == "builtin.module" && op->num_regions() > 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return op->name() == "builtin.module" && op->num_regions() > 0; | |
return op->isa<ModuleOp>() && op->num_regions() > 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
… replace_sum_to_cinn_reduce_sum
} | ||
}; | ||
|
||
PdOpToCinnOpPass::PdOpToCinnOpPass() : pir::Pass("PdOpToCinnOpPass", 1) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
关于pass name的命名规范,之前的想法是统一使用小写+下划线的形式,我已提了统一修改之前pass的pr,见#58205 辛苦这里改一下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for generated_vjp.cc.j2
paddle/phi/infermeta/nullary.cc
Outdated
@@ -41,13 +41,11 @@ void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out) { | |||
CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out); | |||
} | |||
|
|||
void CreateIntArrayInferMeta(const IntArray& data, | |||
void CreateIntArrayInferMeta(const std::vector<int64_t>& shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CreateIntArrayInferMeta
的函数名是不也需要调整下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
const auto &new_perm_attr = | ||
res.Attr([](const pir::drr::MatchContext &match_ctx) -> phi::IntArray { | ||
auto shape = | ||
match_ctx.Attr<std::vector<int64_t>>("expand_shape_value"); | ||
|
||
return phi::IntArray(shape); | ||
}); | ||
const auto &full2 = res.Op("pd_op.full", | ||
{{"shape", pat.Attr("expand_shape_value")}, | ||
{{"shape", new_perm_attr}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个类似我理解和参数表的里面的不是一样的,主要看底层的attribute map里面存的是啥, 这个full 里面底层存的是int array
… replace_sum_to_cinn_reduce_sum
… replace_sum_to_cinn_reduce_sum
… replace_sum_to_cinn_reduce_sum
… replace_sum_to_cinn_reduce_sum
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for generated_vjp.cc.j2
* update * update * update * update * fix bug * add test flag * fix bug * update * fix cmake bug * remove cinn_op header * fix full int array bug * fix vjp gene bug * fix bug * fix bug * polish code * update * update * update * fix bug * fix bug
PR types
Others
PR changes
Others
Description
由于编译器中间部分算子定义和主框架不一样,主框架的reduce 系列,支持axis为可变attribute,但是编译器中当前不支持axis 为可变attribute,因此做了一个变换映射
对于full int array内部存储的phi::IntArray 改成vector,方便drr的适配,这个算子仅在新IR中使用,不存在兼容性问题
Pcard-67164