diff --git a/composer/checkpoint/__init__.py b/composer/checkpoint/__init__.py index d4b21c790d..33162fc5e6 100644 --- a/composer/checkpoint/__init__.py +++ b/composer/checkpoint/__init__.py @@ -3,6 +3,7 @@ """Module for checkpointing API.""" +from composer.checkpoint.download import download_monolithic_checkpoint from composer.checkpoint.state_dict import ( get_metadata_state_dict, get_model_state_dict, @@ -15,4 +16,5 @@ 'get_optim_state_dict', 'get_metadata_state_dict', 'get_resumption_state_dict', + 'download_monolithic_checkpoint', ] diff --git a/composer/checkpoint/download.py b/composer/checkpoint/download.py new file mode 100644 index 0000000000..01a80beb5f --- /dev/null +++ b/composer/checkpoint/download.py @@ -0,0 +1,85 @@ +# Copyright 2024 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Useful functions for load checkpoints from remote object store or local disk.""" + +import logging +from typing import Optional + +from composer.utils import ( + dist, + extract_path_from_symlink, + maybe_create_object_store_from_uri, + parse_uri, + retry, +) + +log = logging.getLogger(__name__) + + +def download_file( + source_uri: str, + destination_path: str, + node_ranks: Optional[list[int]] = None, + num_attempts: int = 5, +): + """Downloads a file (object) from the specified URI to the specified directory. + + Args: + source_uri (str): The URI to download the file from or a symlink to the URI. + destination_path (str): The directory to download the file to. + node_ranks (list[int]): The ranks of the nodes that will download the file. If None, all nodes will download the file. + num_attempts (int): Retry for object store downloads. Default to 5. + """ + # Only local rank 0 downloads + local_rank = dist.get_local_rank() + if local_rank != 0: + return + + node_rank = dist.get_node_rank() + if node_ranks is not None and node_rank not in node_ranks: + return + + object_store = maybe_create_object_store_from_uri(source_uri) + _, _, source_path = parse_uri(source_uri) + if source_uri.endswith('.symlink'): + source_path = extract_path_from_symlink(source_path, object_store) + assert object_store is not None + + @retry(num_attempts=num_attempts) + def _download(): + object_store.download_object( + object_name=source_path, + filename=destination_path, + ) + + log.debug(f'Downloading {source_path} to {destination_path}') + _download() + log.debug(f'Finished downloading {source_path} to {destination_path}') + + +def download_monolithic_checkpoint( + source_uri: str, + destination_path: str, + global_rank_zero_only: bool = True, +): + """Downloads a monolithic checkpoint from the specified URI to the specified directory. + + Args: + source_uri (str): The URI to download the checkpoint from or symlink that points to the URI. + destination_path (str): The directory to download the checkpoint to. + global_rank_zero_only (bool): If True, only rank 0 will download the checkpoint. + broadcast_files_to_other_nodes (bool): If True, the downloaded checkpoint will be broadcast to all other nodes. + If torch syncs modules states this is unnecessary. + """ + node_ranks = None + if global_rank_zero_only: + node_ranks = [0] + download_file( + source_uri=source_uri, + destination_path=destination_path, + node_ranks=node_ranks, + ) + if not global_rank_zero_only or (global_rank_zero_only and dist.get_global_rank() == 0): + return destination_path + return None diff --git a/composer/utils/__init__.py b/composer/utils/__init__.py index 0850fd2bdd..20fa44e092 100644 --- a/composer/utils/__init__.py +++ b/composer/utils/__init__.py @@ -37,6 +37,7 @@ create_symlink_file, ensure_folder_has_no_conflicting_files, ensure_folder_is_empty, + extract_path_from_symlink, format_name_with_dist, format_name_with_dist_and_time, get_file, @@ -158,6 +159,7 @@ 'ParallelismConfig', 'MLFLOW_EXPERIMENT_ID_FORMAT_KEY', 'MLFLOW_RUN_ID_FORMAT_KEY', + 'extract_path_from_symlink', 'RemoteUploader', 'validate_credentials', 'build_remote_backend', diff --git a/composer/utils/file_helpers.py b/composer/utils/file_helpers.py index 11d10328ea..4f458b0a8e 100644 --- a/composer/utils/file_helpers.py +++ b/composer/utils/file_helpers.py @@ -49,6 +49,7 @@ 'maybe_create_object_store_from_uri', 'maybe_create_remote_uploader_downloader_from_uri', 'parse_uri', + 'extract_path_from_symlink', 'validate_credentials', ] @@ -57,6 +58,16 @@ def extract_path_from_symlink( source_path: str, object_store: Optional[Union[LoggerDestination, ObjectStore]] = None, ) -> str: + """Returns the checkpont path from symlink file. + + Args: + source_path(str): The remote symlink path. + object_store(LoggerDestination | ObjectStore, optional): The object store + used to download the remote symlink file + + Returns: + str: The content of the remote symlink file. + """ if object_store is not None: with tempfile.TemporaryDirectory() as tmpdir: _, _, source_path = parse_uri(source_path) diff --git a/tests/checkpoint/test_download.py b/tests/checkpoint/test_download.py new file mode 100644 index 0000000000..98c937bac4 --- /dev/null +++ b/tests/checkpoint/test_download.py @@ -0,0 +1,56 @@ +# Copyright 2024 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +import os +import tempfile +from unittest.mock import patch + +import pytest +import torch + +from composer.checkpoint import download_monolithic_checkpoint +from composer.utils import dist +from tests.checkpoint.helpers import init_model +from tests.common.markers import world_size +from tests.utils.test_remote_uploader import DummyObjectStore + + +@world_size(1, 2) +@pytest.mark.gpu +@pytest.mark.parametrize('rank_zero_only', [True, False]) +def test_download_monolithic_checkpoint(world_size: int, rank_zero_only: bool): + # Write a checkpoint + tmp_dir = tempfile.TemporaryDirectory() + use_fsdp = False + if world_size > 1: + use_fsdp = True + fsdp_model, _ = init_model(use_fsdp=use_fsdp) + + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + state = get_model_state_dict(fsdp_model, options=StateDictOptions(full_state_dict=True)) + + checkpoint_filename = 'state_dict' + save_filename = os.path.join(tmp_dir.name, checkpoint_filename) + if dist.get_global_rank() == 0: + torch.save(state, save_filename) + + class DummyS3ObjectStore(DummyObjectStore): + + def get_tmp_dir(self): + return tmp_dir + + # Download a monolithic checkpoint + local_file_name = 'state_dict.download' + with patch('composer.utils.file_helpers.S3ObjectStore', DummyS3ObjectStore): + ret = download_monolithic_checkpoint( + source_uri=f's3://bucket_name/{checkpoint_filename}', + destination_path=local_file_name, + global_rank_zero_only=rank_zero_only, + ) + dist.barrier() + + if rank_zero_only and dist.get_global_rank() != 0: + assert ret == None + if dist.get_global_rank() == 0: + assert ret == local_file_name + assert os.path.isfile(local_file_name) == True diff --git a/tests/utils/test_remote_uploader.py b/tests/utils/test_remote_uploader.py index 2e41e91d18..100e64ecf0 100644 --- a/tests/utils/test_remote_uploader.py +++ b/tests/utils/test_remote_uploader.py @@ -57,6 +57,8 @@ def download_object( overwrite: bool = False, callback: Optional[Callable[[int, int], None]] = None, ): + if overwrite is False and os.path.isfile(filename): + raise FileExistsError(f'The file at {filename} already exists and overwrite is set to False.') object_path = pathlib.Path(self.root) / object_name shutil.copy2(object_path, filename)