Skip to content

Commit

Permalink
Update npu recompute unittests to cover offloading cases
Browse files Browse the repository at this point in the history
  • Loading branch information
will-jl944 committed Jan 7, 2025
1 parent 43b1e12 commit 9e47ae6
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def test_recompute_kwargs(self):
pos = paddle.randn(shape=[10, 10], dtype="float32")
pos.stop_gradient = False

kwargs = {"pos": pos, "use_reentrant": True}
kwargs = {"pos": pos, "use_reentrant": True, "offload_indices": [0]}
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2], recompute_kwargs=kwargs
)
Expand Down
26 changes: 0 additions & 26 deletions backends/npu/tests/unittests/test_sum_op_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,31 +112,5 @@ def test_check_output(self):
self.check_output_with_place(self.place)


class TestSum4(OpTest):
def setUp(self):
self.set_npu()
self.init_dtype()
self.op_type = "sum"
self.place = paddle.CustomPlace("npu", 0)

x0 = np.random.random((3, 40)).astype(self.dtype)
x1 = np.random.random((0, 0)).astype(self.dtype)
x2 = np.random.random((3, 40)).astype(self.dtype)
self.inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]}
y = x0 + x2
self.outputs = {"Out": y}

self.attrs = {"use_mkldnn": False}

def init_dtype(self):
self.dtype = np.float32

def set_npu(self):
self.__class__.use_custom_device = True

def test_check_output(self):
self.check_output_with_place(self.place)


if __name__ == "__main__":
unittest.main()

0 comments on commit 9e47ae6

Please sign in to comment.