Skip to content

Commit

Permalink
Neva updates to latest mcore and some fixes (NVIDIA#11565)
Browse files Browse the repository at this point in the history
* api updates and fixes

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* fix

Signed-off-by: yaoyu-33 <[email protected]>

* fix arg

Signed-off-by: yaoyu-33 <[email protected]>

---------

Signed-off-by: yaoyu-33 <[email protected]>
Signed-off-by: yaoyu-33 <[email protected]>
Co-authored-by: yaoyu-33 <[email protected]>
  • Loading branch information
yaoyu-33 and yaoyu-33 authored Dec 16, 2024
1 parent 6139a10 commit 42ad3c0
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 68 deletions.
4 changes: 3 additions & 1 deletion nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ class GPTConfig(TransformerConfig, io.IOMixin):
masked_softmax_fusion: bool = True
cross_entropy_loss_fusion: bool = True
gradient_accumulation_fusion: bool = _grad_accum_fusion_available
deallocate_pipeline_outputs = True
deallocate_pipeline_outputs: bool = True
scatter_embedding_sequence_parallel: bool = True

use_transformer_engine_full_layer_spec: bool = False
transformer_layer_spec: Union[ModuleSpec, Callable[["GPTConfig"], ModuleSpec]] = default_layer_spec
Expand Down Expand Up @@ -216,6 +217,7 @@ def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MC
seq_len_interpolation_factor=self.seq_len_interpolation_factor,
pre_process=pre_process or parallel_state.is_pipeline_first_stage(),
post_process=post_process or parallel_state.is_pipeline_last_stage(),
scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel,
)

# If using full TE layer, need to set TP, CP group since the module call
Expand Down
148 changes: 84 additions & 64 deletions nemo/collections/vlm/neva/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
import torch
import torch.distributed
import torch.nn.functional as F
from megatron.core import dist_checkpointing
from megatron.core import InferenceParams, dist_checkpointing
from megatron.core import parallel_state as ps
from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.core.inference_params import InferenceParams
from megatron.core.models.multimodal.llava_model import LLaVAModel as MCoreLLaVAModel
from megatron.core.models.vision.clip_vit_model import CLIPViTModel as MCoreCLIPViTModel
from megatron.core.models.vision.multimodal_projector import MultimodalProjector as MCoreMultimodalProjector
from megatron.core.optimizer import OptimizerConfig
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
from megatron.core.transformer.custom_layers.transformer_engine import (
TEColumnParallelLinear,
Expand Down Expand Up @@ -133,18 +133,17 @@ def neva_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:

def neva_forward_step(model, batch) -> torch.Tensor:
forward_args = {
"media": batch["media"],
"images": batch["media"],
"input_ids": batch["tokens"],
"position_ids": batch["position_ids"],
"attention_mask": batch.get("attention_mask", None),
"loss_mask": batch.get("loss_mask", None),
"labels": batch.get("labels", None),
"num_media_tiles": batch.get("num_media_tiles", None),
"num_image_tiles": batch.get("num_media_tiles", None),
"image_token_mask": batch.get("image_token_mask", None),
"packed_seq_params": batch.get("packed_seq_params", None),
}

if 'cu_seqlens' in batch:
forward_args['packed_seq_params'] = get_packed_seq_params(batch)

return model(**forward_args)


Expand Down Expand Up @@ -219,10 +218,22 @@ class HFCLIPVisionConfig(CLIPVisionConfig, io.IOMixin):
"""

hidden_size: int = 1024
num_image_embeddings_per_tile: Optional[int] = None
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None

def __post_init__(self, *args, **kwargs) -> None:
CLIPVisionConfig.__init__(self, *args, **kwargs, hidden_size=self.hidden_size)
if self.pretrained_model_name_or_path is not None:
config = CLIPVisionConfig.from_pretrained(self.pretrained_model_name_or_path)
for key, value in config.to_dict().items():
setattr(self, key, value)
self.num_image_embeddings_per_tile = get_image_sequence_length(
img_h=self.image_size,
img_w=self.image_size,
patch_dim=self.patch_size,
add_class_token=False,
class_token_len=1,
)

def configure_model(self) -> "CLIPVisionModel":
# Monkey patch the method to the vision encoder
Expand All @@ -232,9 +243,6 @@ def configure_model(self) -> "CLIPVisionModel":
model = CLIPVisionModel(self)
else:
model = CLIPVisionModel.from_pretrained(self.pretrained_model_name_or_path)
# Extend all model.config fields to self
for key, value in model.config.to_dict().items():
setattr(self, key, value)
return model


Expand All @@ -248,6 +256,7 @@ class CLIPViTConfig(TransformerConfig, io.IOMixin):
img_h: int = 336
img_w: int = 336
vision_model_type: str = "clip" # ["clip", "siglip"]
num_image_embeddings_per_tile: Optional[int] = None
transformer_layer_spec: ModuleSpec = transformer_engine_layer_spec

num_layers: int = 1 # Placeholder, NOT used!
Expand All @@ -257,6 +266,13 @@ def __post_init__(self):
if self.vision_model_type == "siglip":
self.add_class_token = False
self.class_token_len = 0
self.num_image_embeddings_per_tile = get_image_sequence_length(
img_h=self.img_h,
img_w=self.img_w,
patch_dim=self.patch_dim,
add_class_token=self.add_class_token,
class_token_len=self.class_token_len,
)

def configure_model(self) -> "CLIPViTModel":
transformer_layer_spec = self.transformer_layer_spec
Expand Down Expand Up @@ -311,8 +327,6 @@ def __post_init__(self):
setattr(self, attr, getattr(self.language_transformer_config, attr))

def configure_model(self, tokenizer) -> "MCoreNevaModel":
from megatron.core import parallel_state as ps

self.language_transformer_config.tensor_model_parallel_size = self.tensor_model_parallel_size
self.language_transformer_config.sequence_parallel = self.sequence_parallel
self.vision_transformer_config.tensor_model_parallel_size = self.tensor_model_parallel_size
Expand Down Expand Up @@ -394,8 +408,11 @@ def __init__(
self.context_parallel_lm = language_transformer_config.context_parallel_size
self.tensor_model_parallel_size_lm = language_transformer_config.tensor_model_parallel_size

# This attribute is needed to check if an all-reduce is required
# on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`.
self.share_embeddings_and_output_weights = False
if self.add_decoder:
language_transformer_config.scatter_embedding_sequence_parallel = False
self.language_model = language_transformer_config.configure_model(
tokenizer=tokenizer, pre_process=pre_process, post_process=post_process
)
Expand Down Expand Up @@ -436,36 +453,22 @@ def __init__(
# on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`.

self.vision_model_from_hf = hasattr(vision_transformer_config, "image_size")
if self.vision_model_from_hf:
# img_h, img_w, patch_dim, add_class_token, class_token_len
self._img_seq_len = get_image_sequence_length(
img_h=vision_transformer_config.image_size,
img_w=vision_transformer_config.image_size,
patch_dim=vision_transformer_config.patch_size,
add_class_token=not drop_vision_class_token,
class_token_len=0 if "siglip" in vision_transformer_config.model_type else 1,
)
else:
self._img_seq_len = get_image_sequence_length(
img_h=vision_transformer_config.img_h,
img_w=vision_transformer_config.img_w,
patch_dim=vision_transformer_config.patch_dim,
add_class_token=not drop_vision_class_token,
class_token_len=vision_transformer_config.class_token_len,
)
self._img_seq_len = vision_transformer_config.num_image_embeddings_per_tile

def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
loss_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
media: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
inference_params: Optional[InferenceParams] = None,
num_media_tiles: Optional[List[int]] = None,
media_token_index: Optional[int] = IMAGE_TOKEN_INDEX,
num_image_tiles: Optional[List[int]] = None,
image_token_index: Optional[int] = IMAGE_TOKEN_INDEX,
runtime_gather_output: Optional[bool] = None,
image_token_mask: Optional[torch.Tensor] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
) -> torch.Tensor:
"""Forward function of the LLaVA model.
Expand All @@ -477,61 +480,69 @@ def forward(
labels (torch.Tensor): Optional target text labels [batch, combined_seq_len].
loss_mask (torch.Tensor): Text loss mask [batch, text_seq_len].
inference_params (InferenceParams): Inference-time parameters including KV cache.
num_media_tiles (list of int): Number of tiles per image. Default None assumes 1 tile per image.
image_token_index (int): ID for input images.
num_image_tiles (list of int): Number of tiles per image. Default 1 tile per image.
image_token_index (int): ID for input images. Default None means `image_token_index`
arg in the constructor will be used.
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
image_token_mask (torch.Tensor): Tensor indicating the location of
image token index in input_ids.
packed_seq_params (PackedSeqParams): Dict with padded token information.
Required for using SP/CP with padding mask type.
Returns:
output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size].
output (torch.Tensor): Loss of shape [b, s] if labels are provided,
otherwise logits of shape [b, s, vocab_size].
loss_mask (torch.Tensor): Loss mask expanded to combined sequence length. Shape [b, s].
"""

use_inference_kv_cache = (
inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict
)
has_images = media is not None and media.shape[0] > 0
has_images = images is not None and images.shape[0] > 0

# If running inference, we can skip media token computation if they were computed already earlier for this sample.
# If running inference, we can skip images token computation if they were computed already earlier for this sample.
if use_inference_kv_cache:
media_embeddings = None
image_embeddings = None
elif self.add_encoder and not has_images:
vision_param = next(self.vision_model.parameters())
# If no images provided, use an empty image embeddings tensor.
media_embeddings = torch.tensor([], dtype=vision_param.dtype, device=vision_param.device).reshape(0, 0, 0)
image_embeddings = torch.tensor([], dtype=vision_param.dtype, device=vision_param.device).reshape(0, 0, 0)
elif self.add_encoder and has_images:
# media is in shape of (num_images_in_mbs, c, h, w)
# images is in shape of (num_images_in_mbs, c, h, w)
# note num_images_in_mbs is not mbs but total images in this mbs.
media = media.to(next(self.vision_model.parameters()).dtype)
images = images.to(next(self.vision_model.parameters()).dtype)
if self.vision_model_from_hf:
self.vision_model = self.vision_model.eval()
media_embeddings = self.vision_model(media, output_hidden_states=True)
media_embeddings = media_embeddings[-1][
image_embeddings = self.vision_model(images, output_hidden_states=True)
image_embeddings = image_embeddings[-1][
self.config.vision_feature_layer
] # [num_images, img_seq_len, h_vision]
else:
# TODO(yuya): MCore Clip path not yet support taking a specific layer hidden states
media_embeddings = self.vision_model(media, num_unused_layers=-self.config.vision_feature_layer - 1)
image_embeddings = self.vision_model(images, num_unused_layers=-self.config.vision_feature_layer - 1)
if self._drop_vision_class_token:
class_token_len = getattr(self.vision_model, "class_token_len", 1)
media_embeddings = media_embeddings[:, class_token_len:, :]
image_embeddings = image_embeddings[:, class_token_len:, :]

# contiguous() required as `permute` can sparsify the tensor and this breaks pipelining
media_embeddings = media_embeddings.permute(1, 0, 2).contiguous() # [img_seq_len, num_tiles, h_vision]
image_embeddings = image_embeddings.permute(1, 0, 2).contiguous() # [img_seq_len, num_tiles, h_vision]

# map vision model output size to language model input size.
media_embeddings = self.vision_projection(media_embeddings) # [img_seq_len, num_tiles, h_language]
image_embeddings = self.vision_projection(image_embeddings) # [img_seq_len, num_tiles, h_language]

# TODO: Support batched inference.
# In inference, the language model KV cache will be updated for image token positions.
# Store the image tokens sequence length to be used as an offset to the KV cache later.
if inference_params is not None:
inference_params.key_value_memory_dict["media_tokens_count"] = (
media_embeddings.shape[0] * media_embeddings.shape[1]
inference_params.key_value_memory_dict["image_tokens_count"] = (
image_embeddings.shape[0] * image_embeddings.shape[1]
)
else:
media_embeddings = self.encoder_hidden_state
image_embeddings = self.encoder_hidden_state

if not self.add_decoder:
return media_embeddings
return image_embeddings

language_embeddings = None
if self.pre_process:
Expand Down Expand Up @@ -569,32 +580,33 @@ def forward(
language_embeddings = language_embeddings.transpose(1, 0).contiguous() # [b, text_seq_len, h_language]

# Assume 1 tile per image if the number of tiles is not provided.
if num_media_tiles is None:
num_media_tiles = torch.ones(media.shape[0], dtype=torch.int, device=input_ids.device)
elif isinstance(num_media_tiles, list):
num_media_tiles = torch.tensor(num_media_tiles, dtype=torch.int, device=input_ids.device)
if num_image_tiles is None:
num_image_tiles = torch.ones(images.shape[0], dtype=torch.int, device=input_ids.device)
elif isinstance(num_image_tiles, list):
num_image_tiles = torch.tensor(num_image_tiles, dtype=torch.int, device=input_ids.device)

# Preprocess input, labels and loss mask.
combined_embeddings, final_labels, final_loss_mask, final_attention_mask = self._preprocess_data(
media_embeddings,
image_embeddings,
language_embeddings,
input_ids,
loss_mask,
labels,
use_inference_kv_cache,
media_token_index,
num_media_tiles,
image_token_index,
num_image_tiles,
attention_mask,
) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len]

output = self.language_model(
input_ids=None,
position_ids=None,
attention_mask=attention_mask,
attention_mask=final_attention_mask,
decoder_input=combined_embeddings,
labels=final_labels,
inference_params=inference_params,
runtime_gather_output=runtime_gather_output,
packed_seq_params=packed_seq_params,
)

if labels is None or loss_mask is None:
Expand Down Expand Up @@ -878,20 +890,28 @@ def forward(
position_ids: torch.Tensor,
loss_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
media: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
inference_params: InferenceParams = None,
num_media_tiles: Optional[List[int]] = None,
inference_params: Optional[InferenceParams] = None,
num_image_tiles: Optional[List[int]] = None,
image_token_index: Optional[int] = IMAGE_TOKEN_INDEX,
runtime_gather_output: Optional[bool] = None,
image_token_mask: Optional[torch.Tensor] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
) -> torch.Tensor:
output_tensor = self.module(
media=media,
images=images,
input_ids=input_ids,
position_ids=position_ids,
loss_mask=loss_mask,
attention_mask=attention_mask,
labels=labels,
inference_params=inference_params,
num_media_tiles=num_media_tiles,
num_image_tiles=num_image_tiles,
image_token_index=image_token_index,
runtime_gather_output=runtime_gather_output,
image_token_mask=image_token_mask,
packed_seq_params=packed_seq_params,
)

return output_tensor
Expand Down
4 changes: 4 additions & 0 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1724,6 +1724,10 @@ def masked_token_loss_context_parallel(tensor: Tensor, mask: Tensor, num_valid_t

losses = tensor.float()
loss_mask = mask.view(-1).float()
if num_valid_tokens_in_ub is None:
num_valid_tokens_in_ub = loss_mask.sum()
if num_valid_tokens_in_ub < 0.5: # no valid tokens
num_valid_tokens_in_ub += 1.0
loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens_in_ub # sequence level nll
torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group())

Expand Down
11 changes: 8 additions & 3 deletions scripts/vlm/neva_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nemo import lightning as nl
from nemo.collections import llm, vlm
from nemo.collections.vlm import ImageDataConfig
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
from nemo.utils.exp_manager import TimingCallback
Expand Down Expand Up @@ -111,7 +112,7 @@ def main(args):
ddp=DistributedDataParallelConfig(
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
overlap_grad_reduce=False,
overlap_grad_reduce=True,
overlap_param_gather=True,
average_in_collective=True,
),
Expand All @@ -134,7 +135,11 @@ def main(args):
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
callbacks=[checkpoint_callback, TimingCallback()],
callbacks=[
checkpoint_callback,
TimingCallback(),
MegatronCommOverlapCallback(tp_comm_overlap=True),
],
val_check_interval=500,
limit_val_batches=gbs,
log_every_n_steps=1,
Expand Down Expand Up @@ -223,7 +228,7 @@ def main(args):
parser.add_argument("--name", type=str, required=False, default="neva_pretrain")
parser.add_argument("--peft", type=str, default='none', help="none | lora")
parser.add_argument("--wandb_project", type=str, required=False, default=None)
parser.add_argument("--gbs", type=int, required=False, default=64, help="Global batch size")
parser.add_argument("--gbs", type=int, required=False, default=128, help="Global batch size")
parser.add_argument("--mbs", type=int, required=False, default=2, help="Micro batch size")
parser.add_argument("--lr", type=float, required=False, default=2.0e-06, help="Learning rate")

Expand Down

0 comments on commit 42ad3c0

Please sign in to comment.