Skip to content

Commit

Permalink
Merge pull request #65 from foundation-model-stack/meta_init
Browse files Browse the repository at this point in the history
switch to new meta device init method
  • Loading branch information
lchu6 authored Mar 28, 2024
2 parents e53bac1 + 7b862bd commit a2d51ac
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 19 deletions.
1 change: 1 addition & 0 deletions fms_fsdp/policies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .ac_handler import apply_fsdp_checkpointing
from .mixed_precision import *
from .param_init import param_init_function
from .wrapping import get_llama_wrapper
18 changes: 18 additions & 0 deletions fms_fsdp/policies/param_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
from fms.modules.attention import MultiHeadAttention
from fms.modules.embedding import WordEmbedding
from fms.modules.feedforward import GatedLinearUnit
from fms.modules.layernorm import LayerNormParameterized


# for details, read https://github.com/foundation-model-stack/fms-fsdp/issues/64
def param_init_function(module):
if (
isinstance(module, MultiHeadAttention)
or isinstance(module, WordEmbedding)
or isinstance(module, GatedLinearUnit)
or isinstance(module, LayerNormParameterized)
):
module.to_empty(device=torch.cuda.current_device())
with torch.no_grad():
module.reset_parameters()
13 changes: 9 additions & 4 deletions fms_fsdp/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,16 @@ def setup_environ_flags():


def get_policies(cfg, rank):
"""Get the policies for mixed precision and fsdp wrapping and sharding strategy"""
"""Get policies for mixed precision, FSDP wrapping, sharding strategy and param init function."""

# mixed precision
verify_bfloat_support = (
torch.version.cuda
and torch.cuda.is_bf16_supported()
and packaging.version.parse(torch.version.cuda).release >= (11, 0)
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
)

# mixed precision
if cfg.mixed_precision:
bf16_ready = verify_bfloat_support
if bf16_ready:
Expand Down Expand Up @@ -219,7 +218,13 @@ def get_policies(cfg, rank):
if rank == 0:
print(f"Sharding strategy = {cfg.sharding_strategy}")

return mixed_precision_policy, wrapping_policy, sharding_strategy
# param init function
if cfg.low_cpu_fsdp:
param_init_fn = param_init_function
else:
param_init_fn = None

return mixed_precision_policy, wrapping_policy, sharding_strategy, param_init_fn


def get_profiler(cfg, rank):
Expand Down
23 changes: 8 additions & 15 deletions main_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,18 @@ def main(**kwargs):
setup_environ_flags()

# get policy
mixed_precision_policy, wrapping_policy, sharding_strategy_policy = get_policies(
cfg, rank
)
(
mixed_precision_policy,
wrapping_policy,
sharding_strategy_policy,
param_init_fn,
) = get_policies(cfg, rank)

# get fms model
llama_config = get_model_config(cfg.model_variant)

if cfg.low_cpu_fsdp:
if rank == 0:
with torch.device("meta"):
model = LLaMA(llama_config)
model.reset_parameters()
else:
with torch.device("meta"):
model = LLaMA(llama_config)
else:
model = LLaMA(llama_config)
model.reset_parameters()
Expand Down Expand Up @@ -87,12 +85,7 @@ def main(**kwargs):
use_orig_params=cfg.use_torch_compile,
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
sync_module_states=cfg.low_cpu_fsdp,
param_init_fn=lambda module: (
module.to_empty(device=torch.device("cuda"), recurse=False)
if cfg.low_cpu_fsdp
else None
),
param_init_fn=param_init_fn,
)
# we need this post-fsdp call to avoid graph break with torch.compile, until we figure out a better solution.
model.rot_emb.compute_freqs_cis(
Expand Down

0 comments on commit a2d51ac

Please sign in to comment.