Skip to content

Commit

Permalink
Used per-parameter FSDP
Browse files Browse the repository at this point in the history
  • Loading branch information
awgu committed Mar 27, 2024
1 parent 8dd5798 commit dbb793a
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 195 deletions.
48 changes: 0 additions & 48 deletions torchtrain/meta_init.py

This file was deleted.

97 changes: 41 additions & 56 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(dim))

# re-enable if not using meta-init
# self.reset_parameters()
self.reset_parameters()

def _norm(self, x: torch.Tensor):
"""
Expand Down Expand Up @@ -207,19 +205,10 @@ def __init__(self, model_args: ModelArgs):
model_args.n_heads * self.head_dim, model_args.dim, bias=False
)

def reset_parameters(self, init_std):
for item in (self.wq, self.wk, self.wv):
nn.init.trunc_normal_(
item.weight,
mean=0.0,
std=0.02,
)

nn.init.trunc_normal_(
self.wo.weight,
mean=0.0,
std=init_std,
)
def init_weights(self, init_std: float):
for linear in (self.wq, self.wk, self.wv):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)

def forward(
self,
Expand Down Expand Up @@ -309,19 +298,10 @@ def __init__(
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))

def reset_parameters(self, init_std):
nn.init.trunc_normal_(
self.w1.weight,
mean=0.0,
std=0.02,
)

for item in (self.w2, self.w3):
nn.init.trunc_normal_(
item.weight,
mean=0.0,
std=init_std,
)
def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
for linear in (self.w2, self.w3):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)


class RotaryEmbedding(nn.Module):
Expand All @@ -333,13 +313,15 @@ def __init__(self, model_args: ModelArgs):
super().__init__()
self.model_args = model_args
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
self.register_buffer(
"freqs_cis", self._precompute_freqs_cis(), persistent=False
)

self.freqs_cis = precompute_freqs_cis(
# Note that self.model_args.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation
# of models is 4096.
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training
# or fine-tuning.
def _precompute_freqs_cis(self):
return precompute_freqs_cis(
self.model_args.dim // self.model_args.n_heads,
# Need to compute until at least the max token limit for generation
# (use 2x max sequence length to be safe)
self.model_args.max_seq_len * 2,
)

Expand All @@ -355,10 +337,14 @@ def forward(self, tokens: torch.Tensor):
"""
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[0:seqlen]
return h, freqs_cis

def init_weights(self):
with torch.device(self.freqs_cis.device):
self.freqs_cis = self._precompute_freqs_cis()
nn.init.normal_(self.tok_embeddings.weight)


class TransformerBlock(nn.Module):
"""
Expand Down Expand Up @@ -421,13 +407,11 @@ def forward(
out = h + self.feed_forward(self.ffn_norm(h))
return out

def reset_parameters(self):
"""reset params and norms for entire block"""
self.attention_norm.reset_parameters()
self.ffn_norm.reset_parameters()

self.attention.reset_parameters(self.weight_init_std)
self.feed_forward.reset_parameters(self.weight_init_std)
def init_weights(self):
for norm in (self.attention_norm, self.ffn_norm):
norm.reset_parameters()
self.attention.init_weights(self.weight_init_std)
self.feed_forward.init_weights(self.weight_init_std)


class Transformer(nn.Module):
Expand Down Expand Up @@ -457,28 +441,29 @@ def __init__(self, model_args: ModelArgs):
self.model_dim = model_args.dim

self.embeddings = RotaryEmbedding(model_args)

self.layers = torch.nn.ModuleList()
for layer_id in range(model_args.n_layers):
self.layers.append(TransformerBlock(layer_id, model_args))

self.norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
self.init_weights()

# init model weights

# we are doing meta_init, which will call reset_parameters() after
# the model is moved to actual device.
# If you modify and are not using meta_init, you will need to call
# reset_parameters() manually as below:

# self.reset_parameters()

def reset_parameters(
self,
):
def init_weights(self):
"""
[Note: On ``init_weights`` vs. ``reset_parameters``]
Modules may define ``reset_parameters`` to initialize parameter values.
``reset_parameters`` is meant to only initialize directly owned
parameters/buffers, not those of their children, and it can be
used to give the initial values for these tensors.
Separately, users may want custom initialization for their modules,
different from that in ``reset_parameters``. For this, we define
``init_weights``. We only call it in the constructor of this
``Transformer`` root module to avoid reinitializing tensors.
"""
self.embeddings.init_weights()
for layer in self.layers:
layer.reset_parameters()
layer.init_weights()
self.norm.reset_parameters()
final_out_std = self.model_dim**-0.5
cutoff_factor = 3
Expand Down
106 changes: 40 additions & 66 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,13 @@
from typing import Tuple

import torch
from torch.distributed._tensor import Replicate, Shard

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.fsdp import (
BackwardPrefetch,
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import enable_wrap, wrap
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
Expand All @@ -33,7 +27,6 @@

from torchtrain.config_manager import JobConfig
from torchtrain.logging_utils import logger
from torchtrain.meta_init import meta_to_real_init_fn


# for selective AC
Expand Down Expand Up @@ -75,7 +68,6 @@ def selective_checkpointing_context_fn():
preserve_rng_state=False,
)
elif config.mode == "full":
# full AC
return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
Expand Down Expand Up @@ -136,28 +128,23 @@ def get_tp_parallel_strategy(

def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply parallelisms to the model, including PTD parallelisms, and AC.
Apply parallelisms and activation checkpointing to the model.
NOTE: the model passed in preferrablably shoule be a meta device model,
otherwise the model needs to be small enough on GPU or can fit into CPU.
NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.
"""
# apply PTD parallelisms
if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet.")

# First we apply Tensor Parallelism if it's enabled
if parallel_dims.tp_enabled:
tp_mesh = world_mesh["tp"]
tp_degree = job_config.training.tensor_parallel_degree

row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy(
job_config
)

# First:
# 1. parallelize the first embedding and the last linear proj layer
# 2. parallelize the root norm layer by sequence dim
# 3. shard the first layer of transformer block
# 1. Parallelize the first embedding and the last linear proj layer
# 2. Parallelize the root norm layer over the sequence dim
# 3. Shard the first transformer block's inputs
model = parallelize_module(
model,
tp_mesh,
Expand All @@ -167,9 +154,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
),
"output": col_parallel_strategy(
input_layouts=Shard(0),
output_layouts=Shard(-1)
if parallel_dims.loss_parallel_enabled
else Replicate(),
output_layouts=(
Shard(-1)
if parallel_dims.loss_parallel_enabled
else Replicate()
),
use_local_output=not parallel_dims.loss_parallel_enabled,
),
"norm": SequenceParallel(sequence_dim=0),
Expand All @@ -181,7 +170,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
},
)

# apply tensor + sequence parallelism to every transformer block
# Apply tensor + sequence parallelism to every transformer block
for layer_id, transformer_block in enumerate(model.layers):
layer_plan = {
"attention": PrepareModuleInput(
Expand All @@ -203,62 +192,47 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"ffn_norm": SequenceParallel(sequence_dim=0),
}

# adjust num_heads in attention layer to local heads
# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_degree
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_degree
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)

logger.info("Applied Sequence Parallelism to the model")
logger.info("Applied Tensor Parallelism to the model")

if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]

fsdp_config = {
"mixed_precision": MixedPrecision(
param_dtype=torch.bfloat16,
# TODO: see whether we should expose a option to user
reduce_dtype=torch.float32,
),
"sharding_strategy": ShardingStrategy.FULL_SHARD,
"backward_prefetch": BackwardPrefetch.BACKWARD_PRE,
# When torch.compile is active, it requires us to set use_orig_params=True
"use_orig_params": True,
"device_mesh": dp_mesh,
"param_init_fn": meta_to_real_init_fn,
}

dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
# TODO: Expose `reduce_dtype` as a config option.
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
)
ac_mode = job_config.activation_checkpoint.mode
with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
for layer_id, transformer_block in enumerate(model.layers):
# apply AC to the transformer block
if ac_mode in ("full", "selective"):
# wrap the transformer block with checkpoint wrapper, using config settings
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)

# Wraps each layer with FSDP
model.layers[layer_id] = wrap(transformer_block)

# wrap the rest layers with FSDP
model = wrap(model)

fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in enumerate(model.layers):
if job_config.activation_checkpoint.mode in ("full", "selective"):
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = layer_id < len(model.layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.layers[layer_id] = transformer_block
model = fully_shard(model, **fsdp_config)
if ac_mode in ("full", "selective"):
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
logger.info("Applied FSDP to the model")
else:
meta_to_real_init_fn(model)
model.cuda()

# we have now moved from meta to device,
# reset parameters for proper initialization
model.reset_parameters()
logger.info("Model fully initialized via reset_parameters")

return model
Loading

0 comments on commit dbb793a

Please sign in to comment.