Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement async_checkpoint #313

Merged
merged 1 commit into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 229 additions & 73 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import os
import re
import time
from multiprocessing import get_context
from typing import Any, Dict

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
Expand All @@ -22,7 +24,7 @@
)
from torch.distributed.checkpoint.stateful import Stateful
from torchtitan.config_manager import JobConfig
from torchtitan.logging_utils import logger
from torchtitan.logging_utils import init_logger, logger


DTYPE_MAP = {
Expand All @@ -37,6 +39,12 @@ class IntervalType(enum.Enum):
STEPS = enum.auto()


class AsyncMode(str, enum.Enum):
DISABLED = "disabled"
ASYNC = "async"
ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem"


class ModelWrapper(Stateful):
def __init__(self, model: nn.Module) -> None:
self.model = model
Expand All @@ -60,6 +68,43 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
set_optimizer_state_dict(self.model, self.optim, optim_state_dict=state_dict)


class Terminate:
pass


class SaveDone:
pass


def checkpoint_mp(recv, send):
init_logger()
os.environ["MASTER_PORT"] = str(int(os.environ["MASTER_PORT"]) + 2)
os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False"
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
dist.init_process_group()
try:
while True:
logger.debug("Checkpoint background process is done.")
send.put(SaveDone())
logger.debug("Wait for the new state_dict.")
obj = recv.get()
logger.debug("Received the new state_dict.")
if isinstance(obj, Terminate):
logger.info("Terminating the checkpoint background process.")
return
assert isinstance(obj, tuple)
begin = time.monotonic()
state, checkpoint_id = obj
dcp.save(state, checkpoint_id=checkpoint_id)
logger.info(
"Finish saving the checkpoint in the background process in "
f"{time.monotonic() - begin:.2f} seconds."
)
finally:
logger.info("Destroying the process group.")
dist.destroy_process_group()


class CheckpointManager:
def __init__(
self,
Expand All @@ -72,82 +117,88 @@ def __init__(
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint

if self.enable_checkpoint:
self.states = states
self.states.update(
{
"model": ModelWrapper(model),
"optimizer": OptimizerWrapper(model, optimizer),
"lr_scheduler": lr_scheduler,
}
)
if not self.enable_checkpoint:
return

self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.interval_type = (
IntervalType.SECONDS
if ckpt_config.interval_type == "seconds"
else IntervalType.STEPS
)
self.interval = ckpt_config.interval
self.model_weights_only = ckpt_config.model_weights_only
self.export_dtype = DTYPE_MAP[ckpt_config.export_dtype]
self.states = states
self.states.update(
{
"model": ModelWrapper(model),
"optimizer": OptimizerWrapper(model, optimizer),
"lr_scheduler": lr_scheduler,
}
)

logger.info(
f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}"
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.interval_type = (
IntervalType.SECONDS
if ckpt_config.interval_type == "seconds"
else IntervalType.STEPS
)
self.interval = ckpt_config.interval
self.begin_time = 0
self.time_sync_work = None
self.time_sync_result = None
self.pg = dist.new_group(backend="gloo")

self.model_weights_only = ckpt_config.model_weights_only
self.export_dtype = DTYPE_MAP[ckpt_config.export_dtype]

self.mp = None
async_mode = ckpt_config.async_mode.lower()
if async_mode == AsyncMode.DISABLED:
self.async_mode = AsyncMode.DISABLED
elif async_mode == AsyncMode.ASYNC:
self.async_mode = AsyncMode.ASYNC
self.async_future = None
elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
self.async_mode = AsyncMode.ASYNC_WITH_PINNED_MEM
ctx = get_context("spawn")
self.mp_queue_send = ctx.Queue()
self.mp_queue_recv = ctx.Queue()
self.mp = ctx.Process(
target=checkpoint_mp,
args=(
self.mp_queue_send,
self.mp_queue_recv,
),
daemon=True,
)
self.mp.start()
self.cpu_offload_state_dict = None
self.staging = False
self.staging_state_dict = None
self.staging_id = None
self.staging_stream = torch.cuda.Stream()
else:
raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}")

logger.info(
f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}"
)

self.begin = 0
self.work = None
self.pg = dist.new_group(backend="gloo")
self.doit = None
def __del__(self):
if self.enable_checkpoint and self.mp and self.mp.is_alive():
self.mp_queue_send.put(Terminate())
self.mp.join()

def reset(self) -> None:
self.begin = time.monotonic()
self.begin_time = time.monotonic()

def _create_checkpoint_id(self, step: int) -> str:
return os.path.join(self.folder, f"step-{step}")

def save(self, curr_step: int, force: bool = False) -> None:
"""
force = True will force the checkpoint to be saved, even if the interval has not been reached.
This only happens when train_state.step == job_config.training.steps, or for initial seed checkpoint.
"""
if not self.enable_checkpoint:
return

if not force:
if self.interval_type == IntervalType.STEPS and not (
curr_step % self.interval == 0
):
return
if self.interval_type == IntervalType.SECONDS:
doit = (time.monotonic() - self.begin) >= self.interval
self.doit = torch.tensor(int(doit))
if self.work is None:
self.work = dist.all_reduce(self.doit, group=self.pg, async_op=True)
return
elif curr_step % 5 == 4:
self.work.wait()
self.work = None
doit = self.doit.item()
self.doit = None
if doit == 0:
return
else:
return

if self.work:
self.work.wait()
self.work = None
self.doit = None

# We only consider saving weights only at the end of the training.
# So this won't affect preemption and training resume.
# We also only allow dtype conversion when we are checkpoint model weights only
# and the current dtype is not the same as the export dtype at the end of the training.
if force and self.model_weights_only:
def _save_last_step(self, curr_step: int) -> None:
# We only consider saving weights only at the end of the training. So
# this won't affect preemption and training resume. We also only allow
# dtype conversion when we are checkpoint model weights only and the
# current dtype is not the same as the export dtype at the end of the training.
if self.model_weights_only:
# We update self.states to keep the model only.
# After this update, self.states = {'tok_embeddings.weight':...,''layers.0.attention.wq.weight': ...}.
# After this update, self.states = {
# 'tok_embeddings.weight':...,
# 'layers.0.attention.wq.weight': ...
# }.
self.states = self.states["model"].state_dict()

# For now, we will manually pop the freqs_cis buffer, as we made this permanent
Expand All @@ -160,19 +211,124 @@ def save(self, curr_step: int, force: bool = False) -> None:
k: v.to(self.export_dtype) for k, v in self.states.items()
}
logger.info(
f"Saving a model weights only checkpoint in {self.export_dtype} at step {curr_step}"
f"Saving a model weights only checkpoint in {self.export_dtype} "
f"at last step, step {curr_step}."
)

else:
logger.info(f"Saving a full checkpoint at step {curr_step}")
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")

begin = time.monotonic()
dcp.save(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
self.reset()

def _should_save(self, curr_step: int, force: bool = False) -> bool:
if not self.enable_checkpoint:
return False

if not force:
if self.interval_type == IntervalType.STEPS and not (
curr_step % self.interval == 0
):
return False
if self.interval_type == IntervalType.SECONDS:
time_sync_result = (time.monotonic() - self.begin_time) >= self.interval
self.time_sync_result = torch.tensor(int(time_sync_result))
if self.time_sync_work is None:
self.time_sync_work = dist.all_reduce(
self.time_sync_result, group=self.pg, async_op=True
)
return False
elif curr_step % 5 == 4:
self.time_sync_work.wait()
self.time_sync_work = None
time_sync_result = self.time_sync_result.item()
self.time_sync_result = None
if time_sync_result == 0:
return False
else:
return False

if self.time_sync_work:
self.time_sync_work.wait()
self.time_sync_work = None
self.time_sync_result = None

return True

def _async_wait(self) -> None:
if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
logger.debug(
f"Waiting for the background process to finish, {time.monotonic()=}.:.2f"
)
if not self.mp.is_alive():
raise RuntimeError("The checkpoint background process is dead.")
_ = self.mp_queue_recv.get()
elif self.async_mode == AsyncMode.ASYNC:
if self.async_future is not None:
self.async_future.result()

def _async_with_pinned_memory(self, checkpoint_id: str) -> None:
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
if self.cpu_offload_state_dict is None:
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
self.cpu_offload_state_dict = _create_cpu_state_dict(
state_dict, pin_memory=True
)

logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
with torch.cuda.stream(self.staging_stream):
self.cpu_offload_state_dict = _copy_state_dict(
state_dict,
self.cpu_offload_state_dict,
non_blocking=True,
)
self.staging = True
self.staging_state_dict = state_dict
self.staging_id = checkpoint_id

def save(self, curr_step: int, force: bool = False) -> None:
"""
force = True will force the checkpoint to be saved, even if the interval
has not been reached.
This only happens when train_state.step == job_config.training.steps, or
for initial seed checkpoint.
"""
if not self._should_save(curr_step, force):
return

begin = time.monotonic()
checkpoint_id = self._create_checkpoint_id(curr_step)
self._async_wait()
if force:
self._save_last_step(curr_step)
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
self._async_with_pinned_memory(checkpoint_id)
elif self.async_mode == AsyncMode.ASYNC:
self.async_future = dcp.async_save(
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin Why did you choose to use the GLOO process group for the async save? Is it expected to make this more efficient?

Neither the DCP docs nor https://discuss.pytorch.org/t/distributed-w-torchtitan-optimizing-checkpointing-efficiency-with-pytorch-dcp/211250 mention or recommend this.

I'm curious to know if this was on purpose and if you have any numbers to show

Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want the checkpointing to affect the training, which involve NCCL. So we choose gloo for any checkpointing communication. Also the main bottleneck of checkpointing is unlikely to be the communication. The storage read/write (or upload/download) will be the major overhead.

)
else:
dcp.save(self.states, checkpoint_id=checkpoint_id)
self.reset()

logger.info(
f"Finished saving the checkpoint in {time.monotonic() - begin:.2f} seconds"
"Finished saving the checkpoint (or staging if async is enabled)"
f"in {time.monotonic() - begin:.2f} seconds."
)

def wait_for_staging(self) -> None:
if (
self.enable_checkpoint
and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
and self.staging
):
logger.debug(f"Waiting for staging, {time.monotonic()=:.2f}.")
self.staging_stream.synchronize()
logger.debug(
f"Sending the state dict to the background process, {time.monotonic()=:.2f}."
)
self.mp_queue_send.put((self.staging_state_dict, self.staging_id))
self.staging = False

def load(self, step: int = -1) -> bool:
if not self.enable_checkpoint:
return False
Expand All @@ -193,13 +349,13 @@ def load(self, step: int = -1) -> bool:

# We won't have optimizer states to load, if we are loading a seed checkpoint
states = {"model": self.states["model"]} if step == 0 else self.states
logger.info(f"Loading the checkpoint at step {step}")
logger.info(f"Loading the checkpoint at step {step}.")
begin = time.monotonic()
dcp.load(
states,
checkpoint_id=self._create_checkpoint_id(step),
)
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds"
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
)
return True
Loading
Loading