Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add 2D Parallelism (FSDP + Tensor Parallel) LoRA #2204

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 97 additions & 5 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from omegaconf import DictConfig, ListConfig

from torch import nn
from torch.distributed import destroy_process_group, init_process_group
from torch.distributed import destroy_process_group, DeviceMesh, init_process_group
from torch.distributed.device_mesh import init_device_mesh

from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
Expand Down Expand Up @@ -116,6 +117,7 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):

Args:
cfg (DictConfig): OmegaConf object parsed from yaml file
device_mesh (DeviceMesh): DeviceMesh object that contains the device topology

Raises:
ValueError: If ``dtype`` is set to fp16.
Expand All @@ -126,9 +128,10 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False.
"""

def __init__(self, cfg: DictConfig) -> None:
def __init__(self, cfg: DictConfig, device_mesh: DeviceMesh) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = training.get_dtype(cfg.dtype, device=self._device)
self.device_mesh = device_mesh

if self._dtype == torch.float16:
raise ValueError(
Expand Down Expand Up @@ -268,6 +271,34 @@ def setup(self, cfg: DictConfig) -> None:
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
self._compile = cfg.get("compile", False)

# Function to remap the base_model_state_dict so we can shard the LoRALinear modules with Tensor Parallel
# Since DTensor currently only shards nn.Linear and nn.Embedding (not LoRALinear)
# we need to remap the original nn.Linear weights in the LoRALinear modules
def remap_base_model_state_dict(base_model_state_dict):
new_state_dict = {}
for k, v in base_model_state_dict.items():
if "q_proj.bias" in k or "output_proj.bias" in k or "v_proj.bias" in k:
new_state_dict[k.replace(".bias", ".linear.bias")] = v
elif (
"q_proj.weight" in k
or "output_proj.weight" in k
or "v_proj.weight" in k
):
new_state_dict[k.replace(".weight", ".linear.weight")] = v
elif "w1.bias" in k or "w2.bias" in k or "w3.bias" in k:
new_state_dict[k.replace(".bias", ".linear.bias")] = v
elif "w1.weight" in k or "w2.weight" in k or "w3.weight" in k:
new_state_dict[k.replace(".weight", ".linear.weight")] = v
else:
new_state_dict[k] = v
return new_state_dict

# Remap the base model state dict
base_model_state_dict = remap_base_model_state_dict(
checkpoint_dict[training.MODEL_KEY]
)
checkpoint_dict.update({training.MODEL_KEY: base_model_state_dict})

self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=self._enable_activation_checkpointing,
Expand Down Expand Up @@ -444,7 +475,7 @@ def _setup_model(

utils.log_rank_zero(
log,
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
"FSDP and TP is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
)
init_start = time.perf_counter()

Expand All @@ -470,6 +501,7 @@ def _setup_model(
]
training.shard_model(
model=model,
device_mesh=self.device_mesh,
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
Expand Down Expand Up @@ -602,8 +634,17 @@ def _setup_data(
raise RuntimeError("left_pad_sequence collator is only for inference.")
collate_fn = _get_component_from_path(collate_fn)

# Setup dp_rank and dp_size
dp_mesh = self.device_mesh["dp"]
dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()

# Create DistributedSampler with appropriate settings
sampler = DistributedSampler(
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
ds,
num_replicas=dp_degree, # number of dp ranks
rank=dp_rank, # use dp_rank, not the global rank
shuffle=shuffle,
seed=0,
)

dataloader = DataLoader(
Expand Down Expand Up @@ -657,6 +698,43 @@ def save_checkpoint(
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
state_dict = self._model.state_dict()

# Function to unmap the state_dict that we remapped for TensorParallel Sharding
# This aligns with what the checkpointer expects so we can save the checkpoint
def unmap_base_model_state_dict(base_model_state_dict):
new_state_dict = {}
for k, v in base_model_state_dict.items():
if (
"q_proj.linear.bias" in k
or "output_proj.linear.bias" in k
or "v_proj.linear.bias" in k
):
new_state_dict[k.replace(".linear.bias", ".bias")] = v
elif (
"q_proj.linear.weight" in k
or "output_proj.linear.weight" in k
or "v_proj.linear.weight" in k
):
new_state_dict[k.replace(".linear.weight", ".weight")] = v
elif (
"w1.linear.bias" in k
or "w2.linear.bias" in k
or "w3.linear.bias" in k
):
new_state_dict[k.replace(".linear.bias", ".bias")] = v
elif (
"w1.linear.weight" in k
or "w2.linear.weight" in k
or "w3.linear.weight" in k
):
new_state_dict[k.replace(".linear.weight", ".weight")] = v
else:
new_state_dict[k] = v
return new_state_dict

# Unmap the state dict so we can save the checkpoint
state_dict = unmap_base_model_state_dict(state_dict)

if self._save_adapter_weights_only:
state_dict = get_adapter_state_dict(state_dict, device=None)

Expand Down Expand Up @@ -921,9 +999,23 @@ def recipe_main(cfg: DictConfig) -> None:
# speed up when benchmarking fused AdamW on CPU
training.set_torch_num_threads()

# Get world size and rank to initialize the device mesh
world_size, rank = training.get_world_size_and_rank()
tp_size = 8
if world_size % tp_size != 0:
raise ValueError(
f"World size {world_size} must be divisible by tensor parallel size {tp_size}"
)
dp_size = world_size // tp_size

# Initialize device mesh
device_mesh = init_device_mesh(
"cuda", mesh_shape=(dp_size, tp_size), mesh_dim_names=("dp", "tp")
)

config.log_config(recipe_name="LoRAFinetuneRecipeDistributed", cfg=cfg)

recipe = LoRAFinetuneRecipeDistributed(cfg=cfg)
recipe = LoRAFinetuneRecipeDistributed(cfg=cfg, device_mesh=device_mesh)
recipe.setup(cfg=cfg)
recipe.train()
recipe.cleanup()
Expand Down
26 changes: 6 additions & 20 deletions torchtune/modules/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@

from torch import nn

from torchao.dtypes.nf4tensor import linear_nf4, to_nf4
from torchtune.modules.low_precision import _register_nf4_dispatch_ops # noqa: F401
from torchtune.modules.peft import AdapterModule


class LoRALinear(nn.Module, AdapterModule):
"""LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`_.
"""Modified LoRA Linear config to support Tensor Parallel Sharding with DTensors
(which currently only support sharding nn.Linear and nn.Embedding layers)

LoRA perturbs a given layer via a low-rank approximation where only
the rank decomposition matrices are trainable. In a linear layer instead of
Expand Down Expand Up @@ -70,23 +69,15 @@ def __init__(
f"``quantize_base`` is False, but received the following quantization arguments: {quantization_kwargs}"
)

# Setup weight and bias
linear = nn.Linear(in_features=in_dim, out_features=out_dim, bias=self.use_bias)
weight = (
linear.weight
if not self._quantize_base
else to_nf4(linear.weight, **quantization_kwargs)
# Setup weight and bias (these are the original weights that we will be loading from state_dict)
self.linear = nn.Linear(
in_features=in_dim, out_features=out_dim, bias=self.use_bias
)
bias = linear.bias if self.use_bias else None

# 'self.disabled' is a flag showing whether to turn off LoRA adapters,
# this can be used in DPO for treating the lora adapters as the policy model
# and disabling it to treat the base model as the reference model
self.disabled = False
self.register_parameter("weight", nn.Parameter(weight))
self.register_parameter(
"bias", nn.Parameter(bias) if bias is not None else None
)
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)
Expand Down Expand Up @@ -126,12 +117,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
torch.Tensor: output tensor with shape ``(..., out_dim)``

"""
if self._quantize_base:
out = linear_nf4(input=x, weight=self.weight)
if self.use_bias:
out = out + self.bias
else:
out = F.linear(x, self.weight, self.bias)
out = self.linear(x)
if self.disabled:
return out
lora_out = self.lora_a(self.dropout(x))
Expand Down
Loading