Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PRIM][IR] Migrate vjp rules to new ir in non primitive mode #55647

Merged
merged 45 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
f630e33
[prim][newir] add basic framework for primitive
cxxly Jul 20, 2023
c8bd625
support desctensor in new ir
Charles-hit Jul 24, 2023
5612359
add vjp interface
zhangbo9674 Jul 24, 2023
4d8079f
Merge commit 'refs/pull/55660/head' of https://github.com/PaddlePaddl…
Charles-hit Jul 25, 2023
fe5605b
support vjp in new ir
Charles-hit Jul 25, 2023
f9389ec
support vjp in new ir
Charles-hit Jul 26, 2023
67cf1fc
polish vjp interface
Charles-hit Jul 27, 2023
35f867b
fix stop_gradients set
Charles-hit Jul 27, 2023
5fe88d5
resolve conflict
Charles-hit Jul 27, 2023
703c168
fix vjp dispatch
Charles-hit Jul 27, 2023
0738201
add comment
Charles-hit Jul 27, 2023
d49d38a
add vjp test for new ir
Charles-hit Jul 27, 2023
a9e9d01
add test for tanh vjp
Charles-hit Jul 27, 2023
4df18b5
[prim][newir] add basic framework for primitive
cxxly Jul 20, 2023
5a65b50
support desctensor in new ir
Charles-hit Jul 24, 2023
5a3710a
support vjp in new ir
Charles-hit Jul 25, 2023
c035675
support vjp in new ir
Charles-hit Jul 26, 2023
a9b8240
polish vjp interface
Charles-hit Jul 27, 2023
901352c
fix stop_gradients set
Charles-hit Jul 27, 2023
de4ac55
fix vjp dispatch
Charles-hit Jul 27, 2023
f3da449
add comment
Charles-hit Jul 27, 2023
84b92dd
add vjp test for new ir
Charles-hit Jul 27, 2023
690a0b9
add test for tanh vjp
Charles-hit Jul 27, 2023
4ee2d44
add eager and static backend for warp lower level api
cxxly Jul 28, 2023
866dc2c
support call_vjp pybind
Charles-hit Jul 28, 2023
0d3d7d6
support call_vjp pybind
Charles-hit Jul 28, 2023
dc3e7be
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Charles-hit Aug 1, 2023
b4579f2
polish code and add test for vjp
Charles-hit Aug 2, 2023
be05029
remove useless code
Charles-hit Aug 2, 2023
619bcd0
polish code
Charles-hit Aug 2, 2023
e57d1f0
remove useless code
Charles-hit Aug 2, 2023
ac8b2a6
support mean vjp
Charles-hit Aug 3, 2023
5612b2f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Charles-hit Aug 3, 2023
afcb454
add test for mean vjp and support has_vjp function
Charles-hit Aug 3, 2023
40d7ab0
fix call_vjp
Charles-hit Aug 3, 2023
d9a78f6
polish code
Charles-hit Aug 4, 2023
ed442ff
add primitive ops set for backend
Charles-hit Aug 4, 2023
95efe5e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Charles-hit Aug 4, 2023
f802b36
add vjp test for tanh_
Charles-hit Aug 6, 2023
820b313
fix inference CI
Charles-hit Aug 7, 2023
4f320f0
fix inference ci
Charles-hit Aug 7, 2023
fe1b035
modify fluid cmake
Charles-hit Aug 7, 2023
587bea0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Charles-hit Aug 7, 2023
c155302
remove useless deps
Charles-hit Aug 7, 2023
d4f37b2
add cmake
Charles-hit Aug 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ add_subdirectory(prim)
add_subdirectory(jit)
add_subdirectory(ir)
add_subdirectory(ir_adaptor)
add_subdirectory(primitive)
# NOTE: please add subdirectory inference at last.
add_subdirectory(inference)
3 changes: 3 additions & 0 deletions paddle/fluid/framework/type_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/raw_tensor.h"
#include "paddle/fluid/framework/string_array.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/fluid/primitive/type/desc_tensor.h"

namespace phi {

Expand All @@ -40,6 +41,8 @@ template class TypeInfoTraits<phi::TensorBase, paddle::framework::Strings>;
template class TypeInfoTraits<phi::TensorBase, paddle::framework::FeedList>;
template class TypeInfoTraits<phi::TensorBase, egr::VariableCompatTensor>;
template class TypeInfoTraits<phi::TensorBase, paddle::prim::DescTensor>;
template class TypeInfoTraits<phi::TensorBase,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle.prim.DescTensor 之后是会删除吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的 目前属于过渡阶段,后续过渡完成后会移除旧的组合算子体系

paddle::primitive::experimental::DescTensor>;
template class TypeInfoTraits<phi::TensorBase,
paddle::framework::VariableRefArray>;

Expand Down
9 changes: 8 additions & 1 deletion paddle/fluid/ir/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,12 @@ file(GLOB PD_DIALECT_SRCS "*.cc")
cc_library(
pd_dialect
SRCS ${PD_DIALECT_SRCS} ${op_source_file}
DEPS framework_proto phi phi_utils pd_interface pd_trait ir)
DEPS framework_proto
phi
phi_utils
pd_interface
pd_trait
ir
primitive_vjp_experimental
type_info)
target_include_directories(pd_dialect PRIVATE ${PD_DIALECT_BINARY_DIR})
21 changes: 20 additions & 1 deletion paddle/fluid/ir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@

import yaml
from op_build_gen import gen_build_func_str
from op_interface_gen import gen_exclusive_interface_str, gen_op_infer_meta_str
from op_interface_gen import (
gen_exclusive_interface_str,
gen_op_infer_meta_str,
vjp_interface_gen_op_list,
)
from op_member_func_gen import gen_op_get_inputs_outputs_str
from op_verify_gen import gen_verify_func_str

Expand All @@ -43,6 +47,7 @@
#include "paddle/fluid/ir/dialect/op_yaml_info_util.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir/interface/infermeta.h"
#include "paddle/fluid/ir/interface/vjp.h"
#include "paddle/fluid/ir/trait/inplace.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
Expand Down Expand Up @@ -303,6 +308,9 @@ def __init__(self, op_yaml_item, op_compat_item):
else:
self.infer_meta_func = None

# parse backward name
self.backward_name = self.parse_backward_name()

# parse inplace && view
self.inplace_map = self.parse_op_inplace_info()
self.view_map = self.parse_op_view_info()
Expand Down Expand Up @@ -612,6 +620,12 @@ def parse_kernel_map(self):
else:
return None

def parse_backward_name(self):
if 'backward' in self.op_yaml_item:
return self.op_yaml_item['backward']
else:
return None

def get_phi_dtype_name(self, name):
name = name.replace('Scalar', 'phi::Scalar')
name = name.replace('IntArray', 'phi::IntArray')
Expand Down Expand Up @@ -720,6 +734,11 @@ def OpGenerator(
if op_info.infer_meta_func:
op_interfaces += ["InferMetaInterface"]

if (
op_info.backward_name
and op_info.op_phi_name[0] in vjp_interface_gen_op_list
):
op_interfaces += ["VjpInterface"]
exclusive_interface_str = gen_exclusive_interface_str(op_info)

# If op has inplace info, we will generate inplace op and non-inplace op.
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/ir/dialect/op_generator/op_interface_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

# generator interfaces
from vjp_interface_gen_op_list import vjp_interface_gen_op_list

OP_INFER_SHAPE_TEMPLATE = """
void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{
Expand All @@ -38,4 +39,6 @@ def gen_exclusive_interface_str(op_info):
exclusive_interface_str += (
" static void InferMeta( phi::InferMetaContext *infer_meta );"
)
if op_info.op_phi_name[0] in vjp_interface_gen_op_list:
exclusive_interface_str += "\n static std::vector<std::vector<ir::OpResult>> Vjp(ir::Operation* op, const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<int>>& stop_gradients);"
return exclusive_interface_str
24 changes: 24 additions & 0 deletions paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# =====================================
# VjpInterface gen op list
# =====================================
# we don't support vjp function code
# gen now, so we use a whitelist to
# control the generation of Vjp methods.
# TODO(wanghao107)
# remove this file and support Vjp methods
# code gen.
vjp_interface_gen_op_list = ["tanh", "mean"]
18 changes: 18 additions & 0 deletions paddle/fluid/ir/dialect/pd_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,23 @@ ir::OpResult full(std::vector<int64_t> shape,
return full_op.out();
}

ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out) {
paddle::dialect::TanhGradOp tanh_grad_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::TanhGradOp>(
out, grad_out);
return tanh_grad_op.result(0);
}

ir::OpResult mean_grad(ir::OpResult x,
ir::OpResult out_grad,
std::vector<int64_t> axis,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::vector<int64_t> axis,
const std::vector<int64_t>& axis,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢,这儿会在下个PR进行统一修改

bool keepdim,
bool reduce_all) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的reduce_all 参数是上次提到的无用参数么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,这儿跟动态图保持一致,直接在前向传一个false过来了。

paddle::dialect::MeanGradOp mean_grad_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::MeanGradOp>(
x, out_grad, axis, keepdim, reduce_all);
return mean_grad_op.result(0);
}

} // namespace dialect
} // namespace paddle
7 changes: 7 additions & 0 deletions paddle/fluid/ir/dialect/pd_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,12 @@ ir::OpResult full(std::vector<int64_t> shape,
phi::DataType dtype = phi::DataType::FLOAT32,
phi::Place place = phi::CPUPlace());

ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out);

ir::OpResult mean_grad(ir::OpResult x,
ir::OpResult out_grad,
std::vector<int64_t> axis = {},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该是 const &?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢,这儿会在下个PR进行统一修改

bool keepdim = false,
bool reduce_all = false);
} // namespace dialect
} // namespace paddle
3 changes: 3 additions & 0 deletions paddle/fluid/ir/dialect/pd_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ class APIBuilder {
ctx_ = ir::IrContext::Instance();
ctx_->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
}

APIBuilder(const APIBuilder&) = delete;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ir/core/macros.h 里有 DISABLE_COPY_AND_ASSIGN 宏可以直接用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢,这儿会在下个PR进行统一修改


ir::IrContext* ctx_;
std::shared_ptr<ir::Builder> builder_;
};
Expand Down
101 changes: 101 additions & 0 deletions paddle/fluid/ir/dialect/pd_op_vjp_manual.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/primitive/type/desc_tensor.h"
#include "paddle/ir/core/op_base.h"

// TODO(wanghao107)
// this file will be generated in pd_op.cc

namespace paddle {
namespace dialect {
std::vector<std::vector<ir::OpResult>> TanhOp::Vjp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

外层用small_vector

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这儿会在下个PR进行修改,目前先合入一版让大家工作可以正常推进

ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stop_gradients 信息是用 int 来存储表示的么?为什么不是bool呢?

Copy link
Contributor Author

@Charles-hit Charles-hit Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vector< bool > 不可以取地址并且访问速度较慢,但是这儿为了语意清晰可以下一个PR改成bool

TanhOp op_obj = op->dyn_cast<TanhOp>();
Tensor out(
std::make_shared<primitive::experimental::DescTensor>(op_obj.out()));
Tensor grad_out(
std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0]));
std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::tanh_vjp(out, grad_out, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
if (!stop_gradients[0][0]) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只有在 if 分支里才会用到 tensor_res ? 那为什么不把tensor_res 这一行放到 if 里面呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这儿如果放到if里面会把上层剪枝逻辑带到下层来,不会去构建反向op。
从设计角度来说,剪枝逻辑放到同一个地方比较合适。

res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>(
tensor_res[0][0].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
return res;
}

std::vector<std::vector<ir::OpResult>> Tanh_Op::Vjp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这儿会在下个PR进行修改,目前先合入一版让大家工作可以正常推进

ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients) {
// TODO(wanghao107)
// we don't support inplace now,
// so use the non-inplace version instead currently.
// Support inplace in the future.
Tanh_Op op_obj = op->dyn_cast<Tanh_Op>();
Tensor out(
std::make_shared<primitive::experimental::DescTensor>(op_obj.out()));
Tensor grad_out(
std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0]));
std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::tanh_vjp(out, grad_out, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
if (!stop_gradients[0][0]) {
res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>(
tensor_res[0][0].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
return res;
}

std::vector<std::vector<ir::OpResult>> MeanOp::Vjp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这儿会在下个PR进行修改,目前先合入一版让大家工作可以正常推进

ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients) {
MeanOp op_obj = op->dyn_cast<MeanOp>();
Tensor x(std::make_shared<primitive::experimental::DescTensor>(op_obj.x()));
Tensor out_grad(
std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0]));

std::vector<int64_t> axis =
op->attribute("axis")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData();
bool keepdim = op->attribute("keepdim").dyn_cast<ir::BoolAttribute>().data();
bool reduce_all = false;
std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::mean_vjp(
x, out_grad, axis, keepdim, reduce_all, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
if (!stop_gradients[0][0]) {
res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>(
tensor_res[0][0].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
return res;
}
} // namespace dialect
} // namespace paddle
24 changes: 14 additions & 10 deletions paddle/fluid/ir/interface/vjp.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,24 @@ namespace dialect {
class VjpInterface : public ir::OpInterfaceBase<VjpInterface> {
public:
struct Concept {
explicit Concept(std::vector<std::vector<ir::Value>> (*vjp)(
std::vector<std::vector<ir::Value>> out_grads,
explicit Concept(std::vector<std::vector<ir::OpResult>> (*vjp)(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients))
: vjp_(vjp) {}
std::vector<std::vector<ir::Value>> (*vjp_)(
std::vector<std::vector<ir::Value>> out_grads,
std::vector<std::vector<ir::OpResult>> (*vjp_)(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients);
};

template <class ConcreteOp>
struct Model : public Concept {
static std::vector<std::vector<ir::Value>> Vjp(
std::vector<std::vector<ir::Value>> out_grads,
static std::vector<std::vector<ir::OpResult>> Vjp(
ir::Operation* op,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果考虑控制流,此接口需要传block, vjp 产生的反向op添加到传入的block接口中

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这儿加了TODO,会在后续进行支持

const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients) {
return ConcreteOp::Vjp(out_grads, stop_gradients);
return ConcreteOp::Vjp(op, out_grads, stop_gradients);
}

Model() : Concept(Vjp) {}
Expand All @@ -43,10 +46,11 @@ class VjpInterface : public ir::OpInterfaceBase<VjpInterface> {
VjpInterface(ir::Operation* op, Concept* impl)
: ir::OpInterfaceBase<VjpInterface>(op), impl_(impl) {}

std::vector<std::vector<ir::Value>> Vjp(
std::vector<std::vector<ir::Value>> out_grads,
std::vector<std::vector<ir::OpResult>> Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients) {
return impl_->vjp_(out_grads, stop_gradients);
return impl_->vjp_(op, out_grads, stop_gradients);
}

private:
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/primitive/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
add_subdirectory(backend)
add_subdirectory(rule)
1 change: 1 addition & 0 deletions paddle/fluid/primitive/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Paddle Primitive Operator System and Combined Strategy Design
10 changes: 10 additions & 0 deletions paddle/fluid/primitive/backend/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
if(NOT (NOT WITH_PYTHON AND ON_INFER))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if(NOT (NOT WITH_PYTHON AND ON_INFER))
if(WITH_PYTHON OR NOT ON_INFER))

这样更简洁些?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢 会在下个pr统一修改

cc_library(
primitive_backend_eager_experimental
SRCS eager_backend.cc
DEPS final_dygraph_function eager_utils phi)
endif()
cc_library(
primitive_backend_static_experimental
SRCS static_backend.cc
DEPS pd_dialect)
26 changes: 26 additions & 0 deletions paddle/fluid/primitive/backend/eager_backend.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/primitive/backend/eager_backend.h"
#include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/primitive/primitive/primitive.h"

namespace paddle {
namespace primitive {
namespace backend {
namespace experimental {} // namespace experimental
} // namespace backend
} // namespace primitive
} // namespace paddle
Loading