Skip to content

Commit

Permalink
[MCoreDistOptim] Add assertions for McoreDistOptim and fix fp8 arg sp…
Browse files Browse the repository at this point in the history
…ecs (NVIDIA#10748)

* add spec fp8 arg

Signed-off-by: Gao Deng <[email protected]>

* add the arg assertions

Signed-off-by: Gao Deng <[email protected]>

* Apply isort and black reformatting

Signed-off-by: gdengk <[email protected]>

---------

Signed-off-by: Gao Deng <[email protected]>
Signed-off-by: gdengk <[email protected]>
Co-authored-by: gdengk <[email protected]>
  • Loading branch information
2 people authored and XuesongYang committed Jan 18, 2025
1 parent 3ecc0fd commit 199953d
Showing 1 changed file with 15 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def mcore_supports_moe() -> bool:


## TODO: This function will not work if TE is not installed
def get_specs(spec_name, transformer_config=None, use_te=True, hyena_cfg: Dict = None):
def get_specs(spec_name, transformer_config=None, use_te=True, hyena_cfg: Dict = None, fp8=False):
from nemo.collections.nlp.models.language_modeling.megatron.gemma2.gemma2_spec import get_gemma2_layer_spec

# else cases for backwards compatibility with neva
Expand All @@ -164,7 +164,7 @@ def get_specs(spec_name, transformer_config=None, use_te=True, hyena_cfg: Dict =
spec_name = 'te_gpt'
name_spec_dict = {
"": get_gpt_layer_local_spec(num_experts, moe_grouped_gemm),
"te_gpt": get_gpt_layer_with_transformer_engine_spec(num_experts, moe_grouped_gemm),
"te_gpt": get_gpt_layer_with_transformer_engine_spec(num_experts, moe_grouped_gemm, fp8=fp8),
"megatron_falcon_gpt": get_falcon_layer_spec(),
"megatron_gemma2": get_gemma2_layer_spec(),
"megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(transformer_config),
Expand Down Expand Up @@ -367,6 +367,18 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
'Expert parallelism is currently not supporting Apex distributed optimizer, use Mcore distributed optimizer instead'
)

if self.cfg.optim.get('overlap_param_gather_with_optimizer_step', False):
assert self.cfg.optim.get(
'overlap_param_sync', False
), "must use overlap_param_gather_with_optimizer_step with overlap_param_sync"
assert (
self.cfg.get('virtual_pipeline_model_parallel_size', None) is not None
and self.cfg.get('virtual_pipeline_model_parallel_size', None) > 1
), "must use overlap_param_gather_with_optimizer_step with interleaved pipeline parallelism"

if self.cfg.optim.get('overlap_param_sync', False) and not self.cfg.optim.get('overlap_grad_sync', False):
raise ValueError('Must use overlap_param_sync together with overlap_grad_sync')

self.transformer_engine = cfg.get('transformer_engine', False)
if self.megatron_amp_O2 and not self.transformer_engine:
logging.warning('megatron_amp_O2 is enabled but transformer-engine is not.')
Expand Down Expand Up @@ -471,6 +483,7 @@ def model_provider_func(self, pre_process, post_process):
self.transformer_config,
self.transformer_engine,
self.cfg.get('hyena', None),
self.cfg.get('fp8', False),
),
vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size),
max_sequence_length=self.cfg.get('encoder_seq_length', 512),
Expand Down

0 comments on commit 199953d

Please sign in to comment.