Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MoE parameter passing #8255

Merged
merged 7 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pipeline {
steps {
sh 'git clone https://github.com/NVIDIA/Megatron-LM.git && \
cd Megatron-LM && \
git checkout 240a8ef7a21df201e47b5b2ae33cc5f4c5486849 && \
git checkout 98da3792f53c80ac9e865eab49a6fa5ccc293d22 && \
pip install .'
}
}
Expand Down
21 changes: 18 additions & 3 deletions examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ def main(cfg) -> None:

assert (
cfg.trainer.devices * cfg.trainer.num_nodes
== cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size
== cfg.tensor_model_parallel_size
* cfg.pipeline_model_parallel_size
* max(1, cfg.get('expert_model_parallel_size', 1))
), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size"

if cfg.gpt_model_file:
Expand All @@ -224,6 +226,8 @@ def main(cfg) -> None:
# with dist checkpointing we can use the model parallel config specified by the user
pretrained_cfg.tensor_model_parallel_size = cfg.tensor_model_parallel_size
pretrained_cfg.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size
pretrained_cfg.expert_model_parallel_size = cfg.get('expert_model_parallel_size', 1)
pretrained_cfg.micro_batch_size = 1
if trainer.precision == "16":
pretrained_cfg.megatron_amp_O2 = False
elif trainer.precision in ['bf16', 'bf16-mixed'] and cfg.get('megatron_amp_O2', False):
Expand All @@ -237,13 +241,23 @@ def main(cfg) -> None:
)
elif cfg.checkpoint_dir:
app_state = AppState()
if cfg.tensor_model_parallel_size > 1 or cfg.pipeline_model_parallel_size > 1:
app_state.model_parallel_size = cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size
if (
cfg.tensor_model_parallel_size > 1
or cfg.pipeline_model_parallel_size > 1
or cfg.get('expert_model_parallel_size', 1) > 1
):
app_state.model_parallel_size = (
cfg.tensor_model_parallel_size
* cfg.pipeline_model_parallel_size
* cfg.get('expert_model_parallel_size', 1)
)
app_state.tensor_model_parallel_size = cfg.tensor_model_parallel_size
app_state.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size
app_state.expert_model_parallel_size = cfg.get('expert_model_parallel_size', 1)
(
app_state.tensor_model_parallel_rank,
app_state.pipeline_model_parallel_rank,
app_state.expert_model_parallel_rank,
app_state.model_parallel_size,
app_state.data_parallel_size,
app_state.pipeline_model_parallel_split_rank,
Expand All @@ -254,6 +268,7 @@ def main(cfg) -> None:
tensor_model_parallel_size_=cfg.tensor_model_parallel_size,
pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size,
pipeline_model_parallel_split_rank_=cfg.pipeline_model_parallel_split_rank,
expert_model_parallel_size_=cfg.get('expert_model_parallel_size', 1),
)
checkpoint_path = os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name)
# checkpoint_path is a dir in case of distributed checkpointing
Expand Down
1 change: 1 addition & 0 deletions examples/nlp/language_modeling/tuning/megatron_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
gpt_cfg.ffn_dropout = cfg.model.ffn_dropout
gpt_cfg.use_flash_attention = cfg.model.get('use_flash_attention', False)
gpt_cfg.tensor_model_parallel_size = cfg.model.get('tensor_model_parallel_size', 1)
gpt_cfg.expert_model_parallel_size = cfg.model.get('expert_model_parallel_size', 1)
gpt_cfg.pipeline_model_parallel_size = cfg.model.get('pipeline_model_parallel_size', 1)
gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get('pipeline_model_parallel_split_rank', 0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
# Overrides used when converting checkpoints
if os.environ.get(NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, "false").lower() == "true":
app_state = AppState()
init_world_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size
init_world_size = (
app_state.tensor_model_parallel_size
* app_state.pipeline_model_parallel_size
* (app_state.expert_model_parallel_size or 1)
)
init_global_rank = app_state.global_rank
init_local_rank = app_state.local_rank
else:
Expand All @@ -185,6 +189,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
global_rank=init_global_rank,
local_rank=init_local_rank,
tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1),
expert_model_parallel_size=cfg.get('expert_model_parallel_size', 1),
pipeline_model_parallel_size=cfg.get('pipeline_model_parallel_size', 1),
virtual_pipeline_model_parallel_size=vp_size,
pipeline_model_parallel_split_rank=cfg.get('pipeline_model_parallel_split_rank', 0),
Expand Down
26 changes: 26 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from megatron.core import tensor_parallel
from megatron.core.parallel_state import (
get_pipeline_model_parallel_rank,
set_expert_model_parallel_rank,
set_expert_model_parallel_world_size,
set_pipeline_model_parallel_rank,
set_pipeline_model_parallel_split_rank,
set_pipeline_model_parallel_world_size,
Expand Down Expand Up @@ -60,6 +62,7 @@ def initialize_model_parallel_for_nemo(
global_rank,
local_rank,
tensor_model_parallel_size=1,
expert_model_parallel_size=1,
pipeline_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None,
pipeline_model_parallel_split_rank=None,
Expand All @@ -81,6 +84,7 @@ def initialize_model_parallel_for_nemo(
app_state.global_rank = global_rank
app_state.world_size = world_size
app_state.local_rank = local_rank
app_state.expert_model_parallel_size = expert_model_parallel_size
app_state.tensor_model_parallel_size = tensor_model_parallel_size
app_state.pipeline_model_parallel_size = pipeline_model_parallel_size
app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size
Expand All @@ -90,6 +94,7 @@ def initialize_model_parallel_for_nemo(
(
app_state.tensor_model_parallel_rank,
app_state.pipeline_model_parallel_rank,
app_state.expert_model_parallel_rank,
app_state.model_parallel_size,
app_state.data_parallel_size,
app_state.pipeline_model_parallel_split_rank,
Expand All @@ -102,12 +107,16 @@ def initialize_model_parallel_for_nemo(
virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size,
pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank,
context_parallel_size_=context_parallel_size,
expert_model_parallel_size_=expert_model_parallel_size,
)

# update apex.transformer globals
set_tensor_model_parallel_world_size(app_state.tensor_model_parallel_size)
set_tensor_model_parallel_rank(app_state.tensor_model_parallel_rank)

set_expert_model_parallel_world_size(app_state.expert_model_parallel_size)
set_expert_model_parallel_rank(app_state.expert_model_parallel_rank)

set_pipeline_model_parallel_rank(app_state.pipeline_model_parallel_rank)
if HAVE_INTERLEAVED:
set_virtual_pipeline_model_parallel_world_size(app_state.virtual_pipeline_model_parallel_size)
Expand Down Expand Up @@ -179,6 +188,7 @@ def fake_initialize_model_parallel(
pipeline_model_parallel_size_,
pipeline_model_parallel_split_rank_=None,
virtual_pipeline_model_parallel_size_=None,
expert_model_parallel_size_=1,
context_parallel_size_=1,
):
"""
Expand Down Expand Up @@ -302,6 +312,21 @@ def fake_initialize_model_parallel(
logging.info(f'All tensor model parallel group ranks: {all_tensor_model_parallel_group_ranks}')
logging.info(f'Rank {rank} has tensor model parallel rank: {tensor_model_parallel_rank}')

# EP rank
expert_model_parallel_rank = 0
if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1:
tensor_and_data_group_size: int = tensor_model_parallel_size * data_parallel_size
num_tensor_and_data_groups: int = world_size // tensor_and_data_group_size
tensor_and_expert_group_size: int = tensor_model_parallel_size * expert_model_parallel_size_
num_expert_groups: int = data_parallel_size // expert_model_parallel_size_
for i in range(num_tensor_and_data_groups):
for j in range(num_expert_groups):
start_rank = i * tensor_and_data_group_size + j * tensor_and_expert_group_size
end_rank = i * tensor_and_data_group_size + (j + 1) * tensor_and_expert_group_size
ranks = range(start_rank, end_rank)
if rank in ranks:
expert_model_parallel_rank = list(ranks).index(rank)

# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
all_pipeline_model_parallel_group_ranks = []
Expand Down Expand Up @@ -340,6 +365,7 @@ def fake_initialize_model_parallel(
return (
tensor_model_parallel_rank,
pipeline_model_parallel_rank,
expert_model_parallel_rank,
model_parallel_size,
data_parallel_size,
pipeline_model_parallel_split_rank_,
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def init_model_parallel(sharp: bool, nccl_communicator_config_path: str = None)
context_parallel_size=app_state.context_parallel_size,
nccl_communicator_config_path=nccl_communicator_config_path,
use_sharp=sharp,
expert_model_parallel_size=app_state.expert_model_parallel_size,
)

# assert that fake tp and pp rank match after model parallel init
Expand Down
34 changes: 34 additions & 0 deletions nemo/utils/app_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ def __init__(self):
self._local_rank = None
self._global_rank = None
self._tensor_model_parallel_rank = None
self._expert_model_parallel_rank = None
self._pipeline_model_parallel_rank = None
self._data_parallel_rank = None

self._world_size = None
self._model_parallel_size = None
self._tensor_model_parallel_size = None
self._tensor_model_parallel_group = None
self._expert_model_parallel_size = None
self._pipeline_model_parallel_size = None
self._virtual_pipeline_model_parallel_size = None
self._pipeline_model_parallel_group = None
Expand Down Expand Up @@ -141,6 +143,38 @@ def tensor_model_parallel_size(self, size):
"""
self._tensor_model_parallel_size = size

@property
def expert_model_parallel_rank(self):
""" Property returns the expert model parallel rank.
Returns:
Tensor model parallel rank.
"""
return self._expert_model_parallel_rank

@expert_model_parallel_rank.setter
def expert_model_parallel_rank(self, rank):
""" Property sets the expert model parallel rank.
Args:
rank (int): Tensor model parallel rank.
"""
self._expert_model_parallel_rank = rank

@property
def expert_model_parallel_size(self):
""" Property returns the number of GPUs in each expert parallel group.
Returns:
Number of GPUs in each expert parallel group.
"""
return self._expert_model_parallel_size

@expert_model_parallel_size.setter
def expert_model_parallel_size(self, size):
""" Property sets the number of GPUs in each expert parallel group.
Args:
size (int): Number of GPUs in each expert parallel group.
"""
self._expert_model_parallel_size = size

@property
def pipeline_model_parallel_size(self):
""" Property returns the number of GPUs in each model parallel group.
Expand Down
Loading