-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PRIM][IR] Migrate vjp rules to new ir in non primitive mode #55647
Changes from all commits
f630e33
c8bd625
5612359
4d8079f
fe5605b
f9389ec
67cf1fc
35f867b
5fe88d5
703c168
0738201
d49d38a
a9e9d01
4df18b5
5a65b50
5a3710a
c035675
a9b8240
901352c
de4ac55
f3da449
84b92dd
690a0b9
4ee2d44
866dc2c
0d3d7d6
dc3e7be
b4579f2
be05029
619bcd0
e57d1f0
ac8b2a6
5612b2f
afcb454
40d7ab0
d9a78f6
ed442ff
95efe5e
f802b36
820b313
4f320f0
fe1b035
587bea0
c155302
d4f37b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# ===================================== | ||
# VjpInterface gen op list | ||
# ===================================== | ||
# we don't support vjp function code | ||
# gen now, so we use a whitelist to | ||
# control the generation of Vjp methods. | ||
# TODO(wanghao107) | ||
# remove this file and support Vjp methods | ||
# code gen. | ||
vjp_interface_gen_op_list = ["tanh", "mean"] |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -53,5 +53,23 @@ ir::OpResult full(std::vector<int64_t> shape, | |||||
return full_op.out(); | ||||||
} | ||||||
|
||||||
ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out) { | ||||||
paddle::dialect::TanhGradOp tanh_grad_op = | ||||||
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::TanhGradOp>( | ||||||
out, grad_out); | ||||||
return tanh_grad_op.result(0); | ||||||
} | ||||||
|
||||||
ir::OpResult mean_grad(ir::OpResult x, | ||||||
ir::OpResult out_grad, | ||||||
std::vector<int64_t> axis, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 感谢,这儿会在下个PR进行统一修改 |
||||||
bool keepdim, | ||||||
bool reduce_all) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的reduce_all 参数是上次提到的无用参数么? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的,这儿跟动态图保持一致,直接在前向传一个false过来了。 |
||||||
paddle::dialect::MeanGradOp mean_grad_op = | ||||||
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::MeanGradOp>( | ||||||
x, out_grad, axis, keepdim, reduce_all); | ||||||
return mean_grad_op.result(0); | ||||||
} | ||||||
|
||||||
} // namespace dialect | ||||||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,5 +39,12 @@ ir::OpResult full(std::vector<int64_t> shape, | |
phi::DataType dtype = phi::DataType::FLOAT32, | ||
phi::Place place = phi::CPUPlace()); | ||
|
||
ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out); | ||
|
||
ir::OpResult mean_grad(ir::OpResult x, | ||
ir::OpResult out_grad, | ||
std::vector<int64_t> axis = {}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里应该是 const &? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 感谢,这儿会在下个PR进行统一修改 |
||
bool keepdim = false, | ||
bool reduce_all = false); | ||
} // namespace dialect | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -91,6 +91,9 @@ class APIBuilder { | |
ctx_ = ir::IrContext::Instance(); | ||
ctx_->GetOrRegisterDialect<paddle::dialect::PaddleDialect>(); | ||
} | ||
|
||
APIBuilder(const APIBuilder&) = delete; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ir/core/macros.h 里有 DISABLE_COPY_AND_ASSIGN 宏可以直接用 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 感谢,这儿会在下个PR进行统一修改 |
||
|
||
ir::IrContext* ctx_; | ||
std::shared_ptr<ir::Builder> builder_; | ||
}; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/ir/dialect/pd_attribute.h" | ||
#include "paddle/fluid/ir/dialect/pd_op.h" | ||
#include "paddle/fluid/primitive/rule/vjp/vjp.h" | ||
#include "paddle/fluid/primitive/type/desc_tensor.h" | ||
#include "paddle/ir/core/op_base.h" | ||
|
||
// TODO(wanghao107) | ||
// this file will be generated in pd_op.cc | ||
|
||
namespace paddle { | ||
namespace dialect { | ||
std::vector<std::vector<ir::OpResult>> TanhOp::Vjp( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 外层用small_vector There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这儿会在下个PR进行修改,目前先合入一版让大家工作可以正常推进 |
||
ir::Operation* op, | ||
const std::vector<std::vector<ir::OpResult>>& out_grads, | ||
const std::vector<std::vector<int>>& stop_gradients) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. stop_gradients 信息是用 int 来存储表示的么?为什么不是bool呢? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. vector< bool > 不可以取地址并且访问速度较慢,但是这儿为了语意清晰可以下一个PR改成bool |
||
TanhOp op_obj = op->dyn_cast<TanhOp>(); | ||
Tensor out( | ||
std::make_shared<primitive::experimental::DescTensor>(op_obj.out())); | ||
Tensor grad_out( | ||
std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0])); | ||
std::vector<std::vector<Tensor>> tensor_res = | ||
primitive::experimental::tanh_vjp(out, grad_out, stop_gradients); | ||
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1)); | ||
if (!stop_gradients[0][0]) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 只有在 if 分支里才会用到 tensor_res ? 那为什么不把tensor_res 这一行放到 if 里面呢? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这儿如果放到if里面会把上层剪枝逻辑带到下层来,不会去构建反向op。 |
||
res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>( | ||
tensor_res[0][0].impl()) | ||
->getValue() | ||
.dyn_cast<ir::OpResult>(); | ||
} | ||
return res; | ||
} | ||
|
||
std::vector<std::vector<ir::OpResult>> Tanh_Op::Vjp( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这儿会在下个PR进行修改,目前先合入一版让大家工作可以正常推进 |
||
ir::Operation* op, | ||
const std::vector<std::vector<ir::OpResult>>& out_grads, | ||
const std::vector<std::vector<int>>& stop_gradients) { | ||
// TODO(wanghao107) | ||
// we don't support inplace now, | ||
// so use the non-inplace version instead currently. | ||
// Support inplace in the future. | ||
Tanh_Op op_obj = op->dyn_cast<Tanh_Op>(); | ||
Tensor out( | ||
std::make_shared<primitive::experimental::DescTensor>(op_obj.out())); | ||
Tensor grad_out( | ||
std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0])); | ||
std::vector<std::vector<Tensor>> tensor_res = | ||
primitive::experimental::tanh_vjp(out, grad_out, stop_gradients); | ||
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1)); | ||
if (!stop_gradients[0][0]) { | ||
res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>( | ||
tensor_res[0][0].impl()) | ||
->getValue() | ||
.dyn_cast<ir::OpResult>(); | ||
} | ||
return res; | ||
} | ||
|
||
std::vector<std::vector<ir::OpResult>> MeanOp::Vjp( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这儿会在下个PR进行修改,目前先合入一版让大家工作可以正常推进 |
||
ir::Operation* op, | ||
const std::vector<std::vector<ir::OpResult>>& out_grads, | ||
const std::vector<std::vector<int>>& stop_gradients) { | ||
MeanOp op_obj = op->dyn_cast<MeanOp>(); | ||
Tensor x(std::make_shared<primitive::experimental::DescTensor>(op_obj.x())); | ||
Tensor out_grad( | ||
std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0])); | ||
|
||
std::vector<int64_t> axis = | ||
op->attribute("axis") | ||
.dyn_cast<paddle::dialect::IntArrayAttribute>() | ||
.data() | ||
.GetData(); | ||
bool keepdim = op->attribute("keepdim").dyn_cast<ir::BoolAttribute>().data(); | ||
bool reduce_all = false; | ||
std::vector<std::vector<Tensor>> tensor_res = | ||
primitive::experimental::mean_vjp( | ||
x, out_grad, axis, keepdim, reduce_all, stop_gradients); | ||
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1)); | ||
if (!stop_gradients[0][0]) { | ||
res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>( | ||
tensor_res[0][0].impl()) | ||
->getValue() | ||
.dyn_cast<ir::OpResult>(); | ||
} | ||
return res; | ||
} | ||
} // namespace dialect | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,21 +20,24 @@ namespace dialect { | |
class VjpInterface : public ir::OpInterfaceBase<VjpInterface> { | ||
public: | ||
struct Concept { | ||
explicit Concept(std::vector<std::vector<ir::Value>> (*vjp)( | ||
std::vector<std::vector<ir::Value>> out_grads, | ||
explicit Concept(std::vector<std::vector<ir::OpResult>> (*vjp)( | ||
ir::Operation* op, | ||
const std::vector<std::vector<ir::OpResult>>& out_grads, | ||
const std::vector<std::vector<int>>& stop_gradients)) | ||
: vjp_(vjp) {} | ||
std::vector<std::vector<ir::Value>> (*vjp_)( | ||
std::vector<std::vector<ir::Value>> out_grads, | ||
std::vector<std::vector<ir::OpResult>> (*vjp_)( | ||
ir::Operation* op, | ||
const std::vector<std::vector<ir::OpResult>>& out_grads, | ||
const std::vector<std::vector<int>>& stop_gradients); | ||
}; | ||
|
||
template <class ConcreteOp> | ||
struct Model : public Concept { | ||
static std::vector<std::vector<ir::Value>> Vjp( | ||
std::vector<std::vector<ir::Value>> out_grads, | ||
static std::vector<std::vector<ir::OpResult>> Vjp( | ||
ir::Operation* op, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果考虑控制流,此接口需要传block, vjp 产生的反向op添加到传入的block接口中 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这儿加了TODO,会在后续进行支持 |
||
const std::vector<std::vector<ir::OpResult>>& out_grads, | ||
const std::vector<std::vector<int>>& stop_gradients) { | ||
return ConcreteOp::Vjp(out_grads, stop_gradients); | ||
return ConcreteOp::Vjp(op, out_grads, stop_gradients); | ||
} | ||
|
||
Model() : Concept(Vjp) {} | ||
|
@@ -43,10 +46,11 @@ class VjpInterface : public ir::OpInterfaceBase<VjpInterface> { | |
VjpInterface(ir::Operation* op, Concept* impl) | ||
: ir::OpInterfaceBase<VjpInterface>(op), impl_(impl) {} | ||
|
||
std::vector<std::vector<ir::Value>> Vjp( | ||
std::vector<std::vector<ir::Value>> out_grads, | ||
std::vector<std::vector<ir::OpResult>> Vjp( | ||
ir::Operation* op, | ||
const std::vector<std::vector<ir::OpResult>>& out_grads, | ||
const std::vector<std::vector<int>>& stop_gradients) { | ||
return impl_->vjp_(out_grads, stop_gradients); | ||
return impl_->vjp_(op, out_grads, stop_gradients); | ||
} | ||
|
||
private: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
add_subdirectory(backend) | ||
add_subdirectory(rule) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Paddle Primitive Operator System and Combined Strategy Design |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,10 @@ | ||||||
if(NOT (NOT WITH_PYTHON AND ON_INFER)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
这样更简洁些? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 感谢 会在下个pr统一修改 |
||||||
cc_library( | ||||||
primitive_backend_eager_experimental | ||||||
SRCS eager_backend.cc | ||||||
DEPS final_dygraph_function eager_utils phi) | ||||||
endif() | ||||||
cc_library( | ||||||
primitive_backend_static_experimental | ||||||
SRCS static_backend.cc | ||||||
DEPS pd_dialect) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/primitive/backend/eager_backend.h" | ||
#include "paddle/fluid/eager/api/all.h" | ||
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" | ||
#include "paddle/fluid/primitive/primitive/primitive.h" | ||
|
||
namespace paddle { | ||
namespace primitive { | ||
namespace backend { | ||
namespace experimental {} // namespace experimental | ||
} // namespace backend | ||
} // namespace primitive | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle.prim.DescTensor 之后是会删除吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的 目前属于过渡阶段,后续过渡完成后会移除旧的组合算子体系