Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allows non-strict load with distributed checkpoints #9613

Merged
merged 16 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ WORKDIR /workspace
# Install NeMo requirements
ARG TE_TAG=bfe21c3d68b0a9951e5716fb520045db53419c5e
ARG MODELOPT_VERSION=0.13.0
ARG MCORE_TAG=02871b4df8c69fac687ab6676c4246e936ce92d0
ARG MCORE_TAG=0595b32de74d568038f4957a9069e39c7f54b310
ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
RUN \
--mount=type=bind,source=requirements,target=requirements \
Expand All @@ -53,7 +53,8 @@ pip install --no-cache-dir --no-build-isolation --extra-index-url https://pypi.n
".[all]"

# Megatron Core installation
git clone https://github.com/NVIDIA/Megatron-LM.git && \
# TODO: revert original repo
mikolajblaz marked this conversation as resolved.
Show resolved Hide resolved
git clone https://github.com/mikolajblaz/Megatron-LM.git && \
mikolajblaz marked this conversation as resolved.
Show resolved Hide resolved
pushd Megatron-LM && \
git checkout ${MCORE_TAG} && \
pushd megatron/core/datasets && \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ model:
dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format.
dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU
dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint
dist_ckpt_load_strictness: null # defines checkpoint keys mismatch behavior (only during dist-ckpt load). Choices: assume_ok_unexpected (default - try loading without any check), log_all (log mismatches), raise_all (raise mismatches)

## Activation Checkpointing
# NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed.
Expand Down
50 changes: 22 additions & 28 deletions nemo/utils/callbacks/dist_ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from time import time
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import pytorch_lightning as pl
from lightning_fabric.plugins import CheckpointIO
Expand All @@ -33,6 +33,7 @@
from megatron.core.dist_checkpointing.dict_utils import extract_matching_values
from megatron.core.dist_checkpointing.mapping import ShardedBase
from megatron.core.dist_checkpointing.strategies import tensorstore
from megatron.core.dist_checkpointing.validation import StrictHandling

from nemo.utils.callbacks.torch_dist_async import AsyncCallsQueue, AsyncRequest, TorchDistAsyncSaveShardedStrategy

Expand Down Expand Up @@ -175,6 +176,9 @@ class DistributedCheckpointIO(AsyncCompatibleCheckpointIO):
load_directly_on_device (bool, optional): if True, loads the weights directly
on GPU. Has effect only for `zarr` based checkpoints (PyT Distributed
always loads on device). Defaults to True.
load_strictness (StrictHandling, optional): defines loading strictness.
If not None, overwrites the `strict` flag passed to `load_checkpoint`.
Defaults to None.
async_save (bool): whether to save asynchronously. Should be set to True if
this class will be wrapped with AsyncFinalizableCheckpointIO.
"""
Expand All @@ -183,6 +187,7 @@ def __init__(
self,
save_ckpt_format: str,
load_directly_on_device: bool = True,
load_strictness: Optional['StrictHandling'] = None,
async_save: bool = False,
):
super().__init__()
Expand All @@ -191,6 +196,7 @@ def __init__(

self.save_ckpt_format = save_ckpt_format
self.load_directly_on_device = load_directly_on_device
self.load_strictness = load_strictness
self.async_save = async_save
self.save_sharded_strategy = self._determine_dist_ckpt_save_strategy()

Expand All @@ -207,6 +213,7 @@ def from_config(cls, model_cfg: dict, async_save: bool = False):
return cls(
save_ckpt_format=model_cfg.get('dist_ckpt_format', 'zarr'),
load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True),
load_strictness=model_cfg.get('dist_ckpt_load_strictness', None),
async_save=async_save,
)

Expand Down Expand Up @@ -241,7 +248,7 @@ def load_checkpoint(
path: _PATH,
map_location: Optional[Any] = None,
sharded_state_dict: Dict[str, Any] = None,
strict: Optional[bool] = True,
strict: Union[None, bool, 'StrictHandling'] = None,
validate_access_integrity: Optional[bool] = True,
) -> Dict[str, Any]:
"""Loads a distributed checkpoint.
Expand All @@ -253,6 +260,10 @@ def load_checkpoint(
defines the loading procedure for the distributed checkpoint.
Defaults to None to comply with the CheckpointIO interface,
but it's a required argument.
strict (bool, StrictHandling, optional): adjust load strictness. bool value
is translated to StrictHandling instance. Gets overwritten by
`self.load_strictness`. Defaults to None. If `self.load_strictness`
is also None, strict becomes StrictHandling.ASSUME_OK_UNEXPECTED.

Returns:
Dist[str, Any]: loaded checkpoint.
Expand All @@ -267,40 +278,23 @@ def load_checkpoint(
else:
sharded_strategy = None

if not strict:
sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict)
if isinstance(strict, bool):
strict = StrictHandling.ASSUME_OK_UNEXPECTED if strict else StrictHandling.LOG_ALL
if self.load_strictness is not None:
# Overwrites function argument
strict = self.load_strictness
if strict is None:
# Default behavior
strict = StrictHandling.ASSUME_OK_UNEXPECTED

return dist_checkpointing.load(
sharded_state_dict=sharded_state_dict,
checkpoint_dir=path,
sharded_strategy=sharded_strategy,
validate_access_integrity=validate_access_integrity,
strict=strict,
)

def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]):
ckpt_sharded_metadata = dist_checkpointing.load_tensors_metadata(path)
loaded_keys = []
missing_keys = []
unexpected_keys = []

def should_remove_missing_sharded_base(x: Any):
if isinstance(x, ShardedBase):
if x.key in ckpt_sharded_metadata:
loaded_keys.append(x.key)
return False
else:
unexpected_keys.append(x.key)
return True
return False

_, sharded_state_dict = extract_matching_values(sharded_state_dict, should_remove_missing_sharded_base)
logging.info(f'The following keys are not in the checkpoint and will not be loaded: {unexpected_keys}')

# TODO: compute missing_keys by:
# 1. all_gather_object of loaded_keys
# 2. missing_keys = ckpt_sharded_metadata.keys() - loaded_keys
return sharded_state_dict

@_debug_time('DistributedCheckpointIO.remove_checkpoint')
def remove_checkpoint(self, path: _PATH) -> None:
"""Remove a distributed checkpoint.
Expand Down
Loading