Skip to content

Commit

Permalink
Fix SAC BC breaking and renaming to ac_freq (#397)
Browse files Browse the repository at this point in the history
as titled, SAC moved to a different public API, move to the new API to
avoid CI breaking
  • Loading branch information
wanchaol authored Jun 13, 2024
1 parent a4cd9ba commit 33f301d
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
SequenceParallel,
)

from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint
from torch.utils.checkpoint import checkpoint

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import logger
Expand All @@ -45,6 +45,7 @@
# currently selective per op and per layer checkpointing are supported
def checkpoint_wrapper(module, config):
if config.mode == "selective" and config.selective_ac_option == "op":
from torch.utils.checkpoint import create_selective_checkpoint_contexts

def _get_custom_policy(meta):
def _custom_policy(mode, func, *args, **kwargs):
Expand All @@ -60,7 +61,7 @@ def _custom_policy(mode, func, *args, **kwargs):

def selective_checkpointing_context_fn():
meta = defaultdict(int)
return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta))
return create_selective_checkpoint_contexts(_get_custom_policy(meta))

return ptd_checkpoint_wrapper(
module,
Expand All @@ -86,15 +87,15 @@ def selective_checkpointing_context_fn():
1 == checkpointing every one (all).
2 == checkpoint every 2nd one
"""
every_x_layer = int(config.selective_ac_option)
ac_freq = int(config.selective_ac_option)
assert (
every_x_layer >= 0
), f"selective layer AC policy (every_x_layer) expects a positive integer, received {every_x_layer}"
ac_freq >= 0
), f"selective layer AC policy (ac_freq) expects a positive integer, received {ac_freq}"

checkpoint_wrapper.__dict__.setdefault("_count", 0)

checkpoint_wrapper._count += 1
if not every_x_layer or checkpoint_wrapper._count % every_x_layer == 0:
if not ac_freq or checkpoint_wrapper._count % ac_freq == 0:
return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
Expand Down

0 comments on commit 33f301d

Please sign in to comment.