Skip to content

Commit

Permalink
[checkpoint] Log timings for checkpoint IO save and load (#11972)
Browse files Browse the repository at this point in the history
* [checkpoint] Log timings for checkpoint IO save and load

Signed-off-by: Ananth Subramaniam <[email protected]>

* add rank to logline for megatron strategy

Signed-off-by: Ananth Subramaniam <[email protected]>

---------

Signed-off-by: Ananth Subramaniam <[email protected]>
  • Loading branch information
ananthsub authored Feb 27, 2025
1 parent f0ee21d commit b94bff7
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 20 deletions.
36 changes: 30 additions & 6 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import shutil
import tempfile
import time
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
from pathlib import Path
Expand Down Expand Up @@ -256,6 +257,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
assert self.model is not None

def setup_distributed(self, global_rank: int = None, world_size: int = None) -> None:
"""Set up distributed environment."""
# call PTL init ddp
super().setup_distributed()

Expand Down Expand Up @@ -391,13 +393,13 @@ def get_safe(param_id):
def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None
) -> None:
app_state = AppState()
""" PTL method which we override to accomodate distributed checkpoints and
the legacy model parallel checkpoints.
"""PTL method which we override to accomodate distributed checkpoints and
the legacy model parallel checkpoints.
When using megatron core, the distributed checkpointing library expects save functions to be
called on every rank and internally does the rank checking.
When using megatron core, the distributed checkpointing library expects save functions to be
called on every rank and internally does the rank checking.
"""
app_state = AppState()
# check if using distributed checkpointing
if self.use_distributed_checkpointing:
# Check whether to save optim states
Expand Down Expand Up @@ -434,6 +436,7 @@ def save_checkpoint(
# PTL 2.2 supports non strict loading of the ckpt with the strict arg
# (https://github.com/Lightning-AI/pytorch-lightning/pull/19404)
def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None:
"""Load lightning module state dict."""
# if using distributed checkpointing, the state dict logic is at the model level
if self.use_distributed_checkpointing:
return
Expand Down Expand Up @@ -552,6 +555,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path], load_optimizer_stat
"""

fs = get_filesystem(checkpoint_path)
app_state = AppState()

# Check if using distributed checkpointing
if self.use_distributed_checkpointing:
Expand Down Expand Up @@ -593,7 +597,17 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path], load_optimizer_stat
if not fs.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint at {checkpoint_path} not found. Aborting training.")
torch.cuda.empty_cache()
return self.checkpoint_io.load_checkpoint(checkpoint_path)
start_time = time.monotonic()
checkpoint = self.checkpoint_io.load_checkpoint(checkpoint_path)
end_time = time.monotonic()
duration = end_time - start_time
logging.info(
f"Global Checkpoint Load : "
f"Rank : {app_state.global_rank} : "
f"Start time : {start_time:.3f}s : "
f"Time spent in load_checkpoint: {duration:.3f}s"
)
return checkpoint

def _integrate_original_checkpoint_data(self, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
"""
Expand All @@ -616,6 +630,7 @@ def _integrate_original_checkpoint_data(self, checkpoint: Dict[str, Any]) -> Dic
return checkpoint

def remove_checkpoint(self, filepath: Union[str, Path]) -> None:
"""Delete checkpoint saved at filepath."""
# check if filepath is a distributed checkpoint
if self.use_distributed_checkpointing:
if self.is_global_zero:
Expand All @@ -632,6 +647,7 @@ def remove_checkpoint(self, filepath: Union[str, Path]) -> None:

@property
def use_distributed_checkpointing(self):
"""Whether to use distributed checkpointing from megatron core."""
has_dist_ckpt_io = HAVE_MEGATRON_CORE and isinstance(self.unwrapped_checkpoint_io, DistributedCheckpointIO)
has_sharded_state_dict = (
hasattr(self.lightning_module, 'sharded_state_dict')
Expand All @@ -651,6 +667,7 @@ def use_distributed_checkpointing(self):

@property
def distributed_sampler_kwargs(self):
"""Provide distributed sampler kwargs."""
app_state = AppState()
if app_state.model_parallel_size is not None:
# When using model parallel, data parallel groups are non-trivial and they
Expand Down Expand Up @@ -894,6 +911,7 @@ def optimizer_state(self, optimizer: torch.optim.Optimizer) -> Dict[str, torch.T
return optim_state_dict

def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict=None) -> None:
"""Load lightning module states from checkpoint."""
# Release strict state dict matching when using Megatron AMP-O2 to skip matching
# half-precision module wrapper module.
# TODO: Refactor this to be more generic.
Expand Down Expand Up @@ -1013,6 +1031,8 @@ def restore_checkpoint_after_setup(self) -> bool:


class NLPSaveRestoreConnector(SaveRestoreConnector):
"""Custom connector to support saving and restoring states."""

def __init__(self) -> None:
if not HAVE_APEX:
logging.warning(
Expand All @@ -1032,6 +1052,7 @@ def __init__(self) -> None:
super().__init__()

def save_to(self, model, save_path: str):
"""Save model to save path."""
app_state = AppState()

# Check if using distributed checkpointing
Expand Down Expand Up @@ -1159,6 +1180,7 @@ def dummy():
return super().save_to(model, save_path)

def modify_state_dict(self, conf, state_dict):
"""Remap keys in state dict."""
if conf.get('megatron_legacy', False):
new_state_dict = {}
for key in state_dict.keys():
Expand Down Expand Up @@ -1712,6 +1734,7 @@ def optimizer_step(
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
"""Run optimizer step and scale gradients, if necessary."""
assert isinstance(
optimizer, MainParamsOptimizerWrapper
), "MegatronHalfPrecisionPlugin supports only the optimizer with master parameters"
Expand Down Expand Up @@ -1794,6 +1817,7 @@ def init_train_tqdm(self):
return self.bar

def on_train_epoch_start(self, trainer, *_):
"""Override parent class on_train_epoch_start to initialize the progress bar state."""
# Use trainer.max_steps as the num_training_batches since len(dataloader) aka num_training_batches
# is returned as the total num of micro batches instead of total num of global batches with this PR:
# https://github.com/NVIDIA/NeMo/pull/8426
Expand Down
52 changes: 45 additions & 7 deletions nemo/lightning/io/pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, Generic, Optional, TypeVar, Union
Expand All @@ -39,16 +39,14 @@
from nemo.lightning.ckpt_utils import WEIGHTS_PATH, ckpt_to_dir
from nemo.lightning.io.capture import IOProtocol
from nemo.lightning.io.mixin import IOMixin
from nemo.utils import logging

try:
from nemo.utils.callbacks.dist_ckpt_io import AsyncCompatibleCheckpointIO
except ImportError:
AsyncCompatibleCheckpointIO = CheckpointIO


log = logging.getLogger(__name__)


LightningModuleT = TypeVar("LightningModuleT", bound=pl.LightningModule)
ModuleT = TypeVar("ModuleT", bound=nn.Module)

Expand Down Expand Up @@ -197,13 +195,35 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
validate_sharding_integrity = not (self.validated_consistency and self.assume_constant_structure)
self.validated_consistency = True

return dist_checkpointing.save(
rank = torch.distributed.get_rank()
iteration = _get_iteration_from_checkpoint(checkpoint)
start_time = time.time()
async_save_request = dist_checkpointing.save(
sharded_state_dict=checkpoint,
checkpoint_dir=checkpoint_dir,
sharded_strategy=self.save_sharded_strategy,
validate_access_integrity=validate_sharding_integrity,
async_sharded_save=self.async_save,
)
end_time = time.time()
log_parts = (
"Global Checkpoint Save",
f"Rank: {rank}",
f"Iteration: {iteration}" if iteration is not None else None,
f"Start time: {start_time:.3f}s",
f"Save duration: {end_time - start_time:.3f}s",
)
log_message = " : ".join(part for part in log_parts if part is not None)
logging.info(log_message)

def iter_finalize_fn():
logging.info(f'Successfully saved checkpoint from iteration {int(iteration):7d} to {path}')

if self.async_save:
assert async_save_request is not None
async_save_request.add_finalize_fn(iter_finalize_fn)

return async_save_request

@override
def load_checkpoint(
Expand Down Expand Up @@ -272,14 +292,22 @@ def load_checkpoint(
# Default behavior
strict = StrictHandling.ASSUME_OK_UNEXPECTED

start_time = time.time()
checkpoint = dist_checkpointing.load(
sharded_state_dict=sharded_state_dict,
checkpoint_dir=str(path),
sharded_strategy=sharded_strategy,
strict=strict,
)
checkpoint = _fix_tensors_device(checkpoint)

end_time = time.time()
duration = end_time - start_time
logging.info(
"Global Checkpoint Load : "
f"Rank : {torch.distributed.get_rank()} : "
f"Start time : {start_time:.3f}s : "
f"Time spent in load_checkpoint: {duration:.3f}s"
)
return checkpoint

@override
Expand All @@ -293,7 +321,7 @@ def remove_checkpoint(self, path: _PATH) -> None:
fs = get_filesystem(path)
if fs.exists(path):
fs.rm(path, recursive=True)
log.debug(f"Removed checkpoint: {path}")
logging.debug(f"Removed checkpoint: {path}")

def _determine_dist_ckpt_save_strategy(self):
"""Determine the saving strategy based on constructor args.
Expand Down Expand Up @@ -433,3 +461,13 @@ def is_distributed_ckpt(path) -> bool:
checkpoint_dir = ckpt_to_dir(path)
fs = get_filesystem(checkpoint_dir)
return fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir)


def _get_iteration_from_checkpoint(checkpoint: Dict[str, Any]) -> Optional[int]:
return (
checkpoint.get("loops", {})
.get("fit_loop", {})
.get("epoch_loop.batch_progress", {})
.get("total", {})
.get("completed", None)
)
Loading

0 comments on commit b94bff7

Please sign in to comment.