diff --git a/run_train.sh b/run_train.sh index f8b465a55..64570316b 100755 --- a/run_train.sh +++ b/run_train.sh @@ -19,7 +19,10 @@ if [ $# -ne 0 ]; then overrides="$*" fi +TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} + PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ torchtitan/train.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/tests/unit_tests/test_checkpoint.py b/tests/unit_tests/test_checkpoint.py index 0a701d1d5..34cf174e9 100644 --- a/tests/unit_tests/test_checkpoint.py +++ b/tests/unit_tests/test_checkpoint.py @@ -61,10 +61,18 @@ class DummyJob: dump_folder: str = "dummy_folder" +@dataclass +class DummyExperimental: + ft_replica_id = 0 + ft_group_size = 1 + + @dataclass class DummyJobConfig: checkpoint: DummyCheckpointConfig = field(default_factory=DummyCheckpointConfig) job: DummyJob = field(default_factory=DummyJob) + experimental: DummyExperimental = field(default_factory=DummyExperimental) + ft_manager = None # Dummy instances to supply as constructor arguments. diff --git a/tests/unit_tests/test_model_converter.py b/tests/unit_tests/test_model_converter.py index 8ac278301..a84fbb509 100644 --- a/tests/unit_tests/test_model_converter.py +++ b/tests/unit_tests/test_model_converter.py @@ -22,6 +22,7 @@ def build_parallel_dims(job_config, world_size): pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=not job_config.training.disable_loss_parallel, + ft_manager=None, ) return parallel_dims diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index ee821c819..8ce661f98 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -16,7 +16,7 @@ from dataclasses import dataclass, field from io import BytesIO from multiprocessing import get_context -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union import torch import torch.distributed as dist @@ -36,6 +36,9 @@ from torchtitan.tools.logging import init_logger, logger from torchtitan.tools.utils import GarbageCollection +if TYPE_CHECKING: + import torchft as ft + MODEL = "model" OPTIMIZER = "optimizer" @@ -214,6 +217,19 @@ class CheckpointManager: 3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers with the assumption that all lr_schedulers have the same state_dict. + Note: TorchFT checkpointing flow + + There are two types of checkpoints: when TorchFT is enabled: 1) the full perisistent + checkpoint, 2) the per-replica checkpoint. + + The full perisistent checkpoint is saved by the replica with + ``ft_manager.participating_rank() == 0``. It contains everything including the model, + optimizer, lr_scheduler, dataloader, and train_state. Right now the full perisistent + checkpoint is loaded by all replicas. However, we can optimize it to only load if + there are no other alive replicas. + + The per-replica checkpoint contains only the dataloader and is saved/loaded by all + replicas to/from the its own folder. The folder name is prefixed with the ft_replica_id. Args: dataloader (DataLoader): The dataloader used to load the data. @@ -223,6 +239,7 @@ class CheckpointManager: states (Dict[str, Any]): The states that need to be saved, other than the previous 4 components. job_config (JobConfig): The job config used to configure the checkpointing. + ft_manager (Optional[ft.Manager]): The FTManager from TorchFT. """ def __init__( @@ -233,16 +250,41 @@ def __init__( lr_schedulers: LRSchedulersContainer, states: Dict[str, Any], job_config: JobConfig, + ft_manager: Optional["ft.Manager"] = None, ) -> None: ckpt_config = job_config.checkpoint self.enable_checkpoint = ckpt_config.enable_checkpoint + self.ft_manager = ft_manager + + if self.ft_manager: + optimizers.init_cache_state_dict() + + def state_dict(): + ret = {} + for k, v in self.states.items(): + if k in { + MODEL, + OPTIMIZER, + LR_SCHEDULER, + TRAIN_STATE, + }: + ret[k] = v.state_dict() + return ret + + def load_state_dict(state_dict): + assert state_dict is not None + for k, v in state_dict.items(): + self.states[k].load_state_dict(v) + + ft_manager.set_state_dict_fns(load_state_dict, state_dict) + self.ft_replica_id = job_config.experimental.ft_replica_id async_mode = ckpt_config.async_mode.lower() self.enable_staging = ( self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM - ) + ) or self.ft_manager - if not self.enable_checkpoint: + if not self.enable_checkpoint and self.ft_manager is None: return self.states = states @@ -254,6 +296,13 @@ def __init__( LR_SCHEDULER: lr_schedulers, } ) + self.ft_states = {DATALOADER: dataloader} + + self.staging = False + self.sending_to_checkpoint_mp = False + self.staging_id = None + self.cpu_offload_state_dict = None + self.staging_stream = torch.cuda.Stream() if self.enable_staging else None self.staging = False self.sending_to_checkpoint_mp = False @@ -264,7 +313,7 @@ def __init__( self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) self.interval = ckpt_config.interval async_mode = ckpt_config.async_mode.lower() - if async_mode == AsyncMode.ASYNC: + if async_mode == AsyncMode.ASYNC or self.ft_manager: self.pg = dist.new_group(backend="gloo") self.keep_latest_k = ckpt_config.keep_latest_k @@ -339,35 +388,44 @@ def save(self, curr_step: int, force: bool = False) -> None: None """ + if self.ft_manager: + self._ft_save(curr_step) + if not self._should_save(curr_step, force): return begin = time.monotonic() - logger.info("Saving the checkpoint (or staging if async is enabled).") - checkpoint_id = self._create_checkpoint_id(curr_step) - self._async_wait() - # This GC is called for async checkpoint as it is useless to do - # GC right after async_save -- the CPU memory is not able to be - # freed until _async_wait() - if force: - self._save_last_step(curr_step) - elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: - GarbageCollection.collect("GC collection invoked by checkpointer.") - self._async_with_pinned_memory(checkpoint_id) - elif self.async_mode == AsyncMode.ASYNC: - GarbageCollection.collect("GC collection invoked by checkpointer.") - self.async_future = dcp.async_save( - self.states, checkpoint_id=checkpoint_id, process_group=self.pg - ) - GarbageCollection.collect("GC collection invoked by checkpointer.") - else: - save_with_gc(self.states, checkpoint_id=checkpoint_id) - self._purge_stale_checkpoints() + if not self.ft_manager or self.ft_manager.participating_rank() == 0: + logger.info("Saving the checkpoint (or staging if async is enabled).") + checkpoint_id = self._create_checkpoint_id(curr_step) + self._async_wait() + # This GC is called for async checkpoint as it is useless to do + # GC right after async_save -- the CPU memory is not able to be + # freed until _async_wait() + if force: + self._save_last_step(curr_step) + elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: + GarbageCollection.collect("GC collection invoked by checkpointer.") + self._async_with_pinned_memory(checkpoint_id) + elif self.async_mode == AsyncMode.ASYNC: + GarbageCollection.collect("GC collection invoked by checkpointer.") + self.async_future = dcp.async_save( + self.states, checkpoint_id=checkpoint_id, process_group=self.pg + ) + GarbageCollection.collect("GC collection invoked by checkpointer.") + else: + save_with_gc(self.states, checkpoint_id=checkpoint_id) + self._purge_stale_checkpoints() - logger.info( - "Finished saving the checkpoint (or staging if async is enabled)" - f"in {time.monotonic() - begin:.2f} seconds." - ) + logger.info( + "Finished saving the checkpoint (or staging if async is enabled)" + f"in {time.monotonic() - begin:.2f} seconds." + ) + elif self.ft_manager: + logger.info( + "Replica %d doesn't save checkpoint.", + self.ft_manager.participating_rank(), + ) @torch.no_grad() def load(self, step: int = -1) -> bool: @@ -384,6 +442,9 @@ def load(self, step: int = -1) -> bool: bool: Whether the checkpoint was loaded successfully. """ + if self.ft_manager: + self._ft_load() + if not self.enable_checkpoint or not os.path.isdir(self.folder): return False @@ -467,10 +528,36 @@ def _find_load_step(self, folder: str = "") -> int: return -1 return max(step_counts) + def _ft_folder(self) -> str: + return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}") + def _create_checkpoint_id(self, step: int, folder: str = "") -> str: folder = folder if folder else self.folder return os.path.join(folder, f"step-{step}") + def _ft_save(self, step: int) -> None: + begin = time.monotonic() + self._async_wait() + checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) + self.async_future = dcp.async_save( + self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg + ) + logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.") + + def _ft_load(self) -> None: + step = self._find_load_step(folder=self._ft_folder()) + if step == -1: + return + + begin = time.monotonic() + logger.info(f"Loading the FT checkpoint at step {step}.") + checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) + dcp.load(self.ft_states, checkpoint_id=checkpoint_id) + GarbageCollection.collect("GC collection for checkpoint loading.") + logger.info( + f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds." + ) + def _states_to_load(self, step: int) -> Dict[str, Any]: """Determines which states to load for the given step. @@ -491,6 +578,8 @@ def _states_to_load(self, step: int) -> Dict[str, Any]: for exclude_key in self.exclude_from_loading: if exclude_key not in states: raise ValueError(f"{exclude_key} not found in state_dict.") + if self.ft_manager: + states_to_load.pop(DATALOADER) return states_to_load def _save_last_step(self, curr_step: int) -> None: @@ -577,6 +666,7 @@ def _purge_stale_checkpoints(self): self.keep_latest_k > 0 and dist.get_rank() == 0 and os.path.isdir(self.folder) + and (not self.ft_manager or self.ft_manager.participating_rank() == 0) ): discovered_checkpoints = [] for filename in os.listdir(self.folder): diff --git a/torchtitan/components/ft.py b/torchtitan/components/ft.py new file mode 100644 index 000000000..6bffac26e --- /dev/null +++ b/torchtitan/components/ft.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +from typing import Optional + +from torchtitan.config_manager import JobConfig + +if importlib.util.find_spec("torchft") is not None: + import torchft as ft + + has_torchft = True +else: + has_torchft = False + + +def init_ft_manager(job: JobConfig) -> Optional["ft.Manager"]: + """Initialize the FT manager if TorchFT is enabled. + + Args: + job (JobConfig): The job configuration. + + Returns: + Optional[ft.Manager]: The FT manager if TorchFT is enabled, otherwise None. + """ + if not job.experimental.enable_torchft: + return None + + if not has_torchft: + raise ImportError("torchft is not installed. Please install it.") + + if job.experimental.ft_min_replica_size < 1: + raise ValueError("At least one FT replica is required.") + + pg = ft.ProcessGroupBabyNCCL() + + return ft.Manager( + pg=pg, + min_replica_size=job.experimental.ft_min_replica_size, + load_state_dict=None, + state_dict=None, + use_async_quorum=True, + replica_id=f"torchtitan_ft_{job.experimental.ft_replica_id}", + ) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 2a4b52f0e..385b8aca0 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -6,7 +6,8 @@ import copy import functools -from typing import Any, Callable, Dict, Generic, List, TypeVar + +from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, TypeVar import torch import torch.nn as nn @@ -19,6 +20,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler +from torchtitan.components.ft import has_torchft from torchtitan.config_manager import JobConfig @@ -30,6 +32,10 @@ ] +if has_torchft: + import torchft as ft + + T = TypeVar("T", bound=Optimizer) @@ -84,13 +90,13 @@ def __iter__(self) -> Optimizer: def __len__(self) -> int: return len(self.optimizers) - def step(self) -> None: + def step(self, *args, **kwargs) -> None: for optimizer in self.optimizers: - optimizer.step() + optimizer.step(*args, **kwargs) - def zero_grad(self) -> None: + def zero_grad(self, *args, **kwargs) -> None: for optimizer in self.optimizers: - optimizer.zero_grad() + optimizer.zero_grad(*args, **kwargs) def state_dict(self) -> Dict[str, Any]: func = functools.partial( @@ -114,7 +120,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def _validate_length(self, expected_length: int) -> None: assert expected_length == len( self.optimizers - ), "Must pass one optimizer per model part or per param if using OptimizersInBackwardContainer" + ), ( + "Must pass one optimizer per model part or per param if " + "using OptimizersInBackwardContainer." + ) def _post_init( self, all_params: list[nn.Parameter], optimizer_kwargs: dict[str, Any] @@ -175,8 +184,72 @@ def zero_grad(self) -> None: pass +class FTOptimizersContainer(OptimizersContainer): + def __init__( + self, + model_parts: List[nn.Module], + optimizer_cls: type[T], + optimizer_kwargs: Dict[str, Any], + ft_manager: Optional["ft.Manager"], + ) -> None: + super().__init__(model_parts, optimizer_cls, optimizer_kwargs) + + # Force to initialize the optimizer state so that `optim.step()` + # won't be called by state_dict() and load_state_dict(). + _ = { + k: v + for sd in map(get_optimizer_state_dict, model_parts, self.optimizers) + for k, v in sd.items() + } + self.cache_state_dict: Dict[str, Any] = {} + self._ft_optimizer = ft.Optimizer(ft_manager, self) + self._call_from_ft: bool = False + + def init_cache_state_dict(self) -> None: + self.cache_state_dict = super().state_dict() + + def state_dict(self) -> Dict[str, Any]: + return self.cache_state_dict + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + # We have to invalidate the `cache_state_dict` because optimizer uses + # assign instead of copy when doing `load_state_dict()`. Without + # invalidating the `cache_state_dict`, there will be memory leakage. + self.cache_state_dict = {} + super().load_state_dict(state_dict) + self.init_cache_state_dict() + + def step(self, *args, **kwargs) -> None: + """Calling the correct step() depending on the caller. + + TorchFT's OptimizerWrapper.step() is designed to be callled only once + per train step per ft.Manager regardless how many optimizers are used. + Hence we will need to appropriately dispatch the call. + """ + if self._call_from_ft: + super().step(*args, **kwargs) + else: + self._call_from_ft = True + self._ft_optimizer.step(*args, **kwargs) + self._call_from_ft = False + + def zero_grad(self, *args, **kwargs) -> None: + """Calling the correct zero_grad() depending on the caller. + + Check the comment in ``step()``. + """ + if self._call_from_ft: + super().zero_grad(*args, **kwargs) + else: + self._call_from_ft = True + self._ft_optimizer.zero_grad(*args, **kwargs) + self._call_from_ft = False + + def build_optimizers( - model_parts: List[nn.Module], job_config: JobConfig + model_parts: List[nn.Module], + job_config: JobConfig, + ft_manager: Optional["ft.Manager"] = None, ) -> OptimizersContainer: """Create a OptimizersContainer for the given model parts and job config. @@ -216,6 +289,7 @@ def build_optimizers( "fused": fused, "foreach": foreach, } + optimizer_classes = { "Adam": torch.optim.Adam, "AdamW": torch.optim.AdamW, @@ -223,11 +297,15 @@ def build_optimizers( if name not in optimizer_classes: raise NotImplementedError(f"Optimizer {name} not added.") optimizer_cls = optimizer_classes[name] - return ( - OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) - if not optim_in_bwd - else OptimizersInBackwardContainer(model_parts, optimizer_cls, optimizer_kwargs) - ) + + if optim_in_bwd and ft_manager: + raise ValueError("TorchFT is not supported with optimizers in backward.") + elif optim_in_bwd: + return OptimizersInBackwardContainer(model_parts, optimizer_cls, optimizer_kwargs) + elif ft_manager: + return FTOptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs, ft_manager) + else: + return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) class LRSchedulersContainer(Stateful): diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 8fbe65263..d5ff7f37a 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -672,6 +672,37 @@ def __init__(self): action="store_true", ) + self.parser.add_argument( + "--experimental.enable_torchft", + action="store_true", + help="Enable TorchFT integration.", + ) + + self.parser.add_argument( + "--experimental.ft_replica_id", + type=int, + default=0, + help="The TorchFT replica ID of this run.", + ) + + self.parser.add_argument( + "--experimental.ft_group_size", + type=int, + default=1, + help=""" + The number of TorchFT replicate groups. This number will be used for + dataloader to split the dataset across the replicate groups and FSDP + dimension. + """, + ) + + self.parser.add_argument( + "--experimental.ft_min_replica_size", + type=int, + default=1, + help="The minimum number of FT replica for each step.", + ) + def to_dict(self): return self.args_dict diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 6f8c32c5c..26d5e2df7 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from functools import cached_property +from typing import Optional, TYPE_CHECKING from torch.distributed.device_mesh import init_device_mesh @@ -15,6 +16,10 @@ __all__ = ["ParallelDims"] +if TYPE_CHECKING: + import torchft as ft + + @dataclass class ParallelDims: dp_replicate: int @@ -24,6 +29,7 @@ class ParallelDims: pp: int world_size: int enable_loss_parallel: bool + ft_manager: Optional["ft.Manager"] = None def __post_init__(self): self._validate() @@ -56,13 +62,24 @@ def build_mesh(self, device_type): [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], ["pp", "dp_replicate", "dp_shard", "cp", "tp"], ): - if d > 1: + if d > 1 or (name == "dp_replicate" and self.ft_manager is not None): dims.append(d) names.append(name) logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") names = tuple(names) - mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + if self.ft_manager is None: + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + else: + from torchft.process_group import ft_init_device_mesh + + mesh = ft_init_device_mesh( + device_type=device_type, + mesh_shape=dims, + mesh_dim_names=names, + replicate_dim=names.index("dp_replicate"), + manager=self.ft_manager, + ) # Create all the submesh here to ensure all required process groups are # initialized: @@ -73,7 +90,7 @@ def build_mesh(self, device_type): # Mesh for loss all-reduce dp_cp_mesh_dim_names = [] - if self.dp_replicate_enabled: + if self.dp_replicate_enabled or self.ft_manager is not None: dp_mesh_dim_names.append("dp_replicate") dp_cp_mesh_dim_names.append("dp_replicate") if self.dp_shard_enabled: @@ -101,7 +118,7 @@ def dp_enabled(self): @property def dp_replicate_enabled(self): - return self.dp_replicate > 1 + return self.dp_replicate > 1 or self.ft_manager is not None @property def dp_shard_enabled(self): diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 7311eb721..8fcc2baa0 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib +import copy import math import os from datetime import timedelta @@ -17,11 +18,22 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor +from torchtitan.components.ft import has_torchft from torchtitan.tools.logging import logger from torchtitan.tools.utils import device_module, device_type +if has_torchft: + import torchft as ft + def _dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float: + if has_torchft: + if isinstance(mesh, ft.process_group._FlattenDeviceMesh): + x = funcol.all_reduce( + x, reduceOp=reduceOp, group=mesh.managed_mesh.replicate_pg + ) + mesh = mesh.managed_mesh.mesh + if isinstance(x, DTensor): # functional collectives do not support DTensor inputs x = x.full_tensor() @@ -286,6 +298,17 @@ def clip_grad_norm_( if isinstance(total_norm, DTensor): # Will reach here if any non-PP parallelism is used. # If only using PP, total_norm will be a local tensor. + mesh = total_norm._spec.mesh + if has_torchft: + if isinstance(mesh, ft.process_group.ManagedDeviceMesh): + # The gradients along the replicated dim has already been reduced. + # So we don't need another reducution beforing removing the + # replicate dimension + local_tensor = total_norm.to_local() + placements = list(copy.copy(total_norm._spec.placements)) + placements.pop(mesh.replicate_dim) + total_norm = DTensor.from_local(local_tensor, mesh.mesh, placements) + total_norm = total_norm.full_tensor() if pp_mesh is not None: diff --git a/torchtitan/train.py b/torchtitan/train.py index 3763175e7..5b404f65c 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -12,8 +12,10 @@ from torch.distributed.elastic.multiprocessing.errors import record from torchtitan.components.checkpoint import CheckpointManager, TrainState +from torchtitan.components.ft import init_ft_manager from torchtitan.config_manager import JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils + from torchtitan.protocols.model_converter import build_model_converters from torchtitan.protocols.train_spec import get_train_spec @@ -43,6 +45,11 @@ def main(job_config: JobConfig): # take control of garbage collection to avoid stragglers gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) + device_module, device_type = utils.device_module, utils.device_type + device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") + device_module.set_device(device) + ft_manager = init_ft_manager(job_config) + # init distributed world_size = int(os.environ["WORLD_SIZE"]) parallel_dims = ParallelDims( @@ -53,10 +60,8 @@ def main(job_config: JobConfig): pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=not job_config.training.disable_loss_parallel, + ft_manager=ft_manager, ) - device_module, device_type = utils.device_module, utils.device_type - device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") - device_module.set_device(device) dist_utils.init_distributed(job_config) # initialize device memory monitor and get peak flops for MFU calculation device_memory_monitor = build_device_memory_monitor() @@ -82,6 +87,8 @@ def main(job_config: JobConfig): # build dataloader tokenizer = train_spec.tokenizer_cls(job_config.model.tokenizer_path) + dp_rank = dp_degree * job_config.experimental.ft_replica_id + dp_rank + dp_degree = dp_degree * job_config.experimental.ft_group_size dataloader = train_spec.build_dataloader_fn( dp_world_size=dp_degree, dp_rank=dp_rank, @@ -181,7 +188,7 @@ def main(job_config: JobConfig): ) # build optimizer after applying parallelisms to the model - optimizers = train_spec.build_optimizers_fn(model_parts, job_config) + optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager) lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config) # Post optimizer step model converters hook. # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 @@ -200,6 +207,7 @@ def main(job_config: JobConfig): lr_schedulers=lr_schedulers, states={"train_state": train_state}, job_config=job_config, + ft_manager=ft_manager, ) if job_config.checkpoint.create_seed_checkpoint: