Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SAC BC breaking and renaming to ac_freq #397

Merged
merged 3 commits into from
Jun 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
SequenceParallel,
)

from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint
from torch.utils.checkpoint import checkpoint

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import logger
Expand All @@ -45,6 +45,7 @@
# currently selective per op and per layer checkpointing are supported
def checkpoint_wrapper(module, config):
if config.mode == "selective" and config.selective_ac_option == "op":
from torch.utils.checkpoint import create_selective_checkpoint_contexts

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, sorry for bc-breaking! I've updated the top of the PR description to have a list of things to restore prior behavior -

Two updates below that are still needed for your case:

  • return a CheckpointPolicy Enum instead of bool
  • accept a ctx: SelectiveCheckpoint instead of mode: str
def _custom_policy(ctx, func, *args, **kwargs):
      mode = "recompute" if ctx.is_recompute else "forward"
      mm_count_key = f"{mode}_mm_count"
      if func == torch.ops.aten.mm.default:
          meta[mm_count_key] += 1
      # Saves output of all compute ops, except every second mm
      to_save =  func in no_recompute_list and not (
          func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
      )
      return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE


def _get_custom_policy(meta):
def _custom_policy(mode, func, *args, **kwargs):
Expand All @@ -60,7 +61,7 @@ def _custom_policy(mode, func, *args, **kwargs):

def selective_checkpointing_context_fn():
meta = defaultdict(int)
return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @wanchaol, I think this PR (or #397) might have broken SAC + compile. I didn't verify though since _pt2_selective_checkpoint_context_fn_gen isn't in PyTorch trunk anymore. Do you think it's due to this change, or SAC + compile was already broken?

return create_selective_checkpoint_contexts(_get_custom_policy(meta))

return ptd_checkpoint_wrapper(
module,
Expand All @@ -86,15 +87,15 @@ def selective_checkpointing_context_fn():
1 == checkpointing every one (all).
2 == checkpoint every 2nd one
"""
every_x_layer = int(config.selective_ac_option)
ac_freq = int(config.selective_ac_option)
assert (
every_x_layer >= 0
), f"selective layer AC policy (every_x_layer) expects a positive integer, received {every_x_layer}"
ac_freq >= 0
), f"selective layer AC policy (ac_freq) expects a positive integer, received {ac_freq}"

checkpoint_wrapper.__dict__.setdefault("_count", 0)

checkpoint_wrapper._count += 1
if not every_x_layer or checkpoint_wrapper._count % every_x_layer == 0:
if not ac_freq or checkpoint_wrapper._count % ac_freq == 0:
return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
Expand Down
Loading