Skip to content

Commit

Permalink
refine code (PaddlePaddle#58312)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo9674 authored and wentaoyu committed Oct 24, 2023
1 parent 3dd85f7 commit efea14c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
28 changes: 14 additions & 14 deletions paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,10 @@ void ProgramTranslator::TranslateBlock(
uint64_t start_id,
uint64_t end_id,
TranslationContext* translation_ctx,
pir::Block* dest_block,
pir::Block* dst_block,
bool for_cond_block,
std::vector<std::string> cond_sub_block_outputs,
std::vector<::paddle::framework::OpDesc*> cond_init_ops) {
const std::vector<std::string>& cond_sub_block_outputs,
const std::vector<::paddle::framework::OpDesc*>& cond_init_ops) {
VLOG(8) << "=============>start to translate a block";
PADDLE_ENFORCE(
(src_block.OpSize() >= end_id) && (start_id <= end_id),
Expand Down Expand Up @@ -372,13 +372,13 @@ void ProgramTranslator::TranslateBlock(
std::vector<uint64_t> cond_op_ids = GetCondOpIds(src_block, op_id);
ConditionBlockCombination cond_op_combination(src_block, cond_op_ids);
pir::Operation* if_op = TranslateCondIfOperation(
cond_op_combination, translation_ctx, dest_block);
cond_op_combination, translation_ctx, dst_block);
for (auto cond_id : cond_op_ids) {
translate_completed[cond_id] = true;
}
VLOG(10) << "[op translated][conditional_block]" << if_op;
} else if (op->Type() == "while") {
TranslateWhileOperation(op, translation_ctx, dest_block);
TranslateWhileOperation(op, translation_ctx, dst_block);
} else {
if (for_cond_block && op->Type() == "assign" &&
std::count(cond_sub_block_outputs.begin(),
Expand All @@ -387,7 +387,7 @@ void ProgramTranslator::TranslateBlock(
assign_output_2_input[op->Output("Out")[0]] = op->Input("X")[0];
translate_completed[op_id] = true;
} else {
TranslateGeneralOperation(op, translation_ctx, dest_block);
TranslateGeneralOperation(op, translation_ctx, dst_block);
translate_completed[op_id] = true;
}
}
Expand All @@ -398,7 +398,7 @@ void ProgramTranslator::TranslateBlock(
if (for_cond_block) {
// insert init ops
for (::paddle::framework::OpDesc* init_op : cond_init_ops) {
TranslateGeneralOperation(init_op, translation_ctx, dest_block);
TranslateGeneralOperation(init_op, translation_ctx, dst_block);
}
// insert yeild op
std::vector<pir::Value> yeild_inputs;
Expand All @@ -414,14 +414,14 @@ void ProgramTranslator::TranslateBlock(
auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name());
pir::Operation* yeild_op =
pir::Operation::Create(yeild_inputs, attribute_map, {}, yeild_info);
dest_block->push_back(yeild_op);
dst_block->push_back(yeild_op);
}
}

pir::Operation* ProgramTranslator::TranslateCondIfOperation(
const ConditionBlockCombination& cond_ops,
TranslationContext* translation_ctx,
pir::Block* dest_block) {
pir::Block* dst_block) {
auto& type_translator = TypeTranslator::instance();
auto op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::IfOp::name());
std::vector<pir::Value> op_inputs = {
Expand Down Expand Up @@ -449,7 +449,7 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation(
VariableDefiningInfo(operation->result(i)));
}

dest_block->push_back(operation);
dst_block->push_back(operation);
VLOG(4) << "[general op][conditional_block] IfOp creation end.";

if (cond_ops.TrueBlockId() != -1) {
Expand Down Expand Up @@ -496,7 +496,7 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation(
void ProgramTranslator::TranslateWhileOperation(
const OpDesc* op,
TranslationContext* translation_ctx,
pir::Block* dest_block) {
pir::Block* dst_block) {
VLOG(8) << "=============>Start to translate while op:" << op;
auto& sub_block = legacy_program_->Block(op->GetBlockAttrId("sub_block"));
int index = static_cast<int>(sub_block.OpSize()) - 1;
Expand Down Expand Up @@ -536,7 +536,7 @@ void ProgramTranslator::TranslateWhileOperation(
}
pir::Operation* while_op =
pir::Operation::Create(op_inputs, {}, op_outputs_type, op_info, 1);
dest_block->push_back(while_op);
dst_block->push_back(while_op);
while_op->region(0).push_back(body_block);
TranslateBlock(sub_block, 0, index + 1, translation_ctx, body_block);

Expand Down Expand Up @@ -566,15 +566,15 @@ void ProgramTranslator::TranslateWhileOperation(
void ProgramTranslator::TranslateGeneralOperation(
const OpDesc* src_op,
TranslationContext* translation_ctx,
pir::Block* dest_block) {
pir::Block* dst_block) {
auto& op_translator = OpTranslator::instance();
OpTranslateFn& fn = op_translator[src_op->Type()];
if (src_op->Type() == "shadow_output") {
if (!translation_ctx->count(src_op->Input("x")[0])) {
return;
}
}
pir::Operation* operation = fn(ctx_, translation_ctx, *src_op, dest_block);
pir::Operation* operation = fn(ctx_, translation_ctx, *src_op, dst_block);
VLOG(10) << "[op translated][general]" << operation << "end";
}

Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/ir_adaptor/translator/program_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,13 @@ class ProgramTranslator {
uint64_t start_id,
uint64_t end_id,
TranslationContext* translation_ctx,
pir::Block* dest_block,
pir::Block* dst_block,
bool for_cond_block = false,
std::vector<std::string> cond_sub_block_outputs = {},
std::vector<::paddle::framework::OpDesc*> cond_init_ops = {});
const std::vector<std::string>& cond_sub_block_outputs = {},
const std::vector<::paddle::framework::OpDesc*>& cond_init_ops = {});
void TranslateGeneralOperation(const OpDesc* src_op,
TranslationContext* translation_ctx,
pir::Block* dest_block);
pir::Block* dst_block);
void GetParameterForSingleBlock(const BlockDesc& block);
void SetParameterFromSingleBlock(const BlockDesc& block);
void SetStopGradientAttributeForAllValue(const BlockDesc& block);
Expand All @@ -159,10 +159,10 @@ class ProgramTranslator {
pir::Operation* TranslateCondIfOperation(
const ConditionBlockCombination& cond_ops,
TranslationContext* translation_ctx,
pir::Block* dest_block);
pir::Block* dst_block);
void TranslateWhileOperation(const OpDesc* op,
TranslationContext* translation_ctx,
pir::Block* dest_block);
pir::Block* dst_block);
};

} // namespace translator
Expand Down

0 comments on commit efea14c

Please sign in to comment.