Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo9674 authored and SigureMo committed Dec 5, 2023
1 parent 0a423ac commit c47aec4
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,16 +289,16 @@ void WhileOp::Print(pir::IrPrinter &printer) {
auto &os = printer.os;
auto op = operation();
printer.PrintOpResult(op);
os << " = \"" << name() << "\"(";
os << " = \"" << name() << "\"(cond=";
printer.PrintValue(cond());
os << ") [";
os << ", inputs=";
auto operands = (*this)->operands_source();
pir::PrintInterleave(
operands.begin() + 1,
operands.end(),
[&](pir::Value v) { printer.PrintValue(v); },
[&]() { os << ", "; });
os << "] { \n ^";
os << ") { \n ^";
pir::PrintInterleave(
body().args_begin(),
body().args_end(),
Expand Down
5 changes: 4 additions & 1 deletion python/paddle/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,10 @@ def _add_pir_fetch_ops(program, fetch_list, fetch_var_name):
assert isinstance(
fetch_input, (OpResult, Value)
), f"Wrong type for fetch_list[{i}]: {type(fetch_input)}"
paddle._pir_ops.fetch(fetch_input, fetch_var_name + str(i), i)
out = paddle._pir_ops.fetch(
fetch_input, fetch_var_name + str(i), i
)
out.persistable = True


def _merge_tensors(tensor, micro_batch_num):
Expand Down
3 changes: 3 additions & 0 deletions test/cpp/pir/control_flow_dialect/while_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ TEST(while_op_test, network_with_backward) {
LOG(INFO) << program;

auto place = paddle::platform::CPUPlace();
#ifdef PADDLE_WITH_CUDA
place = paddle::platform::CUDAPlace(0);
#endif
auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program, place);
paddle::framework::Scope scope;
paddle::framework::InterpreterCore test_core(
Expand Down

0 comments on commit c47aec4

Please sign in to comment.