From 9e47ae69b181032bc46e6d8589bbbda857f592ce Mon Sep 17 00:00:00 2001 From: will-jl944 Date: Mon, 6 Jan 2025 20:37:24 +0800 Subject: [PATCH] Update npu recompute unittests to cover offloading cases --- .../test_dygraph_recompute_for_eager.py | 2 +- .../npu/tests/unittests/test_sum_op_npu.py | 26 ------------------- 2 files changed, 1 insertion(+), 27 deletions(-) diff --git a/backends/npu/tests/unittests/test_dygraph_recompute_for_eager.py b/backends/npu/tests/unittests/test_dygraph_recompute_for_eager.py index 6a802cebd..961863d72 100644 --- a/backends/npu/tests/unittests/test_dygraph_recompute_for_eager.py +++ b/backends/npu/tests/unittests/test_dygraph_recompute_for_eager.py @@ -297,7 +297,7 @@ def test_recompute_kwargs(self): pos = paddle.randn(shape=[10, 10], dtype="float32") pos.stop_gradient = False - kwargs = {"pos": pos, "use_reentrant": True} + kwargs = {"pos": pos, "use_reentrant": True, "offload_indices": [0]} loss_ref, param_ref, grad_ref = run_model( recompute_block=[2], recompute_kwargs=kwargs ) diff --git a/backends/npu/tests/unittests/test_sum_op_npu.py b/backends/npu/tests/unittests/test_sum_op_npu.py index d18d8af80..e281332d7 100644 --- a/backends/npu/tests/unittests/test_sum_op_npu.py +++ b/backends/npu/tests/unittests/test_sum_op_npu.py @@ -112,31 +112,5 @@ def test_check_output(self): self.check_output_with_place(self.place) -class TestSum4(OpTest): - def setUp(self): - self.set_npu() - self.init_dtype() - self.op_type = "sum" - self.place = paddle.CustomPlace("npu", 0) - - x0 = np.random.random((3, 40)).astype(self.dtype) - x1 = np.random.random((0, 0)).astype(self.dtype) - x2 = np.random.random((3, 40)).astype(self.dtype) - self.inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]} - y = x0 + x2 - self.outputs = {"Out": y} - - self.attrs = {"use_mkldnn": False} - - def init_dtype(self): - self.dtype = np.float32 - - def set_npu(self): - self.__class__.use_custom_device = True - - def test_check_output(self): - self.check_output_with_place(self.place) - - if __name__ == "__main__": unittest.main()