Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SAC BC breaking and renaming to ac_freq #397

Merged
merged 3 commits into from
Jun 13, 2024
Merged

Fix SAC BC breaking and renaming to ac_freq #397

merged 3 commits into from
Jun 13, 2024

Conversation

wanchaol
Copy link
Collaborator

as titled, SAC moved to a different public API, move to the new API to avoid CI breaking

as titled, SAC moved to a different public API, move to the new API to
avoid CI breaking
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 13, 2024
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! Not sure why CI failed on PP related jobs

@wanchaol wanchaol merged commit 33f301d into main Jun 13, 2024
5 checks passed
@@ -45,6 +45,7 @@
# currently selective per op and per layer checkpointing are supported
def checkpoint_wrapper(module, config):
if config.mode == "selective" and config.selective_ac_option == "op":
from torch.utils.checkpoint import create_selective_checkpoint_contexts

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, sorry for bc-breaking! I've updated the top of the PR description to have a list of things to restore prior behavior -

Two updates below that are still needed for your case:

  • return a CheckpointPolicy Enum instead of bool
  • accept a ctx: SelectiveCheckpoint instead of mode: str
def _custom_policy(ctx, func, *args, **kwargs):
      mode = "recompute" if ctx.is_recompute else "forward"
      mm_count_key = f"{mode}_mm_count"
      if func == torch.ops.aten.mm.default:
          meta[mm_count_key] += 1
      # Saves output of all compute ops, except every second mm
      to_save =  func in no_recompute_list and not (
          func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
      )
      return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE

@@ -60,7 +61,7 @@ def _custom_policy(mode, func, *args, **kwargs):

def selective_checkpointing_context_fn():
meta = defaultdict(int)
return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @wanchaol, I think this PR (or #397) might have broken SAC + compile. I didn't verify though since _pt2_selective_checkpoint_context_fn_gen isn't in PyTorch trunk anymore. Do you think it's due to this change, or SAC + compile was already broken?

tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
as titled, SAC moved to a different public API, move to the new API to
avoid CI breaking
tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
tianyu-l pushed a commit that referenced this pull request Aug 16, 2024
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
as titled, SAC moved to a different public API, move to the new API to
avoid CI breaking
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants