Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#38 from zyfncg/drr_pass
Browse files Browse the repository at this point in the history
Fix bug of setting InsertPoint
  • Loading branch information
yuanlehome authored Oct 10, 2023
2 parents 3e166f2 + d3bc65d commit c985ec7
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 45 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/pir/drr/drr_rewrite_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ class DrrRewritePattern : public pir::RewritePattern {
}
}
}
if (max_input_op_index == -1UL) {
if (max_input_op_index == 0UL) {
VLOG(6) << "Not found producer op for (" << op_call.name() << ")";
Operation* source_patter_first_op =
src_match_ctx
Expand Down
131 changes: 87 additions & 44 deletions test/cpp/pir/pattern_rewrite/drr_fuse_linear_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,45 +52,6 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
}
};

// class FusedLinearGradPattern
// : public pir::drr::DrrPatternBase<FusedLinearGradPattern> {
// public:
// void operator()(pir::drr::DrrPatternContext *ctx) const override {
// pir::drr::SourcePattern pat = ctx->SourcePattern();
// const auto &matmul_grad = pat.Op("pd_op.matmul_grad",
// {{"transpose_x", pat.Attr("trans_x")},
// {"transpose_y", pat.Attr("trans_y")}});
// const auto &add_grad = pat.Op("pd_op.add_grad");

// add_grad({&pat.Tensor("tmp"), &pat.Tensor("bias"),
// &pat.Tensor("out_grad")},
// {&pat.Tensor("tmp_grad"), &pat.Tensor("bias_grad")});
// matmul_grad({&pat.Tensor("x"), &pat.Tensor("w"),
// &pat.Tensor("tmp_grad")},
// {&pat.Tensor("x_grad"), &pat.Tensor("w_grad")});

// // Result patterns:要替换为的子图
// pir::drr::ResultPattern res = pat.ResultPattern();
// const auto &act_attr =
// res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
// return "none";
// });

// const auto &fused_gemm_epilogue_grad =
// res.Op("pd_op.fused_gemm_epilogue_grad",
// {{{"trans_x", pat.Attr("trans_x")},
// {"trans_y", pat.Attr("trans_y")},
// {"activation_grad", act_attr}}});
// fused_gemm_epilogue_grad({&res.Tensor("x"),
// &res.Tensor("w"),
// &res.NoneTensor(),
// &res.Tensor("out_grad")},
// {&res.Tensor("x_grad"),
// &res.Tensor("w_grad"),
// &res.Tensor("bias_grad")});
// }
// };

class FusedLinearGradPattern
: public pir::drr::DrrPatternBase<FusedLinearGradPattern> {
public:
Expand Down Expand Up @@ -208,6 +169,85 @@ class FusedLinearGeluGradPattern
}
};

class FusedLinearReluGradPattern
: public pir::drr::DrrPatternBase<FusedLinearReluGradPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
pir::drr::SourcePattern pat = ctx->SourcePattern();
const auto &fused_gemm_epilogue =
pat.Op("pd_op.fused_gemm_epilogue",
{{{"trans_x", pat.Attr("trans_x1")},
{"trans_y", pat.Attr("trans_y1")},
{"activation", pat.Attr("act1")}}});
const auto &fused_gemm_epilogue_grad =
pat.Op("pd_op.fused_gemm_epilogue_grad",
{{{"trans_x", pat.Attr("trans_x2")},
{"trans_y", pat.Attr("trans_y2")},
{"activation_grad", pat.Attr("act2")}}});
const auto &fused_gemm_epilogue_grad1 =
pat.Op("pd_op.fused_gemm_epilogue_grad",
{{{"trans_x", pat.Attr("trans_x3")},
{"trans_y", pat.Attr("trans_y3")},
{"activation_grad", pat.Attr("act3")}}});
fused_gemm_epilogue(
{&pat.Tensor("x"), &pat.Tensor("w"), &pat.Tensor("bias")},
{&pat.Tensor("fuse_out"), &pat.Tensor("reserve_space")});
pat.Tensor("out") = pat.Op("pd_op.relu")(pat.Tensor("fuse_out"));

fused_gemm_epilogue_grad1({&pat.Tensor("x1"),
&pat.Tensor("w1"),
&pat.Tensor("reserve_space2"),
&pat.Tensor("out_grad")},
{&pat.Tensor("x1_grad"),
&pat.Tensor("w1_grad"),
&pat.Tensor("bias1_grad")});
pat.Tensor("relu_dx") =
pat.Op("pd_op.relu_grad")(pat.Tensor("x1"), pat.Tensor("x1_grad"));
fused_gemm_epilogue_grad({&pat.Tensor("x"),
&pat.Tensor("w"),
&pat.Tensor("reserve_space1"),
&pat.Tensor("relu_dx")},
{&pat.Tensor("x_grad"),
&pat.Tensor("w_grad"),
&pat.Tensor("bias_grad")});

pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) {
return match_ctx.Attr<std::string>("act1") == "none" &&
match_ctx.Attr<std::string>("act3") == "none";
});

pir::drr::ResultPattern res = pat.ResultPattern();
const auto &act_attr =
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
return "relu";
});
const auto &fused_gemm_epilogue_new =
res.Op("pd_op.fused_gemm_epilogue",
{{{"trans_x", pat.Attr("trans_x1")},
{"trans_y", pat.Attr("trans_y1")},
{"activation", act_attr}}});
const auto &act_grad_attr =
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
return "relu_grad";
});
const auto &fused_gemm_epilogue_grad1_new =
res.Op("pd_op.fused_gemm_epilogue_grad",
{{{"trans_x", pat.Attr("trans_x2")},
{"trans_y", pat.Attr("trans_y2")},
{"activation_grad", act_grad_attr}}});
fused_gemm_epilogue_new(
{&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")},
{&res.Tensor("out"), &res.Tensor("reserve_space3")});
fused_gemm_epilogue_grad1_new({&res.Tensor("x1"),
&res.Tensor("w1"),
&res.Tensor("reserve_space3"),
&res.Tensor("out_grad")},
{&res.Tensor("relu_dx"),
&res.Tensor("w1_grad"),
&res.Tensor("bias1_grad")});
}
};

class FusedLinearPass : public pir::Pass {
public:
FusedLinearPass() : pir::Pass("FusedLinearPass", 1) {}
Expand All @@ -217,6 +257,7 @@ class FusedLinearPass : public pir::Pass {
ps.Add(FusedLinearGradPattern().Build(context));
ps.Add(FusedLinearPattern().Build(context));
ps.Add(FusedLinearGeluGradPattern().Build(context));
ps.Add(FusedLinearReluGradPattern().Build(context));

patterns_ = pir::FrozenRewritePatternSet(std::move(ps));
return true;
Expand Down Expand Up @@ -253,9 +294,10 @@ void BuildProgram(pir::Builder &builder) { // NOLINT
matmul_op1.out(), full_bias_op1.out());
// linear 2
paddle::dialect::FullOp full_weight_op2 =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64, 64}, 1.5);
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64, 128},
1.5);
paddle::dialect::FullOp full_bias_op2 =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64}, 1.0);
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{128}, 1.0);
paddle::dialect::MatmulOp matmul_op2 =
builder.Build<paddle::dialect::MatmulOp>(add_op1.out(),
full_weight_op2.out());
Expand All @@ -265,7 +307,8 @@ void BuildProgram(pir::Builder &builder) { // NOLINT
builder.Build<paddle::dialect::ReluOp>(add_op2.out());
// linear 3
paddle::dialect::FullOp full_weight_op3 =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64, 64}, 1.5);
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{128, 64},
1.5);
paddle::dialect::FullOp full_bias_op3 =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64}, 1.0);
paddle::dialect::MatmulOp matmul_op3 =
Expand Down Expand Up @@ -315,7 +358,7 @@ void BuildProgram(pir::Builder &builder) { // NOLINT
relu_op.out(), full_weight_op3.out(), add_op3_grad.x_grad());

paddle::dialect::ReluGradOp relu_op_grad =
builder.Build<paddle::dialect::ReluGradOp>(add_op2.out(),
builder.Build<paddle::dialect::ReluGradOp>(relu_op.out(),
matmul_op3_grad.x_grad());
// backward linear 2
paddle::dialect::AddGradOp add_op2_grad =
Expand Down Expand Up @@ -353,5 +396,5 @@ TEST(DrrTest, FusedLinear) {
pm.EnableIRPrinting();

CHECK_EQ(pm.Run(&program), true);
EXPECT_EQ(program.block()->size(), 24u);
EXPECT_EQ(program.block()->size(), 22u);
}

0 comments on commit c985ec7

Please sign in to comment.