-
Notifications
You must be signed in to change notification settings - Fork 298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix SAC BC breaking and renaming to ac_freq #397
Conversation
as titled, SAC moved to a different public API, move to the new API to avoid CI breaking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix! Not sure why CI failed on PP related jobs
@@ -45,6 +45,7 @@ | |||
# currently selective per op and per layer checkpointing are supported | |||
def checkpoint_wrapper(module, config): | |||
if config.mode == "selective" and config.selective_ac_option == "op": | |||
from torch.utils.checkpoint import create_selective_checkpoint_contexts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, sorry for bc-breaking! I've updated the top of the PR description to have a list of things to restore prior behavior -
Two updates below that are still needed for your case:
- return a CheckpointPolicy Enum instead of bool
- accept a
ctx: SelectiveCheckpoint
instead ofmode: str
def _custom_policy(ctx, func, *args, **kwargs):
mode = "recompute" if ctx.is_recompute else "forward"
mm_count_key = f"{mode}_mm_count"
if func == torch.ops.aten.mm.default:
meta[mm_count_key] += 1
# Saves output of all compute ops, except every second mm
to_save = func in no_recompute_list and not (
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
)
return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE
@@ -60,7 +61,7 @@ def _custom_policy(mode, func, *args, **kwargs): | |||
|
|||
def selective_checkpointing_context_fn(): | |||
meta = defaultdict(int) | |||
return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as titled, SAC moved to a different public API, move to the new API to avoid CI breaking
as titled, SAC moved to a different public API, move to the new API to avoid CI breaking
as titled, SAC moved to a different public API, move to the new API to avoid CI breaking