Skip to content

Commit

Permalink
Integrate TorchFT
Browse files Browse the repository at this point in the history
**Summary**
This is a WIP TorchFT integration PR.

**Current Issues**

This doesn't work at this moment as there are hanged groups when a new
group joins.

**Issue 1:**
~Group 0 and group 1 will hang during the first `should_commit` after
group 1 applying the pending state_dict from group 0.~

Fixed with: pytorch/torchft#83

**Issue 2:**
~Group 0 and group 1 will pass the `should_commit` but group 0 needs
healing which is wrong and the healing process will cause another hang.~

Fixed with: pytorch/torchft#83

**Issue 3:**
~The byproduct of issue 1 and issue 2: group 1 will continue to print
out~
```
[rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN
socketProgress: Connection closed by remote peer
devgpu051.cln3.svc.fbinfra.net<33618>
```

Fixed with pytorch/torchft#91 and several other
fixes.

**Issue 4:**
When there are 3 groups, everyone requests the state dict every step.
***How to reproduce?***
Using the `Reproduce steps` to run 2 groups, then add another group by
modifying the command.

Seems to be fixed, will need more tests.

**Issue 5:**
Hang will happen if using functional collective.
***How to reproduce?***
Pull the latest version of this PR and comment out line 41 and uncomment
line 42 in `torchtitan/utils.py`

**Reproduce steps:**

1. Patch TorchFT with pytorch/torchft#82
2. Execute lighthouse
3. Execute the following command in one terminal:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1
NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2
--experimental.enable_torchft --experimental.ft_replica_group_id=0
```
4. Wait 10 seconds, execute following command in another terminal:
```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3
NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2
--experimental.enable_torchft --experimental.ft_replica_group_id=1
```

ghstack-source-id: 34c0f96844c46399715bff6302741b8d8ee18532
Pull Request resolved: #834
  • Loading branch information
fegin committed Feb 27, 2025
1 parent 71c6876 commit 3514de3
Show file tree
Hide file tree
Showing 10 changed files with 354 additions and 48 deletions.
3 changes: 3 additions & 0 deletions run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}

PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
torchtitan/train.py --job.config_file ${CONFIG_FILE} $overrides
8 changes: 8 additions & 0 deletions tests/unit_tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,18 @@ class DummyJob:
dump_folder: str = "dummy_folder"


@dataclass
class DummyExperimental:
ft_replica_id = 0
ft_group_size = 1


@dataclass
class DummyJobConfig:
checkpoint: DummyCheckpointConfig = field(default_factory=DummyCheckpointConfig)
job: DummyJob = field(default_factory=DummyJob)
experimental: DummyExperimental = field(default_factory=DummyExperimental)
ft_manager = None


# Dummy instances to supply as constructor arguments.
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/test_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def build_parallel_dims(job_config, world_size):
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=not job_config.training.disable_loss_parallel,
ft_manager=None,
)
return parallel_dims

Expand Down
146 changes: 118 additions & 28 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dataclasses import dataclass, field
from io import BytesIO
from multiprocessing import get_context
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union

import torch
import torch.distributed as dist
Expand All @@ -36,6 +36,9 @@
from torchtitan.tools.logging import init_logger, logger
from torchtitan.tools.utils import GarbageCollection

if TYPE_CHECKING:
import torchft as ft


MODEL = "model"
OPTIMIZER = "optimizer"
Expand Down Expand Up @@ -214,6 +217,19 @@ class CheckpointManager:
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers
with the assumption that all lr_schedulers have the same state_dict.
Note: TorchFT checkpointing flow
There are two types of checkpoints: when TorchFT is enabled: 1) the full perisistent
checkpoint, 2) the per-replica checkpoint.
The full perisistent checkpoint is saved by the replica with
``ft_manager.participating_rank() == 0``. It contains everything including the model,
optimizer, lr_scheduler, dataloader, and train_state. Right now the full perisistent
checkpoint is loaded by all replicas. However, we can optimize it to only load if
there are no other alive replicas.
The per-replica checkpoint contains only the dataloader and is saved/loaded by all
replicas to/from the its own folder. The folder name is prefixed with the ft_replica_id.
Args:
dataloader (DataLoader): The dataloader used to load the data.
Expand All @@ -223,6 +239,7 @@ class CheckpointManager:
states (Dict[str, Any]): The states that need to be saved, other than the
previous 4 components.
job_config (JobConfig): The job config used to configure the checkpointing.
ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
"""

def __init__(
Expand All @@ -233,16 +250,41 @@ def __init__(
lr_schedulers: LRSchedulersContainer,
states: Dict[str, Any],
job_config: JobConfig,
ft_manager: Optional["ft.Manager"] = None,
) -> None:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint
self.ft_manager = ft_manager

if self.ft_manager:
optimizers.init_cache_state_dict()

def state_dict():
ret = {}
for k, v in self.states.items():
if k in {
MODEL,
OPTIMIZER,
LR_SCHEDULER,
TRAIN_STATE,
}:
ret[k] = v.state_dict()
return ret

def load_state_dict(state_dict):
assert state_dict is not None
for k, v in state_dict.items():
self.states[k].load_state_dict(v)

ft_manager.set_state_dict_fns(load_state_dict, state_dict)
self.ft_replica_id = job_config.experimental.ft_replica_id

async_mode = ckpt_config.async_mode.lower()
self.enable_staging = (
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
)
) or self.ft_manager

if not self.enable_checkpoint:
if not self.enable_checkpoint and self.ft_manager is None:
return

self.states = states
Expand All @@ -254,6 +296,13 @@ def __init__(
LR_SCHEDULER: lr_schedulers,
}
)
self.ft_states = {DATALOADER: dataloader}

self.staging = False
self.sending_to_checkpoint_mp = False
self.staging_id = None
self.cpu_offload_state_dict = None
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None

self.staging = False
self.sending_to_checkpoint_mp = False
Expand All @@ -264,7 +313,7 @@ def __init__(
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.interval = ckpt_config.interval
async_mode = ckpt_config.async_mode.lower()
if async_mode == AsyncMode.ASYNC:
if async_mode == AsyncMode.ASYNC or self.ft_manager:
self.pg = dist.new_group(backend="gloo")

self.keep_latest_k = ckpt_config.keep_latest_k
Expand Down Expand Up @@ -339,35 +388,44 @@ def save(self, curr_step: int, force: bool = False) -> None:
None
"""

if self.ft_manager:
self._ft_save(curr_step)

if not self._should_save(curr_step, force):
return

begin = time.monotonic()
logger.info("Saving the checkpoint (or staging if async is enabled).")
checkpoint_id = self._create_checkpoint_id(curr_step)
self._async_wait()
# This GC is called for async checkpoint as it is useless to do
# GC right after async_save -- the CPU memory is not able to be
# freed until _async_wait()
if force:
self._save_last_step(curr_step)
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self._async_with_pinned_memory(checkpoint_id)
elif self.async_mode == AsyncMode.ASYNC:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self.async_future = dcp.async_save(
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
)
GarbageCollection.collect("GC collection invoked by checkpointer.")
else:
save_with_gc(self.states, checkpoint_id=checkpoint_id)
self._purge_stale_checkpoints()
if not self.ft_manager or self.ft_manager.participating_rank() == 0:
logger.info("Saving the checkpoint (or staging if async is enabled).")
checkpoint_id = self._create_checkpoint_id(curr_step)
self._async_wait()
# This GC is called for async checkpoint as it is useless to do
# GC right after async_save -- the CPU memory is not able to be
# freed until _async_wait()
if force:
self._save_last_step(curr_step)
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self._async_with_pinned_memory(checkpoint_id)
elif self.async_mode == AsyncMode.ASYNC:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self.async_future = dcp.async_save(
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
)
GarbageCollection.collect("GC collection invoked by checkpointer.")
else:
save_with_gc(self.states, checkpoint_id=checkpoint_id)
self._purge_stale_checkpoints()

logger.info(
"Finished saving the checkpoint (or staging if async is enabled)"
f"in {time.monotonic() - begin:.2f} seconds."
)
logger.info(
"Finished saving the checkpoint (or staging if async is enabled)"
f"in {time.monotonic() - begin:.2f} seconds."
)
elif self.ft_manager:
logger.info(
"Replica %d doesn't save checkpoint.",
self.ft_manager.participating_rank(),
)

@torch.no_grad()
def load(self, step: int = -1) -> bool:
Expand All @@ -384,6 +442,9 @@ def load(self, step: int = -1) -> bool:
bool: Whether the checkpoint was loaded successfully.
"""

if self.ft_manager:
self._ft_load()

if not self.enable_checkpoint or not os.path.isdir(self.folder):
return False

Expand Down Expand Up @@ -467,10 +528,36 @@ def _find_load_step(self, folder: str = "") -> int:
return -1
return max(step_counts)

def _ft_folder(self) -> str:
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}")

def _create_checkpoint_id(self, step: int, folder: str = "") -> str:
folder = folder if folder else self.folder
return os.path.join(folder, f"step-{step}")

def _ft_save(self, step: int) -> None:
begin = time.monotonic()
self._async_wait()
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
self.async_future = dcp.async_save(
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
)
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")

def _ft_load(self) -> None:
step = self._find_load_step(folder=self._ft_folder())
if step == -1:
return

begin = time.monotonic()
logger.info(f"Loading the FT checkpoint at step {step}.")
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
)

def _states_to_load(self, step: int) -> Dict[str, Any]:
"""Determines which states to load for the given step.
Expand All @@ -491,6 +578,8 @@ def _states_to_load(self, step: int) -> Dict[str, Any]:
for exclude_key in self.exclude_from_loading:
if exclude_key not in states:
raise ValueError(f"{exclude_key} not found in state_dict.")
if self.ft_manager:
states_to_load.pop(DATALOADER)
return states_to_load

def _save_last_step(self, curr_step: int) -> None:
Expand Down Expand Up @@ -577,6 +666,7 @@ def _purge_stale_checkpoints(self):
self.keep_latest_k > 0
and dist.get_rank() == 0
and os.path.isdir(self.folder)
and (not self.ft_manager or self.ft_manager.participating_rank() == 0)
):
discovered_checkpoints = []
for filename in os.listdir(self.folder):
Expand Down
47 changes: 47 additions & 0 deletions torchtitan/components/ft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import importlib
from typing import Optional

from torchtitan.config_manager import JobConfig

if importlib.util.find_spec("torchft") is not None:
import torchft as ft

has_torchft = True
else:
has_torchft = False


def init_ft_manager(job: JobConfig) -> Optional["ft.Manager"]:
"""Initialize the FT manager if TorchFT is enabled.
Args:
job (JobConfig): The job configuration.
Returns:
Optional[ft.Manager]: The FT manager if TorchFT is enabled, otherwise None.
"""
if not job.experimental.enable_torchft:
return None

if not has_torchft:
raise ImportError("torchft is not installed. Please install it.")

if job.experimental.ft_min_replica_size < 1:
raise ValueError("At least one FT replica is required.")

pg = ft.ProcessGroupBabyNCCL()

return ft.Manager(
pg=pg,
min_replica_size=job.experimental.ft_min_replica_size,
load_state_dict=None,
state_dict=None,
use_async_quorum=True,
replica_id=f"torchtitan_ft_{job.experimental.ft_replica_id}",
)
Loading

0 comments on commit 3514de3

Please sign in to comment.