Skip to content

Commit

Permalink
SAC API follow ups to restore old behavior
Browse files Browse the repository at this point in the history
see #397
  • Loading branch information
wanchaol committed Jun 13, 2024
1 parent 8895f41 commit 12efb2a
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,22 @@
# currently selective per op and per layer checkpointing are supported
def checkpoint_wrapper(module, config):
if config.mode == "selective" and config.selective_ac_option == "op":
from torch.utils.checkpoint import create_selective_checkpoint_contexts
from torch.utils.checkpoint import (
CheckpointPolicy,
create_selective_checkpoint_contexts
)

def _get_custom_policy(meta):
def _custom_policy(mode, func, *args, **kwargs):
def _custom_policy(ctx, func, *args, **kwargs):
mode = "recompute" if ctx.is_recompute else "forward"
mm_count_key = f"{mode}_mm_count"
if func == torch.ops.aten.mm.default:
meta[mm_count_key] += 1
# Saves output of all compute ops, except every second mm
return func in no_recompute_list and not (
to_save = func in no_recompute_list and not (
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
)
return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE

return _custom_policy

Expand Down

0 comments on commit 12efb2a

Please sign in to comment.