Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[checkpoint v2] Download api #3447

Merged
merged 9 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions composer/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Module for checkpointing API."""

from composer.checkpoint.download import download_monolithic_checkpoint
from composer.checkpoint.state_dict import (
get_metadata_state_dict,
get_model_state_dict,
Expand All @@ -15,4 +16,5 @@
'get_optim_state_dict',
'get_metadata_state_dict',
'get_resumption_state_dict',
'download_monolithic_checkpoint',
]
85 changes: 85 additions & 0 deletions composer/checkpoint/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Useful functions for load checkpoints from remote object store or local disk."""

import logging
from typing import Optional

from composer.utils import (
dist,
extract_path_from_symlink,
maybe_create_object_store_from_uri,
parse_uri,
retry,
)

log = logging.getLogger(__name__)


def download_file(
source_uri: str,
destination_path: str,
node_ranks: Optional[list[int]] = None,
num_attempts: int = 5,
):
"""Downloads a file (object) from the specified URI to the specified directory.

Args:
source_uri (str): The URI to download the file from or a symlink to the URI.
destination_path (str): The directory to download the file to.
node_ranks (list[int]): The ranks of the nodes that will download the file. If None, all nodes will download the file.
num_attempts (int): Retry for object store downloads. Default to 5.
"""
# Only local rank 0 downloads
local_rank = dist.get_local_rank()
if local_rank != 0:
return

node_rank = dist.get_node_rank()
if node_ranks is not None and node_rank not in node_ranks:
return

object_store = maybe_create_object_store_from_uri(source_uri)
_, _, source_path = parse_uri(source_uri)
if source_uri.endswith('.symlink'):
source_path = extract_path_from_symlink(source_path, object_store)
assert object_store is not None

@retry(num_attempts=num_attempts)
def _download():
object_store.download_object(
object_name=source_path,
filename=destination_path,
)

log.debug(f'Downloading {source_path} to {destination_path}')
_download()
log.debug(f'Finished downloading {source_path} to {destination_path}')


def download_monolithic_checkpoint(
bigning marked this conversation as resolved.
Show resolved Hide resolved
source_uri: str,
destination_path: str,
global_rank_zero_only: bool = True,
):
"""Downloads a monolithic checkpoint from the specified URI to the specified directory.

Args:
source_uri (str): The URI to download the checkpoint from or symlink that points to the URI.
destination_path (str): The directory to download the checkpoint to.
global_rank_zero_only (bool): If True, only rank 0 will download the checkpoint.
broadcast_files_to_other_nodes (bool): If True, the downloaded checkpoint will be broadcast to all other nodes.
If torch syncs modules states this is unnecessary.
"""
node_ranks = None
if global_rank_zero_only:
node_ranks = [0]
download_file(
source_uri=source_uri,
destination_path=destination_path,
node_ranks=node_ranks,
)
if not global_rank_zero_only or (global_rank_zero_only and dist.get_global_rank() == 0):
return destination_path
return None
2 changes: 2 additions & 0 deletions composer/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
create_symlink_file,
ensure_folder_has_no_conflicting_files,
ensure_folder_is_empty,
extract_path_from_symlink,
format_name_with_dist,
format_name_with_dist_and_time,
get_file,
Expand Down Expand Up @@ -158,6 +159,7 @@
'ParallelismConfig',
'MLFLOW_EXPERIMENT_ID_FORMAT_KEY',
'MLFLOW_RUN_ID_FORMAT_KEY',
'extract_path_from_symlink',
'RemoteUploader',
'validate_credentials',
'build_remote_backend',
Expand Down
11 changes: 11 additions & 0 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
'maybe_create_object_store_from_uri',
'maybe_create_remote_uploader_downloader_from_uri',
'parse_uri',
'extract_path_from_symlink',
'validate_credentials',
]

Expand All @@ -57,6 +58,16 @@ def extract_path_from_symlink(
source_path: str,
object_store: Optional[Union[LoggerDestination, ObjectStore]] = None,
) -> str:
"""Returns the checkpont path from symlink file.

Args:
source_path(str): The remote symlink path.
object_store(LoggerDestination | ObjectStore, optional): The object store
used to download the remote symlink file

Returns:
str: The content of the remote symlink file.
"""
if object_store is not None:
with tempfile.TemporaryDirectory() as tmpdir:
_, _, source_path = parse_uri(source_path)
Expand Down
56 changes: 56 additions & 0 deletions tests/checkpoint/test_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import os
import tempfile
from unittest.mock import patch

import pytest
import torch

from composer.checkpoint import download_monolithic_checkpoint
from composer.utils import dist
from tests.checkpoint.helpers import init_model
from tests.common.markers import world_size
from tests.utils.test_remote_uploader import DummyObjectStore


@world_size(1, 2)
@pytest.mark.gpu
@pytest.mark.parametrize('rank_zero_only', [True, False])
def test_download_monolithic_checkpoint(world_size: int, rank_zero_only: bool):
# Write a checkpoint
tmp_dir = tempfile.TemporaryDirectory()
use_fsdp = False
if world_size > 1:
use_fsdp = True
fsdp_model, _ = init_model(use_fsdp=use_fsdp)

from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
state = get_model_state_dict(fsdp_model, options=StateDictOptions(full_state_dict=True))

checkpoint_filename = 'state_dict'
save_filename = os.path.join(tmp_dir.name, checkpoint_filename)
if dist.get_global_rank() == 0:
torch.save(state, save_filename)

class DummyS3ObjectStore(DummyObjectStore):

def get_tmp_dir(self):
return tmp_dir

# Download a monolithic checkpoint
local_file_name = 'state_dict.download'
with patch('composer.utils.file_helpers.S3ObjectStore', DummyS3ObjectStore):
ret = download_monolithic_checkpoint(
source_uri=f's3://bucket_name/{checkpoint_filename}',
destination_path=local_file_name,
global_rank_zero_only=rank_zero_only,
)
dist.barrier()

if rank_zero_only and dist.get_global_rank() != 0:
assert ret == None
if dist.get_global_rank() == 0:
assert ret == local_file_name
assert os.path.isfile(local_file_name) == True
2 changes: 2 additions & 0 deletions tests/utils/test_remote_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def download_object(
overwrite: bool = False,
callback: Optional[Callable[[int, int], None]] = None,
):
if overwrite is False and os.path.isfile(filename):
raise FileExistsError(f'The file at {filename} already exists and overwrite is set to False.')
object_path = pathlib.Path(self.root) / object_name
shutil.copy2(object_path, filename)

Expand Down
Loading