Skip to content

Commit

Permalink
Integrate TorchFT
Browse files Browse the repository at this point in the history
**Summary**
This is a WIP TorchFT integration PR.

**Current Issues**

This doesn't work at this moment as there are hanged groups when a new
group joins.

**Issue 1:**
~Group 0 and group 1 will hang during the first `should_commit` after
group 1 applying the pending state_dict from group 0.~

Fixed with: pytorch/torchft#83

**Issue 2:**
~Group 0 and group 1 will pass the `should_commit` but group 0 needs
healing which is wrong and the healing process will cause another hang.~

Fixed with: pytorch/torchft#83

**Issue 3:**
~The byproduct of issue 1 and issue 2: group 1 will continue to print
out~
```
[rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN
socketProgress: Connection closed by remote peer
devgpu051.cln3.svc.fbinfra.net<33618>
```

Fixed with pytorch/torchft#91 and several other
fixes.

**Issue 4:**
When there are 3 groups, everyone requests the state dict every step.
***How to reproduce?***
Using the `Reproduce steps` to run 2 groups, then add another group by
modifying the command.

Seems to be fixed, will need more tests.

**Issue 5:**
Hang will happen if using functional collective.
***How to reproduce?***
Pull the latest version of this PR and comment out line 41 and uncomment
line 42 in `torchtitan/utils.py`

**Reproduce steps:**

1. Patch TorchFT with pytorch/torchft#82
2. Execute lighthouse
3. Execute the following command in one terminal:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1
NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2
--experimental.enable_torchft --experimental.ft_replica_group_id=0
```
4. Wait 10 seconds, execute following command in another terminal:
```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3
NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2
--experimental.enable_torchft --experimental.ft_replica_group_id=1
```

ghstack-source-id: 440da0f8d30d8466c22e1d8e1d738366b2d58bea
Pull Request resolved: #834
  • Loading branch information
fegin committed Feb 25, 2025
1 parent 0c86fdd commit 28f8c50
Show file tree
Hide file tree
Showing 8 changed files with 277 additions and 42 deletions.
4 changes: 4 additions & 0 deletions run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT:-"29512"}

PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
TORCHFT_LIGHTHOUSE=http://localhost:29510 \
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT} \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
torchtitan/train.py --job.config_file ${CONFIG_FILE} $overrides
128 changes: 101 additions & 27 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ class CheckpointManager:
states (Dict[str, Any]): The states that need to be saved, other than the
previous 4 components.
job_config (JobConfig): The job config used to configure the checkpointing.
ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
"""

def __init__(
Expand All @@ -229,16 +230,41 @@ def __init__(
lr_schedulers: LRSchedulersContainer,
states: Dict[str, Any],
job_config: JobConfig,
ft_manager: Optional["ft.Manager"] = None,
) -> None:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint
self.ft_manager = ft_manager

if self.ft_manager:
optimizers.init_cache_state_dict()

def state_dict():
ret = {}
for k, v in self.states.items():
if k in {
MODEL,
OPTIMIZER,
LR_SCHEDULER,
TRAIN_STATE,
}:
ret[k] = v.state_dict()
return ret

def load_state_dict(state_dict):
assert state_dict is not None
for k, v in state_dict.items():
self.states[k].load_state_dict(v)

ft_manager.set_state_dict_fns(load_state_dict, state_dict)
self.ft_replica_id = job_config.experimental.ft_replica_id

async_mode = ckpt_config.async_mode.lower()
self.enable_staging = (
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
)
) or self.ft_manager

if not self.enable_checkpoint:
if not self.enable_checkpoint and self.ft_manager is None:
return

self.states = states
Expand All @@ -250,6 +276,13 @@ def __init__(
LR_SCHEDULER: lr_schedulers,
}
)
self.ft_states = {DATALOADER: dataloader}

self.staging = False
self.sending_to_checkpoint_mp = False
self.staging_id = None
self.cpu_offload_state_dict = None
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None

self.staging = False
self.sending_to_checkpoint_mp = False
Expand All @@ -260,7 +293,7 @@ def __init__(
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.interval = ckpt_config.interval
async_mode = ckpt_config.async_mode.lower()
if async_mode == AsyncMode.ASYNC:
if async_mode == AsyncMode.ASYNC or self.ft_manager:
self.pg = dist.new_group(backend="gloo")

self.keep_latest_k = ckpt_config.keep_latest_k
Expand Down Expand Up @@ -334,35 +367,44 @@ def save(self, curr_step: int, force: bool = False) -> None:
None
"""

if self.ft_manager:
self._ft_save(curr_step)

if not self._should_save(curr_step, force):
return

begin = time.monotonic()
logger.info("Saving the checkpoint (or staging if async is enabled).")
checkpoint_id = self._create_checkpoint_id(curr_step)
self._async_wait()
# This GC is called for async checkpoint as it is useless to do
# GC right after async_save -- the CPU memory is not able to be
# freed until _async_wait()
if force:
self._save_last_step(curr_step)
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self._async_with_pinned_memory(checkpoint_id)
elif self.async_mode == AsyncMode.ASYNC:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self.async_future = dcp.async_save(
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
)
GarbageCollection.collect("GC collection invoked by checkpointer.")
else:
save_with_gc(self.states, checkpoint_id=checkpoint_id)
self._purge_stale_checkpoints()
if not self.ft_manager or self.ft_manager.participating_rank() == 0:
logger.info("Saving the checkpoint (or staging if async is enabled).")
checkpoint_id = self._create_checkpoint_id(curr_step)
self._async_wait()
# This GC is called for async checkpoint as it is useless to do
# GC right after async_save -- the CPU memory is not able to be
# freed until _async_wait()
if force:
self._save_last_step(curr_step)
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self._async_with_pinned_memory(checkpoint_id)
elif self.async_mode == AsyncMode.ASYNC:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self.async_future = dcp.async_save(
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
)
GarbageCollection.collect("GC collection invoked by checkpointer.")
else:
save_with_gc(self.states, checkpoint_id=checkpoint_id)
self._purge_stale_checkpoints()

logger.info(
"Finished saving the checkpoint (or staging if async is enabled)"
f"in {time.monotonic() - begin:.2f} seconds."
)
logger.info(
"Finished saving the checkpoint (or staging if async is enabled)"
f"in {time.monotonic() - begin:.2f} seconds."
)
elif self.ft_manager:
logger.info(
"Replica %d doesn't save checkpoint.",
self.ft_manager.participating_rank(),
)

@torch.no_grad()
def load(self, step: int = -1) -> bool:
Expand All @@ -379,6 +421,9 @@ def load(self, step: int = -1) -> bool:
bool: Whether the checkpoint was loaded successfully.
"""

if self.ft_manager:
self._ft_load()

if not self.enable_checkpoint or not os.path.isdir(self.folder):
return False

Expand Down Expand Up @@ -462,10 +507,36 @@ def _find_load_step(self, folder: str = "") -> int:
return -1
return max(step_counts)

def _ft_folder(self) -> str:
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}")

def _create_checkpoint_id(self, step: int, folder: str = "") -> str:
folder = folder if folder else self.folder
return os.path.join(folder, f"step-{step}")

def _ft_save(self, step: int) -> None:
begin = time.monotonic()
self._async_wait()
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
self.async_future = dcp.async_save(
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
)
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")

def _ft_load(self) -> None:
step = self._find_load_step(folder=self._ft_folder())
if step == -1:
return

begin = time.monotonic()
logger.info(f"Loading the FT checkpoint at step {step}.")
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
)

def _states_to_load(self, step: int) -> Dict[str, Any]:
"""Determines which states to load for the given step.
Expand All @@ -486,6 +557,8 @@ def _states_to_load(self, step: int) -> Dict[str, Any]:
for exclude_key in self.exclude_from_loading:
if exclude_key not in states:
raise ValueError(f"{exclude_key} not found in state_dict.")
if self.ft_manager:
states_to_load.pop(DATALOADER)
return states_to_load

def _save_last_step(self, curr_step: int) -> None:
Expand Down Expand Up @@ -570,6 +643,7 @@ def _cpu_staging(self, checkpoint_id: Optional[str]) -> None:
def _purge_stale_checkpoints(self):
if (
self.keep_latest_k > 0
and self.ft_manager.participating_rank() == 0
and dist.get_rank() == 0
and os.path.isdir(self.folder)
):
Expand Down
58 changes: 51 additions & 7 deletions torchtitan/components/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import copy
import functools
from typing import Any, Callable, Dict, Iterable, List
from typing import Any, Callable, Dict, Iterable, List, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -177,8 +177,49 @@ def zero_grad(self) -> None:
pass


class FTOptimizersContainer(OptimizersContainer):
def __init__(
self,
model_parts: List[nn.Module],
optimizer_kwargs: Dict[str, Any],
name: str,
ft_manager: Optional["ft.Manager"],
) -> None:
import torchft as ft

super().__init__(model_parts, optimizer_kwargs, name)

# Force to initialize the optimizer state so that `optim.step()`
# won't be called by state_dict() and load_state_dict().
_ = {
k: v
for sd in map(get_optimizer_state_dict, model_parts, self.optimizers)
for k, v in sd.items()
}
self.optimizers = [
ft.Optimizer(ft_manager, optim) for optim in self.optimizers
]
self.cache_state_dict: Dict[str, Any] = {}

def init_cache_state_dict(self) -> None:
self.cache_state_dict = super().state_dict()

def state_dict(self) -> Dict[str, Any]:
return self.cache_state_dict

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# We have to invalidate the `cache_state_dict` because optimizer uses
# assign instead of copy when doing `load_state_dict()`. Without
# invalidating the `cache_state_dict`, there will be memory leakage.
self.cache_state_dict = {}
super().load_state_dict(state_dict)
self.init_cache_state_dict()


def build_optimizers(
model_parts: List[nn.Module], job_config: JobConfig
model_parts: List[nn.Module],
job_config: JobConfig,
ft_manager: Optional["ft.Manager"] = None,
) -> OptimizersContainer:
"""Create a OptimizersContainer for the given model parts and job config.
Expand Down Expand Up @@ -219,11 +260,14 @@ def build_optimizers(
"foreach": foreach,
}

return (
OptimizersContainer(model_parts, optimizer_kwargs, name)
if not optim_in_bwd
else OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
)
if optim_in_bwd and ft_manager:
raise ValueError("TorchFT is not supported with optimizers in backward.")
elif optim_in_bwd:
return OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
elif ft_manager:
return FTOptimizersContainer(model_parts, optimizer_kwargs, name, ft_manager)
else:
return OptimizersContainer(model_parts, optimizer_kwargs, name)


class LRSchedulersContainer(Stateful):
Expand Down
24 changes: 24 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,30 @@ def __init__(self):
action="store_true",
)

self.parser.add_argument(
"--experimental.enable_torchft",
action="store_true",
help="Enable TorchFT integration.",
)

self.parser.add_argument(
"--experimental.ft_replica_id",
type=int,
default=0,
help="The TorchFT replica ID of this run.",
)

self.parser.add_argument(
"--experimental.ft_group_size",
type=int,
default=1,
help="""
The number of TorchFT replicate groups. This number will be used for
dataloader to split the dataset across the replicate groups and FSDP
dimension.
""",
)

def to_dict(self):
return self.args_dict

Expand Down
21 changes: 17 additions & 4 deletions torchtitan/distributed/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from dataclasses import dataclass
from functools import cached_property
from typing import Any, Optional

from torch.distributed.device_mesh import init_device_mesh

Expand All @@ -24,6 +25,7 @@ class ParallelDims:
pp: int
world_size: int
enable_loss_parallel: bool
ft_manager: Optional["ft.Manager"]

def __post_init__(self):
self._validate()
Expand Down Expand Up @@ -56,13 +58,24 @@ def build_mesh(self, device_type):
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
):
if d > 1:
if d > 1 or (name == "dp_replicate" and self.ft_manager is not None):
dims.append(d)
names.append(name)

logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
if self.ft_manager is None:
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
else:
from torchft.process_group import ft_init_device_mesh

mesh = ft_init_device_mesh(
device_type=device_type,
mesh_shape=dims,
mesh_dim_names=names,
replicate_dim=names.index("dp_replicate"),
manager=self.ft_manager,
)

# Create all the submesh here to ensure all required process groups are
# initialized:
Expand All @@ -73,7 +86,7 @@ def build_mesh(self, device_type):
# Mesh for loss all-reduce
dp_cp_mesh_dim_names = []

if self.dp_replicate_enabled:
if self.dp_replicate_enabled or self.ft_manager is not None:
dp_mesh_dim_names.append("dp_replicate")
dp_cp_mesh_dim_names.append("dp_replicate")
if self.dp_shard_enabled:
Expand Down Expand Up @@ -101,7 +114,7 @@ def dp_enabled(self):

@property
def dp_replicate_enabled(self):
return self.dp_replicate > 1
return self.dp_replicate > 1 or self.ft_manager is not None

@property
def dp_shard_enabled(self):
Expand Down
Loading

0 comments on commit 28f8c50

Please sign in to comment.