Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] convert pd_op.sum to cinn_op.reduce_sum #58207

Merged
merged 30 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f1c39dc
update
phlrain Oct 18, 2023
f2aca1b
update
phlrain Oct 18, 2023
31b30c9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 18, 2023
864097c
update
phlrain Oct 18, 2023
a0f63ea
update
phlrain Oct 18, 2023
e474160
fix bug
phlrain Oct 18, 2023
42e02c0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 18, 2023
dfedf62
add test flag
phlrain Oct 18, 2023
a9bf79a
fix bug
phlrain Oct 18, 2023
f5387a7
update
phlrain Oct 19, 2023
9386a5f
fix cmake bug
phlrain Oct 19, 2023
25f3567
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 19, 2023
47e139f
remove cinn_op header
phlrain Oct 19, 2023
71f6215
fix full int array bug
phlrain Oct 19, 2023
109f99a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 19, 2023
25d1a18
fix vjp gene bug
phlrain Oct 19, 2023
0813092
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 19, 2023
8eb71d0
fix bug
phlrain Oct 19, 2023
8785aba
fix bug
phlrain Oct 20, 2023
e1c59ec
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 20, 2023
5edc3fa
polish code
phlrain Oct 23, 2023
6254224
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 23, 2023
433d009
update
phlrain Oct 23, 2023
319a113
update
phlrain Oct 23, 2023
b62c952
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 23, 2023
e0b3fef
update
phlrain Oct 23, 2023
9288d2a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 23, 2023
2bcd6e1
fix bug
phlrain Oct 23, 2023
6e162d7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 23, 2023
e7ee536
fix bug
phlrain Oct 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,13 @@ if(NOT CINN_ONLY)
pd_op_dialect
pir_compiler
cinn_runtime_dialect)

cinn_cc_library(
pd_to_cinn_pass
SRCS
pd_to_cinn_pass.cc
DEPS
drr
cinn_op_dialect
pd_op_dialect)
endif()
113 changes: 113 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h"

#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/drr_pattern_base.h"
#include "paddle/fluid/pir/drr/api/match_context.h"
#include "paddle/pir/core/builtin_dialect.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_manager.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace cinn {
namespace dialect {
namespace ir {

class SumOpPattern : public pir::drr::DrrPatternBase<SumOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern patttern = ctx->SourcePattern();
const auto &full_int_array =
patttern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", patttern.Attr("axis_info")},
{"dtype", patttern.Attr("dtype_2")},
{"place", patttern.Attr("place_2")}});

const auto &sum = patttern.Op(paddle::dialect::SumOp::name(),
{{"dtype", patttern.Attr("dtype")},
{"keepdim", patttern.Attr("keep_dim")}});
patttern.Tensor("ret") = sum(patttern.Tensor("arg0"), full_int_array());

// Result patterns
pir::drr::ResultPattern res = patttern.ResultPattern();
const auto &cinn_reduce_sum =
res.Op(cinn::dialect::ReduceSumOp::name(),
{{"axis", patttern.Attr("axis_info")},
{"keep_dim", patttern.Attr("keep_dim")}});
res.Tensor("ret") = cinn_reduce_sum(res.Tensor("arg0"));
}
};

class MaxOpPattern : public pir::drr::DrrPatternBase<MaxOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern patttern = ctx->SourcePattern();
const auto &full_int_array =
patttern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", patttern.Attr("axis_info")},
{"dtype", patttern.Attr("dtype_2")},
{"place", patttern.Attr("place_2")}});

const auto &pd_max = patttern.Op(paddle::dialect::MaxOp::name(),
{{"keepdim", patttern.Attr("keep_dim")}});
patttern.Tensor("ret") = pd_max(patttern.Tensor("arg0"), full_int_array());

// Result patterns
pir::drr::ResultPattern res = patttern.ResultPattern();
const auto &cinn_reduce_max =
res.Op(cinn::dialect::ReduceMaxOp::name(),
{{"axis", patttern.Attr("axis_info")},
{"keep_dim", patttern.Attr("keep_dim")}});
res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0"));
}
};

PdOpToCinnOpPass::PdOpToCinnOpPass() : pir::Pass("pd_to_cinn_pass", 1) {}

bool PdOpToCinnOpPass::Initialize(pir::IrContext *context) {
pir::RewritePatternSet ps(context);
ps.Add(SumOpPattern().Build(context));
ps.Add(MaxOpPattern().Build(context));

patterns_ = ::pir::FrozenRewritePatternSet(std::move(ps));
return true;
}

void PdOpToCinnOpPass::Run(pir::Operation *op) {
pir::GreedyRewriteConfig cfg;
cfg.use_top_down_traversal = true;
cfg.max_iterations = 10;
pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg);
}

bool PdOpToCinnOpPass::CanApplyOn(pir::Operation *op) const {
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
}

void PdOp2CinnOpConverter(::pir::Program *program) {
pir::IrContext *ctx = pir::IrContext::Instance();

pir::PassManager pm(ctx);
pm.AddPass(std::make_unique<PdOpToCinnOpPass>());

pm.Run(program);
}
} // namespace ir
} // namespace dialect
} // namespace cinn
43 changes: 43 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/pir/core/program.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"

namespace cinn {
namespace dialect {
namespace ir {

class PdOpToCinnOpPass : public pir::Pass {
public:
PdOpToCinnOpPass();

bool Initialize(pir::IrContext *context) override;

void Run(pir::Operation *op) override;

bool CanApplyOn(pir::Operation *op) const override;

private:
pir::FrozenRewritePatternSet patterns_;
};

void PdOp2CinnOpConverter(::pir::Program *program);

} // namespace ir
} // namespace dialect
} // namespace cinn
8 changes: 2 additions & 6 deletions paddle/cinn/hlir/framework/pir/op_mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@ void AppendAttrForReduceOp(const ::pir::Operation& op,
auto* source_op =
op.operand_source(/*dim_idx=*/1).dyn_cast<::pir::OpResult>().owner();
CHECK(source_op->isa<paddle::dialect::FullIntArrayOp>());
const std::vector<int64_t>& dim_val =
source_op->attributes()
.at("value")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData();
auto dim_val =
paddle::dialect::GetInt64Vector(source_op->attributes().at("value"));
std::vector<int> dim(dim_val.begin(), dim_val.end());
attrs["dim"] = dim;
}
Expand Down
3 changes: 1 addition & 2 deletions paddle/cinn/hlir/framework/pir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#pragma once
#include <string>
#include <unordered_map>
#include "paddle/cinn/backends/compiler.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/common/type.h"
#include "paddle/cinn/utils/type_defs.h"
Expand All @@ -30,7 +29,7 @@ struct CUDAJITInfo {
void* fn_ptr;
std::vector<int> block_dims;
std::vector<int> grid_dims;
backends::Compiler* compiler;
void* compiler;
};

struct CompatibleInfo {
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/pir_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ std::vector<pir::CUDAJITInfo> PIRCompiler::BuildCUDAJITInfo(
for (int idx = 0; idx < groups.size(); ++idx) {
pir::CUDAJITInfo jit_info;
jit_info.fn_ptr = fn_ptrs[idx];
jit_info.compiler = compilter_ptr;
jit_info.compiler = reinterpret_cast<void*>(compilter_ptr);

lowered_funcs[idx][0]->cuda_axis_info.CopyBlockDimsTo(
&(jit_info.block_dims));
Expand Down
16 changes: 6 additions & 10 deletions paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,12 +358,10 @@ def GenBuildOutputs(
"""
CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::IntArray {name};
if ({name}_.dyn_cast<pir::OpResult>().owner()->isa<paddle::dialect::FullIntArrayOp>()) {{
{name} = std::move(phi::IntArray({name}_.dyn_cast<pir::OpResult>().owner()
{name} = std::move(phi::IntArray(paddle::dialect::GetInt64Vector(
{name}_.dyn_cast<pir::OpResult>().owner()
->dyn_cast<paddle::dialect::FullIntArrayOp>()
.attribute("value")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData()));
.attribute("value"))));
}} else if ({name}_.type().isa<pir::VectorType>()) {{
size_t {name}_size = {name}_.type().dyn_cast<pir::VectorType>().size();
{name} = std::move(phi::IntArray(std::vector<int64_t>({name}_size, -1)));
Expand All @@ -378,12 +376,10 @@ def GenBuildOutputs(

CREATE_VECTOR_INT_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ std::vector<int64_t> {name};
if ({name}_.dyn_cast<pir::OpResult>().owner()->isa<paddle::dialect::FullIntArrayOp>()) {{
{name} = {name}_.dyn_cast<pir::OpResult>().owner()
{name} = paddle::dialect::GetInt64Vector(
{name}_.dyn_cast<pir::OpResult>().owner()
->dyn_cast<paddle::dialect::FullIntArrayOp>()
.attribute("value")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData();
.attribute("value"));
}} else if ({name}_.type().isa<pir::VectorType>()) {{
size_t {name}_size = {name}_.type().dyn_cast<pir::VectorType>().size();
{name} = std::vector<int64_t>({name}_size, -1);
Expand Down
21 changes: 17 additions & 4 deletions paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
CPP_FILE_TEMPLATE = """
#include "paddle/fluid/pir/drr/ir_operation_factory.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
{op_header}
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"

namespace pir {{
namespace drr {{

void OperationFactory::RegisterGeneratedOpCreator() {{
void OperationFactory::Register{dialect}GeneratedOpCreator() {{
{body}
}}

Expand All @@ -41,7 +41,7 @@
[](const std::vector<Value>& inputs,
const pir::AttributeMap& attrs,
pir::PatternRewriter& rewriter) {{
return rewriter.Build<paddle::dialect::{op_class_name}>(
return rewriter.Build<{namespace}::{op_class_name}>(
{params_code});
}});
"""
Expand All @@ -63,6 +63,12 @@
}});
"""

Dialect2NameSpaceMap = {"pd_op": "paddle::dialect", "cinn_op": "cinn::dialect"}
Dialect2OpHeaderMap = {
"pd_op": "#include \"paddle/fluid/pir/dialect/operator/ir/pd_op.h\"",
"cinn_op": "#include \"paddle/cinn/hlir/dialect/operator/ir/cinn_op.h\"",
}


class OpCreatorCodeGen:
def __init__(self, op_yaml_files, op_compat_yaml_file, dialect_name):
Expand Down Expand Up @@ -107,6 +113,7 @@ def gen_cpp_file_code(self, cpp_file_path):
if len(op_info_item.mutable_attribute_name_list) == 0:
body_code += NORMAL_FUNCTION_TEMPLATE.format(
op_name=ir_op_name,
namespace=Dialect2NameSpaceMap[self.dialect_name],
op_class_name=(to_pascal_case(phi_op_name) + "Op"),
params_code=", ".join(params_no_mutable_attr),
)
Expand Down Expand Up @@ -139,7 +146,13 @@ def gen_cpp_file_code(self, cpp_file_path):
)

with open(cpp_file_path, 'w') as f:
f.write(CPP_FILE_TEMPLATE.format(body=body_code))
f.write(
CPP_FILE_TEMPLATE.format(
dialect=to_pascal_case(self.dialect_name),
op_header=Dialect2OpHeaderMap[self.dialect_name],
body=body_code,
)
)


def ParseArguments():
Expand Down
19 changes: 19 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,5 +203,24 @@ bool IsEmptyValue(const pir::Value& value) {
return !value.impl() || !value.type();
}

std::vector<int64_t> GetInt64Vector(const pir::Attribute& attr) {
PADDLE_ENFORCE_EQ(attr.isa<pir::ArrayAttribute>(),
true,
phi::errors::PreconditionNotMet(
"attribute MUST be a pir::ArrayAttribute"));
auto attr_vec = attr.dyn_cast<pir::ArrayAttribute>().AsVector();

std::vector<int64_t> vec_int64;
for (auto vec_element : attr_vec) {
PADDLE_ENFORCE_EQ(
vec_element.isa<pir::Int64Attribute>(),
true,
phi::errors::PreconditionNotMet("element MUST be a Int64Attribute"));
vec_int64.push_back(vec_element.dyn_cast<pir::Int64Attribute>().data());
}

return vec_int64;
}

} // namespace dialect
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,7 @@ bool IsLegacyOp(const std::string& name);

bool IsEmptyValue(const pir::Value& value);

std::vector<int64_t> GetInt64Vector(const pir::Attribute& attr);

} // namespace dialect
} // namespace paddle
36 changes: 34 additions & 2 deletions paddle/fluid/pir/drr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ set(fused_op_backward_yaml_file
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_backward.parsed.yaml
)

set(cinn_op_yaml_file
${PADDLE_SOURCE_DIR}/paddle/cinn/hlir/dialect/generated/ops.parsed.yaml)

set(cinn_op_yaml_source_file
${PADDLE_SOURCE_DIR}/paddle/cinn/hlir/dialect/operator/ir/ops.yaml)

set(parsed_op_dir
${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/generated)

Expand All @@ -39,6 +45,12 @@ set(op_creator_file_tmp ${op_creator_file}.tmp)

set(dialect_name pd_op)

set(cinn_op_creator_file
${PADDLE_BINARY_DIR}/paddle/fluid/pir/drr/cinn_op_factory_generated.cc)
set(cinn_op_creator_file_tmp ${cinn_op_creator_file}.tmp)

set(cinn_dialect_name cinn_op)

add_custom_command(
OUTPUT ${op_creator_file}
COMMAND
Expand All @@ -59,7 +71,27 @@ add_custom_command(
pd_op_dialect_op
VERBATIM)

if(WITH_CINN AND NOT CINN_ONLY)
add_custom_command(
OUTPUT ${cinn_op_creator_file}
COMMAND
${PYTHON_EXECUTABLE} ${op_creator_gen_file} --op_yaml_files
${cinn_op_yaml_file} --op_compat_yaml_file ${op_compat_yaml_file}
--dialect_name ${cinn_dialect_name} --op_creator_file
${cinn_op_creator_file_tmp}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${cinn_op_creator_file_tmp}
${cinn_op_creator_file}
COMMENT "copy_if_different ${cinn_op_creator_file}"
DEPENDS ${op_creator_gen_file} ${op_compat_yaml_file}
${cinn_op_yaml_source_file} pd_op_dialect_op cinn_op_dialect
VERBATIM)
set(CINN_SOURCE_FILE ${cinn_op_creator_file})

set(CINN_DEPS cinn_op_dialect)

endif()

cc_library(
drr
SRCS ${DRR_SRCS} ${op_creator_file}
DEPS pd_op_dialect pir)
SRCS ${DRR_SRCS} ${op_creator_file} ${CINN_SOURCE_FILE}
DEPS pd_op_dialect ${CINN_DEPS} pir)
Loading