diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 8ae93ea0cb0d51..bdf258f4ee3f1b 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -254,16 +254,16 @@ void WhileOp::Print(pir::IrPrinter &printer) { auto &os = printer.os; auto op = operation(); printer.PrintOpResult(op); - os << " = \"" << name() << "\"("; + os << " = \"" << name() << "\"(cond="; printer.PrintValue(cond()); - os << ") ["; + os << ", inputs="; auto operands = (*this)->operands_source(); pir::PrintInterleave( operands.begin() + 1, operands.end(), [&](pir::Value v) { printer.PrintValue(v); }, [&]() { os << ", "; }); - os << "] { \n ^"; + os << ") { \n ^"; pir::PrintInterleave( body().args_begin(), body().args_end(), diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index c86ca29d85a50c..9de9e4f074a208 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -526,7 +526,10 @@ def _add_pir_fetch_ops(program, fetch_list, fetch_var_name): assert isinstance( fetch_input, (OpResult, Value) ), f"Wrong type for fetch_list[{i}]: {type(fetch_input)}" - paddle._pir_ops.fetch(fetch_input, fetch_var_name + str(i), i) + out = paddle._pir_ops.fetch( + fetch_input, fetch_var_name + str(i), i + ) + out.persistable = True def _merge_tensors(tensor, micro_batch_num): diff --git a/test/cpp/pir/control_flow_dialect/while_op_test.cc b/test/cpp/pir/control_flow_dialect/while_op_test.cc index dbcdc6d361ec20..55291c2ccbc1b4 100644 --- a/test/cpp/pir/control_flow_dialect/while_op_test.cc +++ b/test/cpp/pir/control_flow_dialect/while_op_test.cc @@ -185,6 +185,9 @@ TEST(while_op_test, network_with_backward) { LOG(INFO) << program; auto place = paddle::platform::CPUPlace(); +#ifdef PADDLE_WITH_CUDA + place = paddle::platform::CUDAPlace(0); +#endif auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program, place); paddle::framework::Scope scope; paddle::framework::InterpreterCore test_core(