Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Refine IfOp translate #58088

Merged
merged 2 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 86 additions & 37 deletions paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,36 @@ const std::unordered_set<std::string> ProgramTranslator::unsupported_ops = {
static std::vector<uint64_t> GetCondOpIds(const BlockDesc& src_block,
uint64_t first_id) {
std::vector<uint64_t> op_list = {first_id};
if (src_block.Op(static_cast<int>(first_id + 1))->Type() == "logical_not") {
if (((first_id + 1) < src_block.OpSize()) &&
(src_block.Op(static_cast<int>(first_id + 1))->Type() == "logical_not")) {
op_list.emplace_back(first_id + 1);
}
if (src_block.Op(static_cast<int>(first_id + 2))->Type() ==
"conditional_block") {
if (((first_id + 2) < src_block.OpSize()) &&
(src_block.Op(static_cast<int>(first_id + 2))->Type() ==
"conditional_block")) {
op_list.emplace_back(first_id + 2);
}
if (src_block.Op(static_cast<int>(first_id + 3))->Type() == "cast") {
if (((first_id + 3) < src_block.OpSize()) &&
(src_block.Op(static_cast<int>(first_id + 3))->Type() == "cast")) {
op_list.emplace_back(first_id + 3);
}
size_t output_size =
src_block.Op(static_cast<int>(first_id))->Output("Out").size();
// Note(zhangbo): Some output variables are input, without select_input op.
std::vector<std::string> output_names =
src_block.Op(static_cast<int>(first_id))->Output("Out");
std::vector<std::string> input_names =
src_block.Op(static_cast<int>(first_id))->Input("Input");
std::vector<std::string> diffs(output_names.size());
auto iter = std::set_difference(output_names.begin(),
output_names.end(),
input_names.begin(),
input_names.end(),
diffs.begin());
diffs.resize(iter - diffs.begin());
size_t output_size = diffs.size();
for (size_t i = 0; i < output_size; i++) {
if (src_block.Op(static_cast<int>(first_id + 4 + i))->Type() ==
"select_input") {
if (((first_id + 4 + i) < src_block.OpSize()) &&
(src_block.Op(static_cast<int>(first_id + 4 + i))->Type() ==
"select_input")) {
op_list.emplace_back(first_id + 4 + i);
}
}
Expand All @@ -97,7 +112,16 @@ const std::string& ConditionBlockCombination::CondVarName() const {
}

size_t ConditionBlockCombination::OutputSize() const {
return op_list_[0]->Output("Out").size();
std::vector<std::string> output_names = op_list_[0]->Output("Out");
std::vector<std::string> input_names = op_list_[0]->Input("Input");
std::vector<std::string> diffs(output_names.size());
auto iter = std::set_difference(output_names.begin(),
output_names.end(),
input_names.begin(),
input_names.end(),
diffs.begin());
diffs.resize(iter - diffs.begin());
return diffs.size();
}

std::vector<::paddle::framework::VarDesc*>
Expand All @@ -112,23 +136,41 @@ ConditionBlockCombination::OutputVars() const {
return outputs;
}

const std::vector<std::string>&
ConditionBlockCombination::TrueBlockOutputVarNames() const {
return op_list_[0]->Output("Out");
}

int ConditionBlockCombination::TrueBlockId() const {
return op_list_[0]->GetBlockAttrId("sub_block");
std::vector<std::string> ConditionBlockCombination::TrueBlockOutputVarNames()
const {
std::vector<std::string> output_names = op_list_[0]->Output("Out");
std::vector<std::string> input_names = op_list_[0]->Input("Input");
std::vector<std::string> diffs(output_names.size());
auto iter = std::set_difference(output_names.begin(),
output_names.end(),
input_names.begin(),
input_names.end(),
diffs.begin());
diffs.resize(iter - diffs.begin());
return diffs;
}

std::vector<std::string> ConditionBlockCombination::FalseBlockOutputVarNames()
const {
if (op_list_.size() > 1) {
return op_list_[2]->Output("Out");
std::vector<std::string> output_names = op_list_[2]->Output("Out");
std::vector<std::string> input_names = op_list_[2]->Input("Input");
std::vector<std::string> diffs(output_names.size());
auto iter = std::set_difference(output_names.begin(),
output_names.end(),
input_names.begin(),
input_names.end(),
diffs.begin());
diffs.resize(iter - diffs.begin());
return diffs;
}
return {""};
}

int ConditionBlockCombination::TrueBlockId() const {
return op_list_[0]->GetBlockAttrId("sub_block");
}

int ConditionBlockCombination::FalseBlockId() const {
if (op_list_.size() > 1) {
return op_list_[2]->GetBlockAttrId("sub_block");
Expand All @@ -143,9 +185,6 @@ bool ConditionBlockCombination::Verify(
if (op_list[id]->Type() != "conditional_block") {
return false;
}
if (op_list.size() == 1 && op_list[id]->Output("Out").size() != 0) {
return false;
}
} else if (id == 1) {
if (op_list[id]->Type() != "logical_not") {
return false;
Expand Down Expand Up @@ -207,11 +246,13 @@ void ProgramTranslator::Translate() {
}
}

void ProgramTranslator::TranslateBlock(const BlockDesc& src_block,
uint64_t start_id,
uint64_t end_id,
pir::Block* dest_block,
bool for_cond_block) {
void ProgramTranslator::TranslateBlock(
const BlockDesc& src_block,
uint64_t start_id,
uint64_t end_id,
pir::Block* dest_block,
bool for_cond_block,
std::vector<std::string> skip_cond_assign) {
VLOG(8) << "=============>start to translate a block";
PADDLE_ENFORCE(
(src_block.OpSize() >= end_id) && (start_id <= end_id),
Expand All @@ -223,10 +264,12 @@ void ProgramTranslator::TranslateBlock(const BlockDesc& src_block,
src_block.OpSize()));

std::unordered_map<uint64_t, bool> translate_completed;
std::vector<std::string> assign_inputs;
for (uint64_t op_id = start_id; op_id < end_id; op_id++) {
if (translate_completed.count(op_id) && translate_completed.at(op_id)) {
continue;
}

auto op = src_block.Op(static_cast<int>(op_id));
VLOG(8) << "=============>start to translate a op: " << op->Type();

Expand All @@ -246,20 +289,24 @@ void ProgramTranslator::TranslateBlock(const BlockDesc& src_block,
}
VLOG(10) << "[op translated][conditional_block]" << if_op;
} else {
TranslateGeneralOperation(op, dest_block);
translate_completed[op_id] = true;
if (for_cond_block && op->Type() == "assign" &&
std::count(skip_cond_assign.begin(),
skip_cond_assign.end(),
op->Output("Out")[0])) {
assign_inputs.push_back(op->Input("X")[0]);
translate_completed[op_id] = true;
} else {
TranslateGeneralOperation(op, dest_block);
translate_completed[op_id] = true;
}
}
}
// NOTE(zhangbo): If conditional_block operator has output, the cf.yeild
// operator needs to be inserted
if (for_cond_block) {
std::vector<pir::Value> yeild_inputs;
for (size_t id = end_id; id < src_block.OpSize(); id++) {
PADDLE_ENFORCE(
src_block.Op(id)->Type() == "assign",
"The operator at the end of the sub block needs to be assign");
yeild_inputs.emplace_back(
param_map_[src_block.Op(static_cast<int>(id))->Input("X")[0]].value);
for (size_t id = 0; id < assign_inputs.size(); id++) {
yeild_inputs.emplace_back(param_map_[assign_inputs[id]].value);
}
pir::AttributeMap attribute_map;
auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name());
Expand Down Expand Up @@ -308,9 +355,10 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation(
if (true_region.empty()) true_region.emplace_back();
TranslateBlock(true_sub_block,
0,
true_sub_block.OpSize() - cond_ops.OutputSize(),
true_sub_block.OpSize(),
true_region.front(),
true);
true,
cond_ops.TrueBlockOutputVarNames());
}
VLOG(4) << "[general op][conditional_block] IfOp true block translate end.";

Expand All @@ -321,9 +369,10 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation(
if (false_region.empty()) false_region.emplace_back();
TranslateBlock(false_sub_block,
0,
false_sub_block.OpSize() - cond_ops.OutputSize(),
false_sub_block.OpSize(),
false_region.front(),
true);
true,
cond_ops.FalseBlockOutputVarNames());
}
VLOG(4) << "[general op][conditional_block] IfOp false block translate end.";
VLOG(4) << "[general op][conditional_block] IfOp translate end.";
Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/ir_adaptor/translator/program_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ class ConditionBlockCombination {
ConditionBlockCombination(const ::paddle::framework::BlockDesc& src_block,
const std::vector<uint64_t>& op_ids);
const std::string& CondVarName() const;
int TrueBlockId() const;
int FalseBlockId() const;
size_t OutputSize() const;
std::vector<::paddle::framework::VarDesc*> OutputVars() const;
const std::vector<std::string>& TrueBlockOutputVarNames() const;
int TrueBlockId() const;
std::vector<std::string> TrueBlockOutputVarNames() const;
std::vector<std::string> FalseBlockOutputVarNames() const;
int FalseBlockId() const;

private:
bool Verify(const std::vector<::paddle::framework::OpDesc*>& op_list);
Expand Down Expand Up @@ -101,7 +101,8 @@ class ProgramTranslator {
uint64_t start_id,
uint64_t end_id,
pir::Block* dest_block,
bool for_cond_block = false);
bool for_cond_block = false,
std::vector<std::string> skip_cond_assign = {});
void TranslateGeneralOperation(const OpDesc* src_op, pir::Block* dest_block);
void GetParameterForSingleBlock(const BlockDesc& block);
void InsertOperationToSingleBlock(const BlockDesc& block);
Expand Down