diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 74516e7fa2..0ce3162d42 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -15,7 +15,8 @@ from omegaconf import DictConfig, ListConfig from torch import nn -from torch.distributed import destroy_process_group, init_process_group +from torch.distributed import destroy_process_group, DeviceMesh, init_process_group +from torch.distributed.device_mesh import init_device_mesh from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler @@ -116,6 +117,7 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface): Args: cfg (DictConfig): OmegaConf object parsed from yaml file + device_mesh (DeviceMesh): DeviceMesh object that contains the device topology Raises: ValueError: If ``dtype`` is set to fp16. @@ -126,9 +128,10 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface): RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ - def __init__(self, cfg: DictConfig) -> None: + def __init__(self, cfg: DictConfig, device_mesh: DeviceMesh) -> None: self._device = utils.get_device(device=cfg.device) self._dtype = training.get_dtype(cfg.dtype, device=self._device) + self.device_mesh = device_mesh if self._dtype == torch.float16: raise ValueError( @@ -268,6 +271,34 @@ def setup(self, cfg: DictConfig) -> None: checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) self._compile = cfg.get("compile", False) + # Function to remap the base_model_state_dict so we can shard the LoRALinear modules with Tensor Parallel + # Since DTensor currently only shards nn.Linear and nn.Embedding (not LoRALinear) + # we need to remap the original nn.Linear weights in the LoRALinear modules + def remap_base_model_state_dict(base_model_state_dict): + new_state_dict = {} + for k, v in base_model_state_dict.items(): + if "q_proj.bias" in k or "output_proj.bias" in k or "v_proj.bias" in k: + new_state_dict[k.replace(".bias", ".linear.bias")] = v + elif ( + "q_proj.weight" in k + or "output_proj.weight" in k + or "v_proj.weight" in k + ): + new_state_dict[k.replace(".weight", ".linear.weight")] = v + elif "w1.bias" in k or "w2.bias" in k or "w3.bias" in k: + new_state_dict[k.replace(".bias", ".linear.bias")] = v + elif "w1.weight" in k or "w2.weight" in k or "w3.weight" in k: + new_state_dict[k.replace(".weight", ".linear.weight")] = v + else: + new_state_dict[k] = v + return new_state_dict + + # Remap the base model state dict + base_model_state_dict = remap_base_model_state_dict( + checkpoint_dict[training.MODEL_KEY] + ) + checkpoint_dict.update({training.MODEL_KEY: base_model_state_dict}) + self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=self._enable_activation_checkpointing, @@ -444,7 +475,7 @@ def _setup_model( utils.log_rank_zero( log, - "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...", + "FSDP and TP is enabled. Instantiating model and loading checkpoint on Rank 0 ...", ) init_start = time.perf_counter() @@ -470,6 +501,7 @@ def _setup_model( ] training.shard_model( model=model, + device_mesh=self.device_mesh, shard_conditions=fsdp_shard_conditions, cpu_offload=fsdp_cpu_offload, reshard_after_forward=reshard_after_forward, @@ -602,8 +634,17 @@ def _setup_data( raise RuntimeError("left_pad_sequence collator is only for inference.") collate_fn = _get_component_from_path(collate_fn) + # Setup dp_rank and dp_size + dp_mesh = self.device_mesh["dp"] + dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() + + # Create DistributedSampler with appropriate settings sampler = DistributedSampler( - ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 + ds, + num_replicas=dp_degree, # number of dp ranks + rank=dp_rank, # use dp_rank, not the global rank + shuffle=shuffle, + seed=0, ) dataloader = DataLoader( @@ -657,6 +698,43 @@ def save_checkpoint( # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 state_dict = self._model.state_dict() + + # Function to unmap the state_dict that we remapped for TensorParallel Sharding + # This aligns with what the checkpointer expects so we can save the checkpoint + def unmap_base_model_state_dict(base_model_state_dict): + new_state_dict = {} + for k, v in base_model_state_dict.items(): + if ( + "q_proj.linear.bias" in k + or "output_proj.linear.bias" in k + or "v_proj.linear.bias" in k + ): + new_state_dict[k.replace(".linear.bias", ".bias")] = v + elif ( + "q_proj.linear.weight" in k + or "output_proj.linear.weight" in k + or "v_proj.linear.weight" in k + ): + new_state_dict[k.replace(".linear.weight", ".weight")] = v + elif ( + "w1.linear.bias" in k + or "w2.linear.bias" in k + or "w3.linear.bias" in k + ): + new_state_dict[k.replace(".linear.bias", ".bias")] = v + elif ( + "w1.linear.weight" in k + or "w2.linear.weight" in k + or "w3.linear.weight" in k + ): + new_state_dict[k.replace(".linear.weight", ".weight")] = v + else: + new_state_dict[k] = v + return new_state_dict + + # Unmap the state dict so we can save the checkpoint + state_dict = unmap_base_model_state_dict(state_dict) + if self._save_adapter_weights_only: state_dict = get_adapter_state_dict(state_dict, device=None) @@ -921,9 +999,23 @@ def recipe_main(cfg: DictConfig) -> None: # speed up when benchmarking fused AdamW on CPU training.set_torch_num_threads() + # Get world size and rank to initialize the device mesh + world_size, rank = training.get_world_size_and_rank() + tp_size = 8 + if world_size % tp_size != 0: + raise ValueError( + f"World size {world_size} must be divisible by tensor parallel size {tp_size}" + ) + dp_size = world_size // tp_size + + # Initialize device mesh + device_mesh = init_device_mesh( + "cuda", mesh_shape=(dp_size, tp_size), mesh_dim_names=("dp", "tp") + ) + config.log_config(recipe_name="LoRAFinetuneRecipeDistributed", cfg=cfg) - recipe = LoRAFinetuneRecipeDistributed(cfg=cfg) + recipe = LoRAFinetuneRecipeDistributed(cfg=cfg, device_mesh=device_mesh) recipe.setup(cfg=cfg) recipe.train() recipe.cleanup() diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py index 4c30c7503a..7782152e52 100644 --- a/torchtune/modules/peft/lora.py +++ b/torchtune/modules/peft/lora.py @@ -11,13 +11,12 @@ from torch import nn -from torchao.dtypes.nf4tensor import linear_nf4, to_nf4 -from torchtune.modules.low_precision import _register_nf4_dispatch_ops # noqa: F401 from torchtune.modules.peft import AdapterModule class LoRALinear(nn.Module, AdapterModule): - """LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models `_. + """Modified LoRA Linear config to support Tensor Parallel Sharding with DTensors + (which currently only support sharding nn.Linear and nn.Embedding layers) LoRA perturbs a given layer via a low-rank approximation where only the rank decomposition matrices are trainable. In a linear layer instead of @@ -70,23 +69,15 @@ def __init__( f"``quantize_base`` is False, but received the following quantization arguments: {quantization_kwargs}" ) - # Setup weight and bias - linear = nn.Linear(in_features=in_dim, out_features=out_dim, bias=self.use_bias) - weight = ( - linear.weight - if not self._quantize_base - else to_nf4(linear.weight, **quantization_kwargs) + # Setup weight and bias (these are the original weights that we will be loading from state_dict) + self.linear = nn.Linear( + in_features=in_dim, out_features=out_dim, bias=self.use_bias ) - bias = linear.bias if self.use_bias else None # 'self.disabled' is a flag showing whether to turn off LoRA adapters, # this can be used in DPO for treating the lora adapters as the policy model # and disabling it to treat the base model as the reference model self.disabled = False - self.register_parameter("weight", nn.Parameter(weight)) - self.register_parameter( - "bias", nn.Parameter(bias) if bias is not None else None - ) self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) @@ -126,12 +117,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch.Tensor: output tensor with shape ``(..., out_dim)`` """ - if self._quantize_base: - out = linear_nf4(input=x, weight=self.weight) - if self.use_bias: - out = out + self.bias - else: - out = F.linear(x, self.weight, self.bias) + out = self.linear(x) if self.disabled: return out lora_out = self.lora_a(self.dropout(x)) diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index e46b1ceecf..7374682e92 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -13,12 +13,20 @@ import torch import torch.distributed as dist from torch import nn +from torch.distributed import DeviceMesh from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard -from torch.distributed._tensor import distribute_tensor, DTensor +from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed.checkpoint.state_dict import _init_optim_state from torch.distributed.fsdp import ShardingStrategy +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) from torch.optim import Optimizer from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 from torchtune.modules import TransformerDecoder @@ -434,6 +442,7 @@ def get_shard_conditions( def shard_model( model: TransformerDecoder, + device_mesh: DeviceMesh, shard_conditions: List[Callable[[str, nn.Module], bool]], *, cpu_offload: bool, @@ -447,6 +456,7 @@ def shard_model( Args: model (TransformerDecoder): Model to shard with FSDP. + device_mesh (DeviceMesh): Device mesh to shard the model with. For now just DP and/or TP mesh. shard_conditions (List[Callable[[str, nn.Module], bool]]): A list of functions to determine which modules to shard with FSDP. Each function should take module name (relative to root) and the module itself, returning True if FSDP should shard the module and False otherwise. @@ -460,6 +470,133 @@ def shard_model( Raises: ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered. """ + + def apply_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8: bool, + enable_async_tp: bool, + ): + """Apply tensor parallelism.""" + # Slightly modified from torchtitan to work with LoRA + # See: https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + seq_parallel_input_output_layer_plan = { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + } + + parallelize_module(model, tp_mesh, seq_parallel_input_output_layer_plan) + + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears + if enable_float8: + # TODO(vkuzo): once float8 configuration supports delayed scaling, + # add a check here to enforce supported float8 all-gather configurations + # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there + from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + Float8RowwiseParallel, + Float8ColwiseParallel, + PrepareFloat8ModuleInput, + ) + else: + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + # Apply tensor + sequence parallelism to every transformer block + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel + # by folding (and unfolding) the batch dimension and the sequence dimension. + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + for i, transformer_block in enumerate(model.layers): + + seq_parallel_layer_plan = { + "sa_norm": SequenceParallel(), + "attn": prepare_module_input( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ), + "attn.k_proj": colwise_parallel(), + "attn.q_proj.linear": colwise_parallel(), + "attn.v_proj.linear": colwise_parallel(), + "attn.q_proj.lora_a": colwise_parallel(), + "attn.v_proj.lora_a": colwise_parallel(), + # Shard the q_proj.lora_b and v_proj.lora_b outputs colwise + # since they are added to outputs of q_proj and v_proj (which are colwise sharded) + "attn.q_proj.lora_b": rowwise_parallel(output_layouts=Shard(-1)), + "attn.v_proj.lora_b": rowwise_parallel(output_layouts=Shard(-1)), + "attn.output_proj.linear": rowwise_parallel(output_layouts=Shard(1)), + "attn.output_proj.lora_a": rowwise_parallel(), + "attn.output_proj.lora_b": colwise_parallel(output_layouts=Shard(1)), + "mlp_norm": SequenceParallel(), + "mlp": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "mlp.w1.linear": colwise_parallel(), + "mlp.w1.lora_a": colwise_parallel(), + "mlp.w3.linear": colwise_parallel(), + "mlp.w3.lora_a": colwise_parallel(), + # Shard the .w1.lora_b and .w3.lora_b outputs colwise + # since they are added to outputs of w1 and w3 (which are colwise sharded) + "mlp.w1.lora_b": rowwise_parallel(output_layouts=Shard(-1)), + "mlp.w3.lora_b": rowwise_parallel(output_layouts=Shard(-1)), + "mlp.w2.linear": rowwise_parallel(output_layouts=Shard(1)), + "mlp.w2.lora_a": rowwise_parallel(), + "mlp.w2.lora_b": colwise_parallel(output_layouts=Shard(1)), + } + + # Adjust attention module to use the local number of heads + attn_layer = transformer_block.attn + attn_layer.num_heads = attn_layer.num_heads // tp_mesh.size() + attn_layer.num_kv_heads = attn_layer.num_kv_heads // tp_mesh.size() + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=seq_parallel_layer_plan, + ) + + if enable_async_tp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + + _log.info( + f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" + ) + + # Apply Tensor Parallelism to the model + apply_tp( + model=model, + tp_mesh=device_mesh["tp"], + loss_parallel=False, # Don't support loss parallelism for now + enable_float8=False, # Don't support float8 for now + enable_async_tp=False, # Don't support async TP for now + ) + fsdp_kwargs = {"reshard_after_forward": reshard_after_forward} if cpu_offload: fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() @@ -469,7 +606,9 @@ def shard_model( num_layers_sharded = 0 for n, m in reversed(list(model.named_modules())): if any([shard_condition(n, m) for shard_condition in shard_conditions]): - fully_shard(m, **fsdp_kwargs) + fully_shard( + m, mesh=device_mesh["dp"], **fsdp_kwargs + ) # Use the DP mesh here num_layers_sharded += 1 if num_layers_sharded == 0: @@ -478,4 +617,4 @@ def shard_model( ) # Finally shard the entire model to account for any stragglers - fully_shard(model, **fsdp_kwargs) + fully_shard(model, mesh=device_mesh["dp"], **fsdp_kwargs)