Skip to content

Commit

Permalink
Integrate TorchFT
Browse files Browse the repository at this point in the history
**Summary**
This is a WIP TorchFT integration PR.

**Current Issues**

This doesn't work at this moment as there are hanged groups when a new
group joins.

**Issue 1:**
~Group 0 and group 1 will hang during the first `should_commit` after
group 1 applying the pending state_dict from group 0.~

Fixed with: pytorch/torchft#83

**Issue 2:**
~Group 0 and group 1 will pass the `should_commit` but group 0 needs
healing which is wrong and the healing process will cause another hang.~

Fixed with: pytorch/torchft#83

**Issue 3:**
~The byproduct of issue 1 and issue 2: group 1 will continue to print
out~
```
[rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN
socketProgress: Connection closed by remote peer
devgpu051.cln3.svc.fbinfra.net<33618>
```

Fixed with pytorch/torchft#91 and several other
fixes.

**Issue 4:**
When there are 3 groups, everyone requests the state dict every step.
***How to reproduce?***
Using the `Reproduce steps` to run 2 groups, then add another group by
modifying the command.

Seems to be fixed, will need more tests.

**Issue 5:**
Hang will happen if using functional collective.
***How to reproduce?***
Pull the latest version of this PR and comment out line 41 and uncomment
line 42 in `torchtitan/utils.py`

**Reproduce steps:**

1. Patch TorchFT with pytorch/torchft#82
2. Execute lighthouse
3. Execute the following command in one terminal:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1
NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2
--experimental.enable_torchft --experimental.ft_replica_group_id=0
```
4. Wait 10 seconds, execute following command in another terminal:
```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3
NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2
--experimental.enable_torchft --experimental.ft_replica_group_id=1
```

ghstack-source-id: 3e708066a025792cb8ac08bd3e223e66203dbaa7
Pull Request resolved: #834
  • Loading branch information
fegin committed Feb 28, 2025
1 parent ec82573 commit 9cf0e84
Show file tree
Hide file tree
Showing 10 changed files with 474 additions and 57 deletions.
3 changes: 3 additions & 0 deletions run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}

PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
torchtitan/train.py --job.config_file ${CONFIG_FILE} $overrides
5 changes: 3 additions & 2 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.testing._internal.distributed.fake_pg import FakeStore

from torchtitan.components.ft import init_ft_manager
from torchtitan.components.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.config_manager import JobConfig
from torchtitan.distributed import ParallelDims, utils as dist_utils
Expand Down Expand Up @@ -102,7 +103,6 @@ def estimate_memory(job_config: JobConfig):
if not job_config.memory_estimation.disable_fake_mode
else contextlib.nullcontext()
):

logger.info(
f"Building {train_spec.name} {job_config.model.flavor} with {model_config}"
)
Expand All @@ -122,7 +122,8 @@ def estimate_memory(job_config: JobConfig):
model.train()

# build optimizer after applying parallelisms to the model
optimizers = build_optimizers([model], job_config)
ft_manager = init_ft_manager(job_config)
optimizers = build_optimizers([model], job_config, ft_manager)
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
# Post optimizer step model converters hook.
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
Expand Down
23 changes: 23 additions & 0 deletions tests/unit_tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,18 @@ class DummyJob:
dump_folder: str = "dummy_folder"


@dataclass
class DummyFaultTolerance:
replica_id = 0
group_size = 1


@dataclass
class DummyJobConfig:
checkpoint: DummyCheckpointConfig = field(default_factory=DummyCheckpointConfig)
job: DummyJob = field(default_factory=DummyJob)
fault_tolerance: DummyFaultTolerance = field(default_factory=DummyFaultTolerance)
ft_manager = None


# Dummy instances to supply as constructor arguments.
Expand Down Expand Up @@ -103,13 +111,16 @@ def tearDown(self):
def test_save(self, *_):
"""Test that calling save() writes a checkpoint file to disk."""
job_config = DummyJobConfig(job=self.dummy_job)
ft_manager = mock.Mock()
ft_manager.enabled = False
manager = CheckpointManager(
dummy_dataloader,
dummy_model_parts,
dummy_optimizers,
dummy_lr_schedulers,
{"trainer": self.trainer_state},
job_config,
ft_manager,
)
step = 20
manager.save(curr_step=step, force=True)
Expand All @@ -133,13 +144,16 @@ def test_save(self, *_):
def test_load(self, *_):
"""Test that load() properly reads the checkpoint file from disk and restores state."""
job_config = DummyJobConfig(job=self.dummy_job)
ft_manager = mock.Mock()
ft_manager.enabled = False
manager = CheckpointManager(
dummy_dataloader,
dummy_model_parts,
dummy_optimizers,
dummy_lr_schedulers,
{"trainer": self.trainer_state},
job_config,
ft_manager,
)
step = 30
manager.save(curr_step=step, force=True)
Expand Down Expand Up @@ -171,13 +185,16 @@ def test_purge_stale_checkpoints_rank_zero(self, *_):
"""
job_config = DummyJobConfig(job=self.dummy_job)
job_config.checkpoint.keep_latest_k = 3
ft_manager = mock.Mock()
ft_manager.enabled = False
manager = CheckpointManager(
dummy_dataloader,
dummy_model_parts,
dummy_optimizers,
dummy_lr_schedulers,
{"trainer": self.trainer_state},
job_config,
ft_manager,
)
steps = [10, 20, 30, 40, 50]
for s in steps:
Expand Down Expand Up @@ -215,13 +232,16 @@ def test_purge_stale_checkpoints_rank_nonzero(self, *_):
"""
job_config = DummyJobConfig(job=self.dummy_job)
job_config.checkpoint.keep_latest_k = 3
ft_manager = mock.Mock()
ft_manager.enabled = False
manager = CheckpointManager(
dummy_dataloader,
dummy_model_parts,
dummy_optimizers,
dummy_lr_schedulers,
{"trainer": self.trainer_state},
job_config,
ft_manager,
)
steps = [10, 20, 30, 40, 50]
for s in steps:
Expand Down Expand Up @@ -252,13 +272,16 @@ def test_async_save_calls_async_wait(self, *_):
# Set async_mode to "async" in the job configuration.
job_config = DummyJobConfig(job=self.dummy_job)
job_config.checkpoint.async_mode = "async"
ft_manager = mock.Mock()
ft_manager.enabled = False
manager = CheckpointManager(
dummy_dataloader,
dummy_model_parts,
dummy_optimizers,
dummy_lr_schedulers,
{"trainer": self.trainer_state},
job_config,
ft_manager,
)
# First save: should schedule an async save.
manager.save(curr_step=10, force=False)
Expand Down
142 changes: 115 additions & 27 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader

from torchtitan.components.ft import FTManager
from torchtitan.components.optimizer import LRSchedulersContainer, OptimizersContainer
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.tools.logging import init_logger, logger
Expand Down Expand Up @@ -214,6 +215,19 @@ class CheckpointManager:
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers
with the assumption that all lr_schedulers have the same state_dict.
Note: TorchFT checkpointing flow
There are two types of checkpoints: when TorchFT is enabled: 1) the full perisistent
checkpoint, 2) the per-replica checkpoint.
The full perisistent checkpoint is saved by the replica with
``ft_manager.participating_rank() == 0``. It contains everything including the model,
optimizer, lr_scheduler, dataloader, and train_state. Right now the full perisistent
checkpoint is loaded by all replicas. However, we can optimize it to only load if
there are no other alive replicas.
The per-replica checkpoint contains only the dataloader and is saved/loaded by all
replicas to/from the its own folder. The folder name is prefixed with the ft_replica_id.
Args:
dataloader (DataLoader): The dataloader used to load the data.
Expand All @@ -223,6 +237,7 @@ class CheckpointManager:
states (Dict[str, Any]): The states that need to be saved, other than the
previous 4 components.
job_config (JobConfig): The job config used to configure the checkpointing.
ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
"""

def __init__(
Expand All @@ -233,16 +248,41 @@ def __init__(
lr_schedulers: LRSchedulersContainer,
states: Dict[str, Any],
job_config: JobConfig,
ft_manager: FTManager,
) -> None:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint
self.ft_manager = ft_manager.manager if ft_manager.enabled else None

if self.ft_manager:
optimizers.init_cache_state_dict()

def state_dict():
ret = {}
for k, v in self.states.items():
if k in {
MODEL,
OPTIMIZER,
LR_SCHEDULER,
TRAIN_STATE,
}:
ret[k] = v.state_dict()
return ret

def load_state_dict(state_dict):
assert state_dict is not None
for k, v in state_dict.items():
self.states[k].load_state_dict(v)

self.ft_manager.set_state_dict_fns(load_state_dict, state_dict)
self.ft_replica_id = job_config.fault_tolerance.replica_id

async_mode = ckpt_config.async_mode.lower()
self.enable_staging = (
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
)
) or self.ft_manager

if not self.enable_checkpoint:
if not self.enable_checkpoint and self.ft_manager is None:
return

self.states = states
Expand All @@ -254,6 +294,13 @@ def __init__(
LR_SCHEDULER: lr_schedulers,
}
)
self.ft_states = {DATALOADER: dataloader}

self.staging = False
self.sending_to_checkpoint_mp = False
self.staging_id = None
self.cpu_offload_state_dict = None
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None

self.staging = False
self.sending_to_checkpoint_mp = False
Expand All @@ -264,7 +311,7 @@ def __init__(
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.interval = ckpt_config.interval
async_mode = ckpt_config.async_mode.lower()
if async_mode == AsyncMode.ASYNC:
if async_mode == AsyncMode.ASYNC or self.ft_manager:
self.pg = dist.new_group(backend="gloo")

self.keep_latest_k = ckpt_config.keep_latest_k
Expand Down Expand Up @@ -339,35 +386,44 @@ def save(self, curr_step: int, force: bool = False) -> None:
None
"""

if self.ft_manager:
self._ft_save(curr_step)

if not self._should_save(curr_step, force):
return

begin = time.monotonic()
logger.info("Saving the checkpoint (or staging if async is enabled).")
checkpoint_id = self._create_checkpoint_id(curr_step)
self._async_wait()
# This GC is called for async checkpoint as it is useless to do
# GC right after async_save -- the CPU memory is not able to be
# freed until _async_wait()
if force:
self._save_last_step(curr_step)
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self._async_with_pinned_memory(checkpoint_id)
elif self.async_mode == AsyncMode.ASYNC:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self.async_future = dcp.async_save(
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
)
GarbageCollection.collect("GC collection invoked by checkpointer.")
else:
save_with_gc(self.states, checkpoint_id=checkpoint_id)
self._purge_stale_checkpoints()
if not self.ft_manager or self.ft_manager.participating_rank() == 0:
logger.info("Saving the checkpoint (or staging if async is enabled).")
checkpoint_id = self._create_checkpoint_id(curr_step)
self._async_wait()
# This GC is called for async checkpoint as it is useless to do
# GC right after async_save -- the CPU memory is not able to be
# freed until _async_wait()
if force:
self._save_last_step(curr_step)
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self._async_with_pinned_memory(checkpoint_id)
elif self.async_mode == AsyncMode.ASYNC:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self.async_future = dcp.async_save(
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
)
GarbageCollection.collect("GC collection invoked by checkpointer.")
else:
save_with_gc(self.states, checkpoint_id=checkpoint_id)
self._purge_stale_checkpoints()

logger.info(
"Finished saving the checkpoint (or staging if async is enabled)"
f"in {time.monotonic() - begin:.2f} seconds."
)
logger.info(
"Finished saving the checkpoint (or staging if async is enabled)"
f"in {time.monotonic() - begin:.2f} seconds."
)
elif self.ft_manager:
logger.info(
"Replica %d doesn't save checkpoint.",
self.ft_manager.participating_rank(),
)

@torch.no_grad()
def load(self, step: int = -1) -> bool:
Expand All @@ -384,6 +440,9 @@ def load(self, step: int = -1) -> bool:
bool: Whether the checkpoint was loaded successfully.
"""

if self.ft_manager:
self._ft_load()

if not self.enable_checkpoint or not os.path.isdir(self.folder):
return False

Expand Down Expand Up @@ -467,10 +526,36 @@ def _find_load_step(self, folder: str = "") -> int:
return -1
return max(step_counts)

def _ft_folder(self) -> str:
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}")

def _create_checkpoint_id(self, step: int, folder: str = "") -> str:
folder = folder if folder else self.folder
return os.path.join(folder, f"step-{step}")

def _ft_save(self, step: int) -> None:
begin = time.monotonic()
self._async_wait()
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
self.async_future = dcp.async_save(
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
)
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")

def _ft_load(self) -> None:
step = self._find_load_step(folder=self._ft_folder())
if step == -1:
return

begin = time.monotonic()
logger.info(f"Loading the FT checkpoint at step {step}.")
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
)

def _states_to_load(self, step: int) -> Dict[str, Any]:
"""Determines which states to load for the given step.
Expand All @@ -491,6 +576,8 @@ def _states_to_load(self, step: int) -> Dict[str, Any]:
for exclude_key in self.exclude_from_loading:
if exclude_key not in states:
raise ValueError(f"{exclude_key} not found in state_dict.")
if self.ft_manager:
states_to_load.pop(DATALOADER)
return states_to_load

def _save_last_step(self, curr_step: int) -> None:
Expand Down Expand Up @@ -577,6 +664,7 @@ def _purge_stale_checkpoints(self):
self.keep_latest_k > 0
and dist.get_rank() == 0
and os.path.isdir(self.folder)
and (not self.ft_manager or self.ft_manager.participating_rank() == 0)
):
discovered_checkpoints = []
for filename in os.listdir(self.folder):
Expand Down
Loading

0 comments on commit 9cf0e84

Please sign in to comment.