Skip to content

Commit

Permalink
Update fake_initialization.
Browse files Browse the repository at this point in the history
  • Loading branch information
Victarry committed Nov 27, 2024
1 parent be0bd93 commit c2ed354
Showing 1 changed file with 35 additions and 4 deletions.
39 changes: 35 additions & 4 deletions nemo/lightning/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def fake_initialize_model_parallel(
pipeline_model_parallel_split_rank_=None,
virtual_pipeline_model_parallel_size_=None,
expert_model_parallel_size_=1,
expert_tensor_parallel_size_=None,
context_parallel_size_=1,
encoder_tensor_model_parallel_size_=0,
encoder_pipeline_model_parallel_size_=0,
Expand Down Expand Up @@ -349,23 +350,53 @@ def fake_initialize_model_parallel(

decoder_rank_generator = RankGenerator(
tp=tensor_model_parallel_size,
ep=expert_model_parallel_size_,
ep=1,
dp=data_parallel_size,
pp=pipeline_model_parallel_size,
cp=context_parallel_size,
order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
rank_offset=encoder_world_size,
)
# Build expert rank generator
if expert_tensor_parallel_size_ is None:
expert_tensor_parallel_size_ = tensor_model_parallel_size
expert_tensor_model_pipeline_parallel_size = (
expert_tensor_parallel_size_ * expert_model_parallel_size_ * pipeline_model_parallel_size
)
expert_data_parallel_size = decoder_world_size // expert_tensor_model_pipeline_parallel_size
if decoder_world_size % expert_tensor_model_pipeline_parallel_size != 0:
raise RuntimeError(
f"decoder world_size ({decoder_world_size}) is not divisible by expert_tensor_model_pipeline_parallel size ({expert_tensor_model_pipeline_parallel_size})"
)

expert_decoder_rank_generator = RankGenerator(
tp=expert_tensor_parallel_size_,
ep=expert_model_parallel_size_,
dp=expert_data_parallel_size,
pp=pipeline_model_parallel_size,
cp=1,
order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
rank_offset=encoder_world_size,
)

def generator_wrapper(group_type, **kwargs):
assert decoder_rank_generator.get_ranks("pp") == expert_decoder_rank_generator.get_ranks(
"pp"
), f"Pipeline parallel groups are expected to be the same for Non-Expert and Expert part, \
but got {decoder_rank_generator.get_ranks('pp')} and {expert_decoder_rank_generator.get_ranks('pp')}"


def generator_wrapper(group_type, is_expert=False, **kwargs):
from itertools import cycle

"""The `RankGenerator` class produces a hyper-rectangle for a given set of
tensor, pipeline, data, expert, and context parallelism. If we have an encoder,
in addition to the default decoder, we essentially instantiate two `RankGenerator`
classes to construct the parallelism for each module separately, and we then have
to stitch them together for the right groups. For now, this means pp and tp-pp."""
d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs)
if is_expert:
d_ranks = expert_decoder_rank_generator.get_ranks(group_type, **kwargs)
else:
d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs)
if encoder_rank_generator is None:
for x in d_ranks:
yield x
Expand Down Expand Up @@ -446,7 +477,7 @@ def generator_wrapper(group_type, **kwargs):
# EP rank
expert_model_parallel_rank = 0
if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1:
for ranks in generator_wrapper('ep', independent_ep=True):
for ranks in generator_wrapper('ep', is_expert=True):
if rank in ranks:
expert_model_parallel_rank = list(ranks).index(rank)

Expand Down

0 comments on commit c2ed354

Please sign in to comment.