Skip to content

Commit

Permalink
Allows non-strict load with distributed checkpoints (#9613) (#9715)
Browse files Browse the repository at this point in the history
* Allow non-strict load



* Point to non-stric load MCore branch



* Avoid module level StrictHandling



* Use MCore fork



* Update to MCore fix



* Restore ackward compatibility



* Update flag defaults



* Update MCore tag



* Update PyT Dist interface



* Update to latest core_r0.8.0



---------

Signed-off-by: Mikołaj Błaż <[email protected]>
Co-authored-by: mikolajblaz <[email protected]>
Signed-off-by: Tugrul Konuk <[email protected]>
  • Loading branch information
2 people authored and ertkonuk committed Jul 19, 2024
1 parent c7b2ead commit 3662b61
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 31 deletions.
6 changes: 3 additions & 3 deletions Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ WORKDIR /workspace
# Install NeMo requirements
ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea
ARG MODELOPT_VERSION=0.13.0
ARG MCORE_TAG=de1b7c223303f6ba21e0540f27361334116efcbc
ARG MCORE_TAG=c0164bcfd4f8213a10a6b1e47ef80721a68b4fb6
ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
RUN \
--mount=type=bind,source=requirements,target=requirements \
Expand Down Expand Up @@ -69,14 +69,14 @@ git clone https://github.com/state-spaces/mamba.git && \
git checkout v2.0.3 && \
python setup.py install && \
cd .. && \
rm -rf mamba
rm -rf mamba

git clone https://github.com/Dao-AILab/causal-conv1d && \
cd causal-conv1d && \
git checkout v1.2.2.post1 && \
python setup.py install && \
cd .. && \
rm -rf causal-conv1d
rm -rf causal-conv1d

EOF

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ model:
dist_ckpt_torch_dist_multiproc: 2 # number of extra processes per rank used during ckpt save with PyTorch distributed format
dist_ckpt_assume_constant_structure: False # set to True only if the state dict structure doesn't change within a single job. Allows caching some computation across checkpoint saves.
dist_ckpt_parallel_dist_opt: True # parallel save/load of a DistributedOptimizer. 'True' allows performant save and reshardable checkpoints. Set to 'False' only in order to minimize the number of checkpoint files.
dist_ckpt_load_strictness: null # defines checkpoint keys mismatch behavior (only during dist-ckpt load). Choices: assume_ok_unexpected (default - try loading without any check), log_all (log mismatches), raise_all (raise mismatches)

## Activation Checkpointing
# NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed.
Expand Down
50 changes: 22 additions & 28 deletions nemo/utils/callbacks/dist_ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from time import time
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import pytorch_lightning as pl
from lightning_fabric.plugins import CheckpointIO
Expand All @@ -44,6 +44,7 @@
FullyParallelSaveStrategyWrapper,
)
from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy
from megatron.core.dist_checkpointing.validation import StrictHandling
from megatron.core.parallel_state import get_data_parallel_group

HAVE_MEGATRON_CORE = True
Expand Down Expand Up @@ -188,6 +189,9 @@ class DistributedCheckpointIO(AsyncCompatibleCheckpointIO):
load_directly_on_device (bool, optional): if True, loads the weights directly
on GPU. Has effect only for `zarr` based checkpoints (PyT Distributed
always loads on device). Defaults to True.
load_strictness (StrictHandling, optional): defines loading strictness.
If not None, overwrites the `strict` flag passed to `load_checkpoint`.
Defaults to None.
async_save (bool): whether to save asynchronously. Should be set to True if
this class will be wrapped with AsyncFinalizableCheckpointIO.
torch_dist_multiproc (int, optional): number of extra processes per rank
Expand All @@ -202,6 +206,7 @@ def __init__(
self,
save_ckpt_format: str,
load_directly_on_device: bool = True,
load_strictness: Optional['StrictHandling'] = None,
async_save: bool = False,
torch_dist_multiproc: Optional[int] = None,
assume_constant_structure: bool = False,
Expand All @@ -215,6 +220,7 @@ def __init__(

self.save_ckpt_format = save_ckpt_format
self.load_directly_on_device = load_directly_on_device
self.load_strictness = load_strictness
self.async_save = async_save
self.torch_dist_multiproc = torch_dist_multiproc
self.assume_constant_structure = assume_constant_structure
Expand All @@ -238,6 +244,7 @@ def from_config(cls, model_cfg: dict, async_save: bool = False):
return cls(
save_ckpt_format=model_cfg.get('dist_ckpt_format', 'zarr'),
load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True),
load_strictness=model_cfg.get('dist_ckpt_load_strictness', None),
async_save=async_save,
torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None),
parallel_save=model_cfg.get('dist_ckpt_parallel_save', False),
Expand Down Expand Up @@ -275,7 +282,7 @@ def load_checkpoint(
path: _PATH,
map_location: Optional[Any] = None,
sharded_state_dict: Dict[str, Any] = None,
strict: Optional[bool] = True,
strict: Union[None, bool, 'StrictHandling'] = None,
validate_access_integrity: Optional[bool] = True,
) -> Dict[str, Any]:
"""Loads a distributed checkpoint.
Expand All @@ -287,6 +294,10 @@ def load_checkpoint(
defines the loading procedure for the distributed checkpoint.
Defaults to None to comply with the CheckpointIO interface,
but it's a required argument.
strict (bool, StrictHandling, optional): adjust load strictness. bool value
is translated to StrictHandling instance. Gets overwritten by
`self.load_strictness`. Defaults to None. If `self.load_strictness`
is also None, strict becomes StrictHandling.ASSUME_OK_UNEXPECTED.
Returns:
Dist[str, Any]: loaded checkpoint.
Expand All @@ -311,40 +322,23 @@ def load_checkpoint(
if sharded_strategy is not None:
logging.info(f'Using {sharded_strategy} dist-ckpt load strategy.')

if not strict:
sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict)
if isinstance(strict, bool):
strict = StrictHandling.ASSUME_OK_UNEXPECTED if strict else StrictHandling.LOG_ALL
if self.load_strictness is not None:
# Overwrites function argument
strict = self.load_strictness
if strict is None:
# Default behavior
strict = StrictHandling.ASSUME_OK_UNEXPECTED

return dist_checkpointing.load(
sharded_state_dict=sharded_state_dict,
checkpoint_dir=path,
sharded_strategy=sharded_strategy,
validate_access_integrity=validate_access_integrity,
strict=strict,
)

def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]):
ckpt_sharded_metadata = dist_checkpointing.load_tensors_metadata(path)
loaded_keys = []
missing_keys = []
unexpected_keys = []

def should_remove_missing_sharded_base(x: Any):
if isinstance(x, ShardedBase):
if x.key in ckpt_sharded_metadata:
loaded_keys.append(x.key)
return False
else:
unexpected_keys.append(x.key)
return True
return False

_, sharded_state_dict = extract_matching_values(sharded_state_dict, should_remove_missing_sharded_base)
logging.info(f'The following keys are not in the checkpoint and will not be loaded: {unexpected_keys}')

# TODO: compute missing_keys by:
# 1. all_gather_object of loaded_keys
# 2. missing_keys = ckpt_sharded_metadata.keys() - loaded_keys
return sharded_state_dict

@_debug_time('DistributedCheckpointIO.remove_checkpoint')
def remove_checkpoint(self, path: _PATH) -> None:
"""Remove a distributed checkpoint.
Expand Down

0 comments on commit 3662b61

Please sign in to comment.