From aaec79106600d99c7a89886e6dcc64cf5eb5f3d2 Mon Sep 17 00:00:00 2001 From: Zhenghai Zhang <65210872+ccsuzzh@users.noreply.github.com> Date: Fri, 10 Nov 2023 17:07:05 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PIR=20API=20adaptor=20No.133-135?= =?UTF-8?q?=E3=80=91Migrate=20paddle.logical=5Fnot/logical=5For/logical=5F?= =?UTF-8?q?xor=20into=20pir=20(#58781)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/tensor/logic.py | 6 +++--- test/legacy_test/test_logical_op.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 7e29f89a9de173..ec9f0fa67644aa 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -201,7 +201,7 @@ def logical_or(x, y, out=None, name=None): [[True , True ], [True , False]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.logical_or(x, y) return _logical_op( op_name="logical_or", x=x, y=y, name=name, out=out, binary_op=True @@ -262,7 +262,7 @@ def logical_xor(x, y, out=None, name=None): [[False, True ], [True , False]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.logical_xor(x, y) return _logical_op( @@ -322,7 +322,7 @@ def logical_not(x, out=None, name=None): Tensor(shape=[4], dtype=bool, place=Place(cpu), stop_gradient=True, [False, True , False, True ]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.logical_not(x) return _logical_op( op_name="logical_not", x=x, y=None, name=name, out=out, binary_op=False diff --git a/test/legacy_test/test_logical_op.py b/test/legacy_test/test_logical_op.py index 81dec36e2f698e..6c4dda65984eb4 100755 --- a/test/legacy_test/test_logical_op.py +++ b/test/legacy_test/test_logical_op.py @@ -19,7 +19,7 @@ import paddle from paddle.framework import in_dynamic_mode -from paddle.static import Executor, Program, program_guard +from paddle.pir_utils import test_with_pir_api SUPPORTED_DTYPES = [ bool, @@ -67,16 +67,15 @@ } -# @test_with_pir_api def run_static(x_np, y_np, op_str, use_gpu=False, binary_op=True): paddle.enable_static() - startup_program = Program() - main_program = Program() + startup_program = paddle.static.Program() + main_program = paddle.static.Program() place = paddle.CPUPlace() if use_gpu and paddle.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) - exe = Executor(place) - with program_guard(main_program, startup_program): + exe = paddle.static.Executor(place) + with paddle.static.program_guard(main_program, startup_program): x = paddle.static.data(name='x', shape=x_np.shape, dtype=x_np.dtype) op = getattr(paddle, op_str) feed_list = {'x': x_np} @@ -135,6 +134,7 @@ def np_data_generator(np_shape, dtype, *args, **kwargs): return np.random.normal(0, 1, np_shape).astype(dtype) +@test_with_pir_api def test(unit_test, use_gpu=False, test_error=False): for op_data in TEST_META_OP_DATA: meta_data = dict(op_data)