Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Support while grad exe #59496

Merged
merged 40 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
b5c82b4
support lower to kernel for if_grad op
chen2016013 Nov 15, 2023
8279d42
add PD_DECLARE_KERNEL
chen2016013 Nov 16, 2023
a395203
fix
chen2016013 Nov 16, 2023
2c9ad7c
fix
chen2016013 Nov 16, 2023
9c0c2ae
fix
chen2016013 Nov 16, 2023
62633df
merge
chen2016013 Nov 16, 2023
34ce4b1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
chen2016013 Nov 16, 2023
44885af
resolve conflict
chen2016013 Nov 21, 2023
796c10f
resolve conflict
chen2016013 Nov 21, 2023
cc99d1f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
chen2016013 Nov 21, 2023
d410461
update
chen2016013 Nov 21, 2023
e0e6bc7
update
chen2016013 Nov 21, 2023
f10592d
update
chen2016013 Nov 21, 2023
eb2db16
update
chen2016013 Nov 21, 2023
ca073ca
update
chen2016013 Nov 23, 2023
1955bd9
update
chen2016013 Nov 23, 2023
155ece7
fix
zhangbo9674 Nov 24, 2023
428fcf3
update
chen2016013 Nov 27, 2023
c4909ea
update
chen2016013 Nov 27, 2023
121077c
solve conflicts
chen2016013 Nov 27, 2023
3f48d26
update
chen2016013 Nov 28, 2023
98ca4ac
resolve conflicts
chen2016013 Nov 28, 2023
3783545
update
chen2016013 Nov 28, 2023
811ebd2
update
chen2016013 Nov 28, 2023
edf124d
update
chen2016013 Nov 28, 2023
f06de5c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
chen2016013 Nov 28, 2023
7e19d9e
update
chen2016013 Nov 28, 2023
fa619bc
update
chen2016013 Nov 28, 2023
6d6647a
fix bugs and warnings
chen2016013 Nov 29, 2023
c264973
Merge commit 'refs/pull/59200/head' of https://github.com/PaddlePaddl…
zhangbo9674 Nov 29, 2023
2687fbb
fix
zhangbo9674 Nov 29, 2023
47c4d5e
fix
zhangbo9674 Nov 29, 2023
d9cbb35
fix
zhangbo9674 Nov 29, 2023
296ef3f
fix
zhangbo9674 Nov 29, 2023
ddc2516
fix
zhangbo9674 Nov 29, 2023
234482f
fix
zhangbo9674 Nov 30, 2023
645443a
fix conflict
zhangbo9674 Nov 30, 2023
066727e
fix
zhangbo9674 Nov 30, 2023
48be391
fix
zhangbo9674 Nov 30, 2023
7340d73
fix
zhangbo9674 Nov 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions paddle/fluid/framework/new_executor/instruction/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
cc_library(
instruction_base
SRCS instruction_base.cc phi_kernel_instruction.cc
legacy_kernel_instruction.cc cond_instruction.cc while_instruction.cc
SRCS instruction_base.cc
phi_kernel_instruction.cc
legacy_kernel_instruction.cc
cond_instruction.cc
while_instruction.cc
has_elements_instruction.cc
tuple_push_instruction.cc
tuple_pop_instruction.cc
instruction_util.cc
DEPS pir_adaptor phi framework_proto)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ CondInstruction::CondInstruction(size_t id,
VLOG(6) << "finish process analyse kernel type";

auto cond_value = if_op.operand_source(0);
cond_var_ = value_exec_info->GetScope()->FindVar(
value_exec_info->GetValue2VarName().at(cond_value));
cond_var_ = value_exec_info->GetVarByValue(cond_value);
for (size_t i = 0; i < if_op.num_results(); ++i) {
output_vars_.push_back(value_exec_info->GetScope()->GetVar(
value_exec_info->GetValue2VarName().at(if_op.result(i))));
Expand All @@ -70,9 +69,9 @@ CondInstruction::CondInstruction(size_t id,
std::unordered_map<pir::Value, std::vector<int>> inputs;
GetInputIds(op, *value_exec_info, &inputs);
auto true_outside_inputs =
GetOutsideOpInputs(true_branch_block, *value_exec_info, &inputs);
GetExternalInputs(true_branch_block, *value_exec_info, &inputs);
auto false_outside_inputs =
GetOutsideOpInputs(false_branch_block, *value_exec_info, &inputs);
GetExternalInputs(false_branch_block, *value_exec_info, &inputs);
SetInputs(inputs);

std::unordered_map<pir::Value, std::vector<int>> outputs;
Expand All @@ -89,6 +88,8 @@ CondInstruction::CondInstruction(size_t id,
outputs.emplace(value, GetValueIds(value, *value_exec_info));
}
}
InsertTuplePushContinerToOuts(true_branch_block, *value_exec_info, &outputs);
InsertTuplePushContinerToOuts(false_branch_block, *value_exec_info, &outputs);
SetOutputs(outputs);
VLOG(6) << "finish process inputs outputs index";

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/new_executor/instruction/has_elements_instruction.h"
#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h"
#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"

namespace paddle {
namespace framework {
HasElementsInstruction::HasElementsInstruction(
size_t id,
const platform::Place& place,
::pir::Operation* op,
ValueExecutionInfo* value_exe_info)
: InstructionBase(id, place), op_(op), value_exe_info_(value_exe_info) {
auto has_elements_op = op->dyn_cast<paddle::dialect::HasElementsOp>();
VLOG(6) << "construct has_elements instruction for: "
<< has_elements_op->name();

std::unordered_map<pir::Value, std::vector<int>> outputs;
outputs.emplace(has_elements_op.out(),
GetValueIds(has_elements_op.out(), *value_exe_info_));
SetOutputs(outputs);

std::unordered_map<pir::Value, std::vector<int>> inputs;
std::vector<int> inputs_id = {
value_exe_info_->GetVarId(has_elements_op.input())};
inputs.emplace(has_elements_op.input(), inputs_id);
SetInputs(inputs);

type_ = OpFuncType::kCpuSync;

platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* bool_tensor = value_exe_info_->GetVarByValue(op_->result(0))
->GetMutable<phi::DenseTensor>();
bool_tensor->Resize(phi::make_ddim({1}));
has_elements_ = pool.Get(platform::CPUPlace())->Alloc<bool>(bool_tensor);

auto stack_value =
op_->dyn_cast<paddle::dialect::HasElementsOp>().operand_source(0);
auto var_array = value_exe_info_->GetVarByValue(stack_value);
stack_element_var_array_ = var_array->GetMutable<VariableRefArray>();
}

void HasElementsInstruction::Run() {
VLOG(6) << "run has_elements instruction";
bool is_empty = stack_element_var_array_->size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 stack_element_var_array_没有empty接口吗?调用size(),返回bool值感觉有点奇怪

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好建议,目前没有,后面我添加一下

*has_elements_ = is_empty;
}
} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/tensor_ref_array.h"
#include "paddle/pir/dialect/control_flow/ir/cf_op.h"

namespace paddle {
namespace framework {
class ValueExecutionInfo;

class HasElementsInstruction : public InstructionBase {
public:
HasElementsInstruction(size_t id,
const platform::Place& place,
::pir::Operation* op,
ValueExecutionInfo* value_exe_info);

void Run() override;

const std::string& Name() const override { return name_; }

::pir::Operation* Operation() const override { return op_; }

private:
::pir::Operation* op_;

OpFuncType type_;

std::string name_{"has_elelments_instruction"};

// const platform::DeviceContext* dev_ctx_; // not owned
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之后可以删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除


ValueExecutionInfo* value_exe_info_; // not owned

VariableRefArray* stack_element_var_array_; // not owned

bool* has_elements_; // not owned
};

} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,21 @@ const std::vector<size_t>& InstructionBase::GCCheckVars() const {
return gc_check_vars_;
}

void InstructionBase::AddEagerGCVar(Variable* var) {
eager_gc_vars_.push_back(var);
}

const std::vector<Variable*>& InstructionBase::EagerGCVars() const {
// NOTE(chenxi67): eager_gc_vars_ contains the vars that need to be gc. Some
// vars in Instruction Node are created temporarily and are not the input or
// output of an OP (e.g. copy_var created by TuplePushOp). We cannot determine
// whether they need to be gc by analyzing OP(using GCCheckVars() function).
// These vars are added to eager_gc_vars_ and directly gc.
return eager_gc_vars_;
}

void InstructionBase::ClearEagerGCVars() { eager_gc_vars_.clear(); }

const std::vector<std::pair<Variable*, Variable*>>&
InstructionBase::InplaceInfo() const {
return vec_inplace_in_to_out_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ class InstructionBase {

const std::vector<size_t>& GCCheckVars() const;
void AddGCCheckVar(size_t id);
const std::vector<Variable*>& EagerGCVars() const;
void AddEagerGCVar(Variable* var);
void ClearEagerGCVars();

const std::vector<std::pair<Variable*, Variable*>>& InplaceInfo() const;
void AddInplace(Variable* in, Variable* out);
Expand Down Expand Up @@ -173,6 +176,8 @@ class InstructionBase {

std::vector<size_t> gc_check_vars_;

std::vector<Variable*> eager_gc_vars_;

std::vector<std::pair<Variable*, Variable*>>
vec_inplace_in_to_out_; // If not use share data, need this ?

Expand Down
70 changes: 55 additions & 15 deletions paddle/fluid/framework/new_executor/instruction/instruction_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/operation.h"
#include "paddle/pir/core/value.h"
#include "paddle/pir/dialect/control_flow/ir/cf_op.h"

#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h"
#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/pir/core/block_argument.h"
#include "paddle/pir/dialect/control_flow/ir/cf_op.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
Expand Down Expand Up @@ -223,18 +223,18 @@ void GetInputIds(pir::Operation* op,
}
}

std::unordered_set<pir::Value> GetBlockInnerOutputs(pir::Block* block) {
std::unordered_set<pir::Value> GetInternalOutputs(pir::Block* block) {
std::unordered_set<pir::Value> inner_outputs;
for (size_t arg_id = 0; arg_id < block->args_size(); ++arg_id) {
inner_outputs.insert(block->arg(arg_id));
}
for (auto& op : (*block)) {
VLOG(8) << "GetBlockInnerOutputs of " << op.name();
VLOG(8) << "GetInternalOutputs of " << op.name();
if (op.num_regions()) {
for (size_t i = 0; i < op.num_regions(); ++i) {
for (auto& sub_block : op.region(i)) {
std::unordered_set<pir::Value> sub_set =
GetBlockInnerOutputs(&sub_block);
GetInternalOutputs(&sub_block);
inner_outputs.insert(sub_set.begin(), sub_set.end());
}
}
Expand All @@ -246,51 +246,91 @@ std::unordered_set<pir::Value> GetBlockInnerOutputs(pir::Block* block) {
return inner_outputs;
}

std::unordered_set<pir::Value> GetBlockInnerInputs(pir::Block* block) {
std::unordered_set<pir::Value> GetInternalInputs(pir::Block* block) {
std::unordered_set<pir::Value> inner_inputs;
for (auto& op : (*block)) {
VLOG(8) << "GetBlockInnerInputs of " << op.name();
VLOG(8) << "GetInternalInputs of " << op.name();
if (op.num_regions()) {
for (size_t i = 0; i < op.num_regions(); ++i) {
for (auto& sub_block : op.region(i)) {
std::unordered_set<pir::Value> sub_set =
GetBlockInnerInputs(&sub_block);
GetInternalInputs(&sub_block);
inner_inputs.insert(sub_set.begin(), sub_set.end());
}
}
}
if (op.isa<pir::TuplePopOp>()) {
auto tuple_pop_op = op.dyn_cast<pir::TuplePopOp>();
inner_inputs.insert(tuple_pop_op.container());
}
for (size_t i = 0; i < op.num_operands(); ++i) {
inner_inputs.insert(op.operand_source(i));
VLOG(10) << op.name()
<< "'s inner_input: " << op.operand_source(i).impl();
}
}
return inner_inputs;
}

std::vector<pir::Value> GetOutsideOpInputs(
std::vector<pir::Value> GetExternalInputs(
pir::Block* block,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* input_ids) {
std::unordered_set<pir::Value> inner_outputs;
inner_outputs = GetBlockInnerOutputs(block);
inner_outputs = GetInternalOutputs(block);

std::unordered_set<pir::Value> inner_inputs;
inner_inputs = GetBlockInnerInputs(block);
inner_inputs = GetInternalInputs(block);

std::vector<pir::Value> outside_op_inputs;
for (pir::Value value : inner_inputs) {
if (value && (!inner_outputs.count(value))) {
PADDLE_ENFORCE_EQ(
value_exec_info.HasValue(value),
true,
phi::errors::PreconditionNotMet("input should be in name map"));
PADDLE_ENFORCE_EQ(value_exec_info.HasValue(value),
true,
phi::errors::PreconditionNotMet(
"input %s should be in name map", value.impl()));
input_ids->emplace(value, GetValueIds(value, value_exec_info));
outside_op_inputs.push_back(value);
VLOG(6) << "GetOutsideOpInputs of " << value.impl();
VLOG(6) << "GetExternalInputs of " << value.impl();
}
}
return outside_op_inputs;
}

std::unordered_set<pir::Value> GetTuplePushContainer(pir::Block* block) {
std::unordered_set<pir::Value> inner_outputs;
for (auto& op : (*block)) {
VLOG(8) << "GetTuplePushContainer of " << op.name();
if (op.num_regions()) {
for (size_t i = 0; i < op.num_regions(); ++i) {
for (auto& sub_block : op.region(i)) {
std::unordered_set<pir::Value> sub_set =
GetTuplePushContainer(&sub_block);
inner_outputs.insert(sub_set.begin(), sub_set.end());
}
}
}
if (op.isa<pir::TuplePushOp>()) {
auto tuple_push_op = op.dyn_cast<pir::TuplePushOp>();
inner_outputs.insert(tuple_push_op.container());
}
}
return inner_outputs;
}

void InsertTuplePushContinerToOuts(
pir::Block* block,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* outputs) {
std::unordered_set<pir::Value> inner_stack_outputs;
inner_stack_outputs = GetTuplePushContainer(block);

for (pir::Value value : inner_stack_outputs) {
outputs->emplace(value, GetValueIds(value, value_exec_info));
VLOG(6) << "InsertTuplePushContinerToOuts of " << value.impl();
}
}

bool GetCondData(const phi::DenseTensor& cond) {
if (paddle::platform::is_cpu_place(cond.place())) {
return cond.data<bool>()[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,16 @@ void GetInputIds(pir::Operation* op,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* input_ids);

std::vector<pir::Value> GetOutsideOpInputs(
std::vector<pir::Value> GetExternalInputs(
pir::Block* block,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* input_ids);

void InsertTuplePushContinerToOuts(
pir::Block* block,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* outputs);

bool GetCondData(const phi::DenseTensor& cond);
} // namespace framework
} // namespace paddle
Loading