Skip to content

Commit

Permalink
Merge branch 'main' into expert_parallelism/akoumparouli2
Browse files Browse the repository at this point in the history
  • Loading branch information
akoumpa committed Jan 27, 2024
2 parents 19879d9 + 13c1db4 commit 0aa813e
Show file tree
Hide file tree
Showing 4 changed files with 645 additions and 86 deletions.
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pipeline {

stage('Install test requirements') {
steps {
sh 'apt-get update && apt-get install -y bc && pip install -r requirements/requirements_test.txt'
sh 'apt-get update && apt-get install -y bc && pip install -r requirements/requirements_test.txt && pip install -r requirements/requirements_lightning.txt'
}
}

Expand Down
141 changes: 56 additions & 85 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import warnings
from contextlib import nullcontext
from dataclasses import fields
from functools import partial
from functools import cache, partial
from importlib.metadata import version
from typing import Any, Dict, Iterator, List, Optional, Union

Expand Down Expand Up @@ -113,11 +113,30 @@
HAVE_TE = False


def get_specs(spec_name):
name_spec_dict = {"": get_gpt_layer_with_transformer_engine_spec(), "megatron_falcon_gpt": get_falcon_layer_spec()}
if spec_name not in name_spec_dict:
@cache
def mcore_supports_moe() -> bool:
global HAVE_MEGATRON_CORE
if not HAVE_MEGATRON_CORE:
return False
try:
from megatron.core.transformer.moe.router import TopKRouter

return True
except ImportError:
return False


def get_specs(spec_name, num_experts=None):
if spec_name == '':
if num_experts is not None:
assert mcore_supports_moe(), "Megatron-core >= v0.5.0 is required for MoE"
return get_gpt_layer_with_transformer_engine_spec(num_experts)
else:
return get_gpt_layer_with_transformer_engine_spec()
elif spec_name == 'megatron_falcon_gpt':
return get_falcon_layer_spec()
else:
raise ValueError(f"Spec name '{spec_name}' is not recognized.")
return name_spec_dict[spec_name]


class MegatronGPTExportableModel(torch.nn.Module, Exportable):
Expand Down Expand Up @@ -328,7 +347,7 @@ def model_provider_func(self, pre_process, post_process):
if self.mcore_gpt:
model = MCoreGPTModel(
config=self.transformer_config,
transformer_layer_spec=get_specs(self.spec_name),
transformer_layer_spec=get_specs(self.spec_name, self.transformer_config.num_moe_experts),
vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size),
max_sequence_length=self.cfg.get('encoder_seq_length', 512),
pre_process=pre_process,
Expand Down Expand Up @@ -909,96 +928,25 @@ def __next__(self):
# TODO @tmoon: Use once available in Megatron-LM
# return DataIteratorList(iters)

def get_ltor_masks_and_position_ids(
self, data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss
):
"""Build masks and position id for left to right model."""

# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()

# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view(
att_mask_batch, 1, seq_length, seq_length
)

# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0

# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()

if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):

# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()

# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1) :] -= i + 1 - prev_index
prev_index = i + 1

# Convert attention mask to binary:
attention_mask = attention_mask < 0.5

return attention_mask, loss_mask, position_ids

def get_batch(self, data_iterator, tuning):
"""Generate a batch."""

# return batch for GPT SFT
if tuning:
return next(data_iterator)

# Items and their type.
keys = ['text']
datatype = torch.int64

# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = tensor_parallel.broadcast_data(keys, data, datatype)

# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()

# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = self.get_ltor_masks_and_position_ids(
tokens, self.tokenizer.eos_id, self.reset_position_ids, self.reset_attention_mask, self.eod_mask_loss
)
# return batch for GPT SFT
if tuning:
return data

batch = {
'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids,
'tokens': data["tokens"],
'labels': data["labels"],
'loss_mask': data["loss_mask"],
'attention_mask': data["attention_mask"],
'position_ids': data["position_ids"],
}

return batch
Expand Down Expand Up @@ -1301,6 +1249,10 @@ def build_train_valid_test_datasets(self):
"blend": self.cfg.data.data_prefix,
"split": self.cfg.data.splits_string,
"path_to_cache": self.cfg.data.index_mapping_dir,
"reset_position_ids": self.reset_position_ids,
"reset_attention_mask": self.reset_attention_mask,
"eod_mask_loss": self.eod_mask_loss,
"eod_id": self.tokenizer.eos_id,
}

if self.cfg.data.get('add_fim', False):
Expand Down Expand Up @@ -1794,7 +1746,26 @@ def build_transformer_config(self) -> TransformerConfig:
'num_moe_experts': self.cfg.get('num_moe_experts', None),
'moe_router_type': self.cfg.get('moe_router_type', None),
'tp_comm_overlap': ub_tp_comm_overlap,
# MoE related
'num_experts': self.cfg.get('num_experts', None),
'moe_router_load_balancing_type': self.cfg.get('moe_router_load_balancing_type', 'aux_loss'),
'moe_router_topk': self.cfg.get('moe_router_topk', 2),
'moe_grouped_gemm': self.cfg.get('moe_grouped_gemm', False),
'moe_aux_loss_coeff': self.cfg.get(
'moe_aux_loss_coeff', 0
), # 1e-2 would be a good start value for load balance loss.
'moe_z_loss_coeff': self.cfg.get('moe_z_loss_coeff', None), # 1e-3 would be a good start value for z-loss
'moe_input_jitter_eps': self.cfg.get('moe_input_jitter_eps', None),
'moe_token_dropping': self.cfg.get('moe_token_dropping', False), # TODO: Support token dropping.
}
if model_specific_configs['num_experts'] is not None:
assert mcore_supports_moe(), 'Megatron-core >= v0.5.0 is required for MoE'
elif not mcore_supports_moe():
if 'num_experts' in model_specific_configs:
del model_specific_configs['num_experts']
moe_keys = list(filter(lambda x: x.startswith('moe_'), model_specific_configs.keys()))
for k in moe_keys:
del model_specific_configs[k]

transformer_config = super().build_transformer_config()

Expand Down
Loading

0 comments on commit 0aa813e

Please sign in to comment.