Skip to content

Commit

Permalink
Update npu recompute unittests to cover offloading cases
Browse files Browse the repository at this point in the history
  • Loading branch information
will-jl944 committed Jan 6, 2025
1 parent 43b1e12 commit e9c6439
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def test_recompute_kwargs(self):
pos = paddle.randn(shape=[10, 10], dtype="float32")
pos.stop_gradient = False

kwargs = {"pos": pos, "use_reentrant": True}
kwargs = {"pos": pos, "use_reentrant": True, "offload_indices": [0]}
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2], recompute_kwargs=kwargs
)
Expand Down

0 comments on commit e9c6439

Please sign in to comment.