Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] convert pd_op.sum to cinn_op.reduce_sum #58207

Merged
merged 30 commits into from
Oct 24, 2023

Conversation

phlrain
Copy link
Collaborator

@phlrain phlrain commented Oct 18, 2023

PR types

Others

PR changes

Others

Description

由于编译器中间部分算子定义和主框架不一样,主框架的reduce 系列,支持axis为可变attribute,但是编译器中当前不支持axis 为可变attribute,因此做了一个变换映射

对于full int array内部存储的phi::IntArray 改成vector,方便drr的适配,这个算子仅在新IR中使用,不存在兼容性问题

Pcard-67164

@paddle-bot
Copy link

paddle-bot bot commented Oct 18, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot
Copy link

paddle-bot bot commented Oct 18, 2023

❌ The PR is not created using PR's template. You can refer to this Demo.
Please use PR's template, it helps save our maintainers' time so that more developers get helped.

@phlrain phlrain changed the title update [PIR] convert pd_op.sum to cinn_op.reduce_sum Oct 19, 2023
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/dialect/operator/transforms/pd_op_to_cinn_op_convert_pass.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文件名是不是可精简为:pd_to_cinn_pass.h,因为目前我们没有Kernel层的转换需求,所以默认是Op层面的。或者pd_to_cinn_op_pass.hpd_op_to_cinn_pass.h

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

namespace dialect {
namespace ir {

class PDSum2CINNReduceSumPattern
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class PDSum2CINNReduceSumPattern
class SumOpPattern

是不是可以约定下这个命名范式,因为源Dialect是pd_op,所以我们只需要以其为命名即可?即使后续出现一对多的场景,这里的class name也会比较简洁。具体的映射到哪个算子,看code也是很清楚的。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pat = ctx->SourcePattern();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pir::drr::SourcePattern pat = ctx->SourcePattern();
pir::drr::SourcePattern patttern = ctx->SourcePattern();

pat 作为 pattern 缩写似乎比较少见?不像idx这些

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

void operator()(pir::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pat = ctx->SourcePattern();
const auto &full_int_array = pat.Op("pd_op.full_int_array",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const auto &full_int_array = pat.Op("pd_op.full_int_array",
const auto &full_int_array = pat.Op(paddle::dialect::FullIntArrayOp::name(),

我们应该尽量避免直接使用 op_name_str

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

{"place", pat.Attr("place_2")}});

const auto &sum = pat.Op(
"pd_op.sum",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}

bool PdOpToCinnOpPass::CanApplyOn(pir::Operation *op) const {
return op->name() == "builtin.module" && op->num_regions() > 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return op->name() == "builtin.module" && op->num_regions() > 0;
return op->isa<ModuleOp>() && op->num_regions() > 0;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
};

PdOpToCinnOpPass::PdOpToCinnOpPass() : pir::Pass("PdOpToCinnOpPass", 1) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

关于pass name的命名规范,之前的想法是统一使用小写+下划线的形式,我已提了统一修改之前pass的pr,见#58205 辛苦这里改一下?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Charles-hit
Charles-hit previously approved these changes Oct 20, 2023
Copy link
Contributor

@Charles-hit Charles-hit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for generated_vjp.cc.j2

@@ -41,13 +41,11 @@ void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out) {
CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out);
}

void CreateIntArrayInferMeta(const IntArray& data,
void CreateIntArrayInferMeta(const std::vector<int64_t>& shape,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CreateIntArrayInferMeta的函数名是不也需要调整下?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +70 to +78
const auto &new_perm_attr =
res.Attr([](const pir::drr::MatchContext &match_ctx) -> phi::IntArray {
auto shape =
match_ctx.Attr<std::vector<int64_t>>("expand_shape_value");

return phi::IntArray(shape);
});
const auto &full2 = res.Op("pd_op.full",
{{"shape", pat.Attr("expand_shape_value")},
{{"shape", new_perm_attr},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

full的shape参数还是vector<int64_t>类型的,这里换成int_array是用来测试吗?

Copy link
Collaborator Author

@phlrain phlrain Oct 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个类似我理解和参数表的里面的不是一样的,主要看底层的attribute map里面存的是啥, 这个full 里面底层存的是int array

Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@yuanlehome yuanlehome left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Charles-hit Charles-hit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for generated_vjp.cc.j2

@phlrain phlrain merged commit 3e677c4 into PaddlePaddle:develop Oct 24, 2023
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
* update

* update

* update

* update

* fix bug

* add test flag

* fix bug

* update

* fix cmake bug

* remove cinn_op header

* fix full int array bug

* fix vjp gene bug

* fix bug

* fix bug

* polish code

* update

* update

* update

* fix bug

* fix bug
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants