From 812ea87338b1138c80551484f9e768fc9c2a7dc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20B=C5=82a=C5=BC?= Date: Thu, 4 Jul 2024 12:59:03 +0200 Subject: [PATCH 01/10] Allow non-strict load MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Mikołaj Błaż --- .../conf/megatron_gpt_config.yaml | 1 + nemo/utils/callbacks/dist_ckpt_io.py | 42 +++++++------------ 2 files changed, 15 insertions(+), 28 deletions(-) diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index 98bf7d448845..794237812fad 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -177,6 +177,7 @@ model: dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format. dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint + dist_ckpt_load_strictness: 'assume_ok_unexpected' # defines checkpoint keys mismatch behavior (only during dist-ckpt load). Choices: assume_ok_unexpected (try loading without any check), log_all (log mismatches), raise_all (raise mismatches) ## Activation Checkpointing # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index 31ab0c84dd3a..437cdd9f3b18 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from time import time -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union import pytorch_lightning as pl from lightning_fabric.plugins import CheckpointIO @@ -33,6 +33,7 @@ from megatron.core.dist_checkpointing.dict_utils import extract_matching_values from megatron.core.dist_checkpointing.mapping import ShardedBase from megatron.core.dist_checkpointing.strategies import tensorstore + from megatron.core.dist_checkpointing.validation import StrictHandling from nemo.utils.callbacks.torch_dist_async import AsyncCallsQueue, AsyncRequest, TorchDistAsyncSaveShardedStrategy @@ -183,6 +184,7 @@ def __init__( self, save_ckpt_format: str, load_directly_on_device: bool = True, + load_strictness: StrictHandling = StrictHandling.ASSUME_OK_UNEXPECTED, async_save: bool = False, ): super().__init__() @@ -191,6 +193,7 @@ def __init__( self.save_ckpt_format = save_ckpt_format self.load_directly_on_device = load_directly_on_device + self.load_strictness = load_strictness self.async_save = async_save self.save_sharded_strategy = self._determine_dist_ckpt_save_strategy() @@ -207,6 +210,7 @@ def from_config(cls, model_cfg: dict, async_save: bool = False): return cls( save_ckpt_format=model_cfg.get('dist_ckpt_format', 'zarr'), load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True), + load_strictness=model_cfg.get('dist_ckpt_load_strictness', 'assume_ok_unexpected'), async_save=async_save, ) @@ -241,7 +245,7 @@ def load_checkpoint( path: _PATH, map_location: Optional[Any] = None, sharded_state_dict: Dict[str, Any] = None, - strict: Optional[bool] = True, + strict: Union[None, bool, StrictHandling] = None, validate_access_integrity: Optional[bool] = True, ) -> Dict[str, Any]: """Loads a distributed checkpoint. @@ -253,6 +257,9 @@ def load_checkpoint( defines the loading procedure for the distributed checkpoint. Defaults to None to comply with the CheckpointIO interface, but it's a required argument. + strict (bool, StrictHandling, optional): adjust load strictness. bool value + is translated to StrictHandling instance. Defaults to None, in which + case `self.load_strictness` is used as a default. Returns: Dist[str, Any]: loaded checkpoint. @@ -267,40 +274,19 @@ def load_checkpoint( else: sharded_strategy = None - if not strict: - sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict) + if strict is None: + strict = self.load_strictness + elif isinstance(strict, bool): + strict = StrictHandling.RAISE_ALL if strict else StrictHandling.LOG_ALL return dist_checkpointing.load( sharded_state_dict=sharded_state_dict, checkpoint_dir=path, sharded_strategy=sharded_strategy, validate_access_integrity=validate_access_integrity, + strict=strict, ) - def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]): - ckpt_sharded_metadata = dist_checkpointing.load_tensors_metadata(path) - loaded_keys = [] - missing_keys = [] - unexpected_keys = [] - - def should_remove_missing_sharded_base(x: Any): - if isinstance(x, ShardedBase): - if x.key in ckpt_sharded_metadata: - loaded_keys.append(x.key) - return False - else: - unexpected_keys.append(x.key) - return True - return False - - _, sharded_state_dict = extract_matching_values(sharded_state_dict, should_remove_missing_sharded_base) - logging.info(f'The following keys are not in the checkpoint and will not be loaded: {unexpected_keys}') - - # TODO: compute missing_keys by: - # 1. all_gather_object of loaded_keys - # 2. missing_keys = ckpt_sharded_metadata.keys() - loaded_keys - return sharded_state_dict - @_debug_time('DistributedCheckpointIO.remove_checkpoint') def remove_checkpoint(self, path: _PATH) -> None: """Remove a distributed checkpoint. From be985fb9a44bca5bccc466afcbf19364c1916864 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20B=C5=82a=C5=BC?= Date: Thu, 4 Jul 2024 13:07:21 +0200 Subject: [PATCH 02/10] Point to non-stric load MCore branch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Mikołaj Błaż --- Dockerfile.ci | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.ci b/Dockerfile.ci index b376aacd0bfe..e52389f5b7f6 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -34,7 +34,7 @@ WORKDIR /workspace # Install NeMo requirements ARG TE_TAG=bfe21c3d68b0a9951e5716fb520045db53419c5e ARG MODELOPT_VERSION=0.13.0 -ARG MCORE_TAG=02871b4df8c69fac687ab6676c4246e936ce92d0 +ARG MCORE_TAG=4168a374a3287ebd5b335db7829875a392419055 ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c RUN \ --mount=type=bind,source=requirements,target=requirements \ From cc8e4a1885006052ba258d121950352dfde5cdda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20B=C5=82a=C5=BC?= Date: Thu, 4 Jul 2024 13:12:13 +0200 Subject: [PATCH 03/10] Avoid module level StrictHandling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Mikołaj Błaż --- nemo/utils/callbacks/dist_ckpt_io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index 437cdd9f3b18..324c171a048d 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -184,7 +184,7 @@ def __init__( self, save_ckpt_format: str, load_directly_on_device: bool = True, - load_strictness: StrictHandling = StrictHandling.ASSUME_OK_UNEXPECTED, + load_strictness: 'StrictHandling' = 'assume_ok_unexpected', async_save: bool = False, ): super().__init__() @@ -245,7 +245,7 @@ def load_checkpoint( path: _PATH, map_location: Optional[Any] = None, sharded_state_dict: Dict[str, Any] = None, - strict: Union[None, bool, StrictHandling] = None, + strict: Union[None, bool, 'StrictHandling'] = None, validate_access_integrity: Optional[bool] = True, ) -> Dict[str, Any]: """Loads a distributed checkpoint. From bf9504225d7ee6905ba1abcc45b6bba20573c485 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20B=C5=82a=C5=BC?= Date: Thu, 4 Jul 2024 13:24:12 +0200 Subject: [PATCH 04/10] Use MCore fork MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Mikołaj Błaż --- Dockerfile.ci | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Dockerfile.ci b/Dockerfile.ci index e52389f5b7f6..42e9b3b33009 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -53,7 +53,8 @@ pip install --no-cache-dir --no-build-isolation --extra-index-url https://pypi.n ".[all]" # Megatron Core installation -git clone https://github.com/NVIDIA/Megatron-LM.git && \ +# TODO: revert original repo +git clone https://github.com/mikolajblaz/Megatron-LM.git && \ pushd Megatron-LM && \ git checkout ${MCORE_TAG} && \ pushd megatron/core/datasets && \ From a9a53f6273794d7fb8bddbb90519be1557f0ce6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20B=C5=82a=C5=BC?= Date: Thu, 4 Jul 2024 14:20:31 +0200 Subject: [PATCH 05/10] Update to MCore fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Mikołaj Błaż --- Dockerfile.ci | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.ci b/Dockerfile.ci index 42e9b3b33009..6d3e61d01f08 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -34,7 +34,7 @@ WORKDIR /workspace # Install NeMo requirements ARG TE_TAG=bfe21c3d68b0a9951e5716fb520045db53419c5e ARG MODELOPT_VERSION=0.13.0 -ARG MCORE_TAG=4168a374a3287ebd5b335db7829875a392419055 +ARG MCORE_TAG=e448c5ae6f2961fc11c31ddbae827d910faa93f0 ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c RUN \ --mount=type=bind,source=requirements,target=requirements \ From a08d87366fedf356f18be7ec9a0539bcf1806718 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20B=C5=82a=C5=BC?= Date: Thu, 4 Jul 2024 17:31:02 +0200 Subject: [PATCH 06/10] Restore ackward compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Mikołaj Błaż --- nemo/utils/callbacks/dist_ckpt_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index 324c171a048d..1f48562535f1 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -277,7 +277,7 @@ def load_checkpoint( if strict is None: strict = self.load_strictness elif isinstance(strict, bool): - strict = StrictHandling.RAISE_ALL if strict else StrictHandling.LOG_ALL + strict = StrictHandling.ASSUME_OK_UNEXPECTED if strict else StrictHandling.LOG_ALL return dist_checkpointing.load( sharded_state_dict=sharded_state_dict, From 096b10a5b58fc46aa4adaf33bc7a48067103f2b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20B=C5=82a=C5=BC?= Date: Thu, 4 Jul 2024 17:35:48 +0200 Subject: [PATCH 07/10] Update flag defaults MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Mikołaj Błaż --- .../conf/megatron_gpt_config.yaml | 2 +- nemo/utils/callbacks/dist_ckpt_io.py | 22 +++++++++++++------ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index 794237812fad..511e1ee7d460 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -177,7 +177,7 @@ model: dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format. dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint - dist_ckpt_load_strictness: 'assume_ok_unexpected' # defines checkpoint keys mismatch behavior (only during dist-ckpt load). Choices: assume_ok_unexpected (try loading without any check), log_all (log mismatches), raise_all (raise mismatches) + dist_ckpt_load_strictness: null # defines checkpoint keys mismatch behavior (only during dist-ckpt load). Choices: assume_ok_unexpected (default - try loading without any check), log_all (log mismatches), raise_all (raise mismatches) ## Activation Checkpointing # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index 1f48562535f1..ca9f4c79edfb 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -176,6 +176,9 @@ class DistributedCheckpointIO(AsyncCompatibleCheckpointIO): load_directly_on_device (bool, optional): if True, loads the weights directly on GPU. Has effect only for `zarr` based checkpoints (PyT Distributed always loads on device). Defaults to True. + load_strictness (StrictHandling, optional): defines loading strictness. + If not None, overwrites the `strict` flag passed to `load_checkpoint`. + Defaults to None. async_save (bool): whether to save asynchronously. Should be set to True if this class will be wrapped with AsyncFinalizableCheckpointIO. """ @@ -184,7 +187,7 @@ def __init__( self, save_ckpt_format: str, load_directly_on_device: bool = True, - load_strictness: 'StrictHandling' = 'assume_ok_unexpected', + load_strictness: Optional['StrictHandling'] = None, async_save: bool = False, ): super().__init__() @@ -210,7 +213,7 @@ def from_config(cls, model_cfg: dict, async_save: bool = False): return cls( save_ckpt_format=model_cfg.get('dist_ckpt_format', 'zarr'), load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True), - load_strictness=model_cfg.get('dist_ckpt_load_strictness', 'assume_ok_unexpected'), + load_strictness=model_cfg.get('dist_ckpt_load_strictness', None), async_save=async_save, ) @@ -258,8 +261,9 @@ def load_checkpoint( Defaults to None to comply with the CheckpointIO interface, but it's a required argument. strict (bool, StrictHandling, optional): adjust load strictness. bool value - is translated to StrictHandling instance. Defaults to None, in which - case `self.load_strictness` is used as a default. + is translated to StrictHandling instance. Gets overwritten by + `self.load_strictness`. Defaults to None. If `self.load_strictness` + is also None, strict becomes StrictHandling.ASSUME_OK_UNEXPECTED. Returns: Dist[str, Any]: loaded checkpoint. @@ -274,10 +278,14 @@ def load_checkpoint( else: sharded_strategy = None - if strict is None: - strict = self.load_strictness - elif isinstance(strict, bool): + if isinstance(strict, bool): strict = StrictHandling.ASSUME_OK_UNEXPECTED if strict else StrictHandling.LOG_ALL + if self.load_strictness is not None: + # Overwrites function argument + strict = self.load_strictness + if strict is None: + # Default behavior + strict = StrictHandling.ASSUME_OK_UNEXPECTED return dist_checkpointing.load( sharded_state_dict=sharded_state_dict, From e95d5cc86abba587a95c4b060db2335a946d38a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20B=C5=82a=C5=BC?= Date: Thu, 4 Jul 2024 17:36:01 +0200 Subject: [PATCH 08/10] Update MCore tag MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Mikołaj Błaż --- Dockerfile.ci | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.ci b/Dockerfile.ci index 6d3e61d01f08..6db7f9139639 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -34,7 +34,7 @@ WORKDIR /workspace # Install NeMo requirements ARG TE_TAG=bfe21c3d68b0a9951e5716fb520045db53419c5e ARG MODELOPT_VERSION=0.13.0 -ARG MCORE_TAG=e448c5ae6f2961fc11c31ddbae827d910faa93f0 +ARG MCORE_TAG=0595b32de74d568038f4957a9069e39c7f54b310 ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c RUN \ --mount=type=bind,source=requirements,target=requirements \ From ef618047ccc13de494a85c7e6de7a2f5e22860a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20B=C5=82a=C5=BC?= Date: Thu, 4 Jul 2024 18:35:59 +0200 Subject: [PATCH 09/10] Update PyT Dist interface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Mikołaj Błaż --- nemo/utils/callbacks/torch_dist_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/utils/callbacks/torch_dist_async.py b/nemo/utils/callbacks/torch_dist_async.py index 1cd226af9cdb..8faeec25f766 100644 --- a/nemo/utils/callbacks/torch_dist_async.py +++ b/nemo/utils/callbacks/torch_dist_async.py @@ -64,7 +64,7 @@ def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): # Use PyT saving mechanism writer = FileSystemWriterAsync(checkpoint_dir, thread_count=self.thread_count) - save_state_dict_ret = save_state_dict_async_plan( + save_state_dict_ret, _, _, _ = save_state_dict_async_plan( pyt_state_dict, writer, None, From 84064086cd9ba6788b6b96aec535cf7296cb2f05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20B=C5=82a=C5=BC?= Date: Wed, 10 Jul 2024 11:11:39 +0200 Subject: [PATCH 10/10] Update to latest core_r0.8.0 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Mikołaj Błaż --- Dockerfile.ci | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.ci b/Dockerfile.ci index f52120860259..2a7006c057f1 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -34,7 +34,7 @@ WORKDIR /workspace # Install NeMo requirements ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea ARG MODELOPT_VERSION=0.13.0 -ARG MCORE_TAG=0bc3547702464501feefeb5523b7a17e591b21fa +ARG MCORE_TAG=c0164bcfd4f8213a10a6b1e47ef80721a68b4fb6 ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c RUN \ --mount=type=bind,source=requirements,target=requirements \