Skip to content

Commit

Permalink
[PIR]Rename flags (PaddlePaddle#57496)
Browse files Browse the repository at this point in the history
* rename flag

* fix py3 bugs

* modify demo code
  • Loading branch information
YuanRisheng authored and iosmers committed Sep 21, 2023
1 parent 1329e29 commit 2a52a8c
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 43 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/feed_fetch_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License. */
#include "glog/logging.h"

PHI_DECLARE_bool(enable_new_ir_in_executor);
PHI_DECLARE_bool(enable_new_ir_api);
PHI_DECLARE_bool(enable_pir_api);

namespace phi {
class DenseTensor;
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/framework/new_executor/standalone_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include "paddle/pir/pass/pass_manager.h"

PHI_DECLARE_bool(enable_new_ir_in_executor);
PHI_DECLARE_bool(enable_new_ir_api);
PHI_DECLARE_bool(enable_pir_api);
PHI_DECLARE_bool(new_ir_apply_inplace_pass);

namespace paddle {
Expand All @@ -55,7 +55,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
const std::string& job_type = job->Type();
std::shared_ptr<ProgramDesc> program = nullptr;
std::shared_ptr<::pir::Program> ir_program = nullptr;
if (FLAGS_enable_new_ir_api) {
if (FLAGS_enable_pir_api) {
ir_program = plan_.IrProgram(job_type);
} else {
program = std::make_shared<ProgramDesc>(*(plan_.Program(job_type)));
Expand All @@ -69,7 +69,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
micro_batch_id,
micro_batch_num));

if (micro_batch_num > 1 && !FLAGS_enable_new_ir_api) {
if (micro_batch_num > 1 && !FLAGS_enable_pir_api) {
SetColAttrForFeedFetchOps(program, micro_batch_num, micro_batch_id);
}

Expand All @@ -80,7 +80,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
// TODO(phlrain) we only support cpu for now
if (FLAGS_enable_new_ir_in_executor) {
std::shared_ptr<::pir::Program> base_program = ir_program;
if (!FLAGS_enable_new_ir_api) {
if (!FLAGS_enable_pir_api) {
VLOG(6) << "begin to translate" << std::endl;
base_program = paddle::TranslateLegacyProgramToProgram(*program);
}
Expand Down
24 changes: 12 additions & 12 deletions paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class StaticTensorOperants : public TensorOperantsBase {
#include "paddle/fluid/primitive/backend/backend.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
PHI_DECLARE_bool(enable_new_ir_api);
PHI_DECLARE_bool(enable_pir_api);
"""

Expand All @@ -227,39 +227,39 @@ class StaticTensorOperants : public TensorOperantsBase {
using LazyTensor = paddle::primitive::LazyTensor;
Tensor StaticTensorOperants::add(const Tensor& x, const Scalar& y) {
if (FLAGS_enable_new_ir_api) {
if (FLAGS_enable_pir_api) {
return paddle::primitive::backend::add<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
} else {
return paddle::prim::add<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}
}
Tensor StaticTensorOperants::subtract(const Tensor& x, const Scalar& y) {
if (FLAGS_enable_new_ir_api) {
if (FLAGS_enable_pir_api) {
return paddle::primitive::backend::subtract<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
} else {
return paddle::prim::subtract<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}
}
Tensor StaticTensorOperants::multiply(const Tensor& x, const Scalar& y) {
if (FLAGS_enable_new_ir_api) {
if (FLAGS_enable_pir_api) {
return paddle::primitive::backend::scale<LazyTensor>(x, y, 0.0f, true);
} else {
return paddle::prim::scale<DescTensor>(x, y, 0.0f, true);
}
}
Tensor StaticTensorOperants::divide(const Tensor& x, const Scalar& y) {
if (FLAGS_enable_new_ir_api) {
if (FLAGS_enable_pir_api) {
return paddle::primitive::backend::divide<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
} else {
return paddle::prim::divide<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}
}
Tensor StaticTensorOperants::add(const Scalar& x, const Tensor& y) {
if (FLAGS_enable_new_ir_api) {
if (FLAGS_enable_pir_api) {
return paddle::primitive::backend::add<LazyTensor>(paddle::primitive::backend::full<LazyTensor>(y.shape(), x, y.dtype(), y.place()), y);
} else {
return paddle::prim::add<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
Expand All @@ -268,39 +268,39 @@ class StaticTensorOperants : public TensorOperantsBase {
Tensor StaticTensorOperants::subtract(const Scalar& x, const Tensor& y) {
if (FLAGS_enable_new_ir_api) {
if (FLAGS_enable_pir_api) {
return paddle::primitive::backend::subtract<LazyTensor>(paddle::primitive::backend::full<LazyTensor>(y.shape(), x, y.dtype(), y.place()), y);
} else {
return paddle::prim::subtract<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
}
}
Tensor StaticTensorOperants::multiply(const Scalar& x, const Tensor& y) {
if (FLAGS_enable_new_ir_api) {
if (FLAGS_enable_pir_api) {
return paddle::primitive::backend::scale<LazyTensor>(y, x, 0.0f, true);
} else {
return paddle::prim::scale<DescTensor>(y, x, 0.0f, true);
}
}
Tensor StaticTensorOperants::divide(const Scalar& x, const Tensor& y) {
if (FLAGS_enable_new_ir_api) {
if (FLAGS_enable_pir_api) {
return paddle::primitive::backend::divide<LazyTensor>(paddle::primitive::backend::full<LazyTensor>(y.shape(), x, y.dtype(), y.place()), y);
} else {
return paddle::prim::divide<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
}
}
Tensor StaticTensorOperants::pow(const Tensor& x, const Tensor& y) {
if (FLAGS_enable_new_ir_api) {
if (FLAGS_enable_pir_api) {
return paddle::primitive::backend::elementwise_pow<LazyTensor>(x, y);
} else {
return paddle::prim::elementwise_pow<DescTensor>(x, y);
}
}
Tensor StaticTensorOperants::pow(const Tensor& x, const Scalar& y) {
if (FLAGS_enable_new_ir_api) {
if (FLAGS_enable_pir_api) {
return paddle::primitive::backend::elementwise_pow<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
} else {
return paddle::prim::elementwise_pow<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
Expand Down Expand Up @@ -393,7 +393,7 @@ def gene_static_tensor_func_call(self):
)
static_func_parameters = self.get_func_args()

static_tensor_func_call = f"""if (FLAGS_enable_new_ir_api) {{
static_tensor_func_call = f"""if (FLAGS_enable_pir_api) {{
return {backend_static_func_name}({static_func_parameters});
}} else {{
return {prim_static_func_name}({static_func_parameters});
Expand Down
6 changes: 2 additions & 4 deletions paddle/phi/core/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1278,15 +1278,13 @@ PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor,

/**
* Using new IR API in Python
* Name: enable_new_ir_api
* Name: enable_pir_api
* Since Version: 2.6.0
* Value Range: bool, default=false
* Example:
* Note: If Ture, New IR API will be used in Python
*/
PHI_DEFINE_EXPORTED_bool(enable_new_ir_api,
false,
"Enable new IR API in Python");
PHI_DEFINE_EXPORTED_bool(enable_pir_api, false, "Enable new IR API in Python");

/**
* Using new IR in executor FLAG
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def __init__(self):
self._in_to_static_mode_ = False
self._functional_dygraph_context_manager = None
self._dygraph_tracer_ = _dygraph_tracer_
self._use_pir_api_ = get_flags("FLAGS_enable_new_ir_api")[
'FLAGS_enable_new_ir_api'
self._use_pir_api_ = get_flags("FLAGS_enable_pir_api")[
'FLAGS_enable_pir_api'
]

def __str__(self):
Expand Down Expand Up @@ -340,8 +340,8 @@ def in_dynamic_or_pir_mode():
>>> print(paddle.framework.in_dynamic_or_pir_mode())
False
>>> paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
>>> print(paddle.framework.in_dynamic_or_pir_mode())
>>> with paddle.pir_utils.IrGuard():
... print(paddle.framework.in_dynamic_or_pir_mode())
True
"""
Expand Down
26 changes: 13 additions & 13 deletions python/paddle/pir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
class IrGuard:
def __init__(self):
self.in_dygraph_outside = False
old_flag = paddle.base.framework.get_flags("FLAGS_enable_new_ir_api")
paddle.base.framework.set_flags({"FLAGS_enable_new_ir_api": False})
old_flag = paddle.base.framework.get_flags("FLAGS_enable_pir_api")
paddle.base.framework.set_flags({"FLAGS_enable_pir_api": False})
paddle.base.framework.global_var._use_pir_api_ = False
if not paddle.base.framework.get_flags("FLAGS_enable_new_ir_api")[
"FLAGS_enable_new_ir_api"
if not paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
"FLAGS_enable_pir_api"
]:
self.old_Program = paddle.static.Program
self.old_program_guard = paddle.base.program_guard
Expand All @@ -34,31 +34,31 @@ def __init__(self):
else:
raise RuntimeError(
"IrGuard only init when paddle.framework.in_pir_mode(): is false, \
please set FLAGS_enable_new_ir_api = false"
please set FLAGS_enable_pir_api = false"
)
paddle.base.framework.set_flags(old_flag)
paddle.base.framework.global_var._use_pir_api_ = old_flag[
"FLAGS_enable_new_ir_api"
"FLAGS_enable_pir_api"
]

def __enter__(self):
self.in_dygraph_outside = paddle.base.framework.in_dygraph_mode()
if self.in_dygraph_outside:
paddle.enable_static()
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
paddle.framework.set_flags({"FLAGS_enable_pir_api": True})
paddle.base.framework.global_var._use_pir_api_ = True
self._switch_to_pir()

def __exit__(self, exc_type, exc_val, exc_tb):
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
paddle.framework.set_flags({"FLAGS_enable_pir_api": False})
paddle.base.framework.global_var._use_pir_api_ = False
self._switch_to_old_ir()
if self.in_dygraph_outside:
paddle.disable_static()

def _switch_to_pir(self):
if paddle.base.framework.get_flags("FLAGS_enable_new_ir_api")[
"FLAGS_enable_new_ir_api"
if paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
"FLAGS_enable_pir_api"
]:
paddle.framework.set_flags(
{"FLAGS_enable_new_ir_in_executor": True}
Expand All @@ -76,8 +76,8 @@ def _switch_to_pir(self):
)

def _switch_to_old_ir(self):
if not paddle.base.framework.get_flags("FLAGS_enable_new_ir_api")[
"FLAGS_enable_new_ir_api"
if not paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
"FLAGS_enable_pir_api"
]:
paddle.framework.set_flags(
{"FLAGS_enable_new_ir_in_executor": False}
Expand All @@ -93,5 +93,5 @@ def _switch_to_old_ir(self):
else:
raise RuntimeError(
"IrGuard._switch_to_old_ir only work when paddle.framework.in_pir_mode() is false, \
please set FLAGS_enable_new_ir_api = false"
please set FLAGS_enable_pir_api = false"
)
2 changes: 1 addition & 1 deletion test/ir/new_ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ foreach(target ${TEST_INTERP_CASES})
endforeach()

foreach(target ${TEST_IR_SYSTEM_CASES})
py_test_modules(${target} MODULES ${target} ENVS FLAGS_enable_new_ir_api=true)
py_test_modules(${target} MODULES ${target} ENVS FLAGS_enable_pir_api=true)
endforeach()

set_tests_properties(test_pd_inplace_pass PROPERTIES TIMEOUT 60)
6 changes: 3 additions & 3 deletions test/ir/new_ir/test_ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_ir_program_0():

class TesBackward_1(unittest.TestCase):
def tearDown(self) -> None:
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
paddle.framework.set_flags({"FLAGS_enable_pir_api": False})

def test_grad(self):
newir_program = get_ir_program_0()
Expand Down Expand Up @@ -155,7 +155,7 @@ def get_ir_program_1():

class TesBackward_2(unittest.TestCase):
def tearDown(self) -> None:
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
paddle.framework.set_flags({"FLAGS_enable_pir_api": False})

def test_add_n(self):
newir_program = get_ir_program_1()
Expand Down Expand Up @@ -231,7 +231,7 @@ def get_ir_program_2():

class TestBackward_3(unittest.TestCase):
def tearDown(self) -> None:
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
paddle.framework.set_flags({"FLAGS_enable_pir_api": False})

def test_basic_network(self):
newir_program = get_ir_program_2()
Expand Down
2 changes: 1 addition & 1 deletion test/prim/new_ir_prim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set(TEST_PRIM_PURE_NEW_IR_CASES test_prim_program test_prim_simpnet

foreach(target ${TEST_PRIM_PURE_NEW_IR_CASES})
py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1
FLAGS_enable_new_ir_api=true)
FLAGS_enable_pir_api=true)
endforeach()

file(
Expand Down

0 comments on commit 2a52a8c

Please sign in to comment.