From e85a4546ddccaf49ea081d79ecdf6aed402fd9cf Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Thu, 14 Nov 2024 11:45:58 -0800 Subject: [PATCH 1/7] fix api Signed-off-by: yaoyu-33 --- .../nlp/models/language_modeling/megatron/bert/bert_model.py | 1 + .../language_modeling/megatron/falcon/falcon_decoder_layer.py | 1 + .../megatron/gpt_full_te_layer_autocast_spec.py | 1 + .../nlp/modules/common/megatron/adapters/mcore_mixins.py | 3 +++ nemo/collections/vlm/mllama/model/vision.py | 1 + 5 files changed, 7 insertions(+) diff --git a/nemo/collections/nlp/models/language_modeling/megatron/bert/bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron/bert/bert_model.py index 0d75ab7cc706..c629db5af3c3 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/bert/bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/bert/bert_model.py @@ -208,6 +208,7 @@ def forward( rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, + attention_bias=None, inference_params=None, packed_seq_params=None, ): diff --git a/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_decoder_layer.py b/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_decoder_layer.py index 131f154d6709..7c3f3c194f14 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_decoder_layer.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_decoder_layer.py @@ -108,6 +108,7 @@ def forward( rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, + attention_bias=None, inference_params=None, packed_seq_params=None, ): diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py index d1945139dee9..1def214113ee 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py @@ -252,6 +252,7 @@ def forward( rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, + attention_bias=None, inference_params=None, packed_seq_params=None, # TODO: handle this ): diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index da9c98fd94ea..85d5235fccfb 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -82,6 +82,7 @@ def forward( rotary_pos_emb: Tensor = None, rotary_pos_cos: Tensor = None, rotary_pos_sin: Tensor = None, + attention_bias: Tensor = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, ): @@ -93,6 +94,7 @@ def forward( rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, + attention_bias, inference_params, packed_seq_params, ) @@ -232,6 +234,7 @@ def forward( packed_seq_params=None, rotary_pos_cos=None, rotary_pos_sin=None, + attention_bias=None, ): # hidden_states: [sq, b, h] diff --git a/nemo/collections/vlm/mllama/model/vision.py b/nemo/collections/vlm/mllama/model/vision.py index f662546d21ae..aa38e8f16ea6 100644 --- a/nemo/collections/vlm/mllama/model/vision.py +++ b/nemo/collections/vlm/mllama/model/vision.py @@ -423,6 +423,7 @@ def forward( rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, + attention_bias=None, inference_params=None, packed_seq_params=None, ): From 0de74894feb766d6dcb82ecfb2f41b168a776342 Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Fri, 15 Nov 2024 08:59:34 -0800 Subject: [PATCH 2/7] fix ci Signed-off-by: yaoyu-33 --- .../common/megatron/adapters/mcore_mixins.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index 85d5235fccfb..e306a0a9b6b7 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -87,16 +87,15 @@ def forward( packed_seq_params: PackedSeqParams = None, ): hidden_states = super().forward( - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - rotary_pos_cos, - rotary_pos_sin, - attention_bias, - inference_params, - packed_seq_params, + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + inference_params=inference_params, + packed_seq_params=packed_seq_params, ) mlp_head_adapter = self.get_adapter_module(AdapterName.MLP_HEAD_ADAPTER) From b14c0329dff93b0b8ec57edbea70121c487db504 Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Fri, 15 Nov 2024 11:12:27 -0800 Subject: [PATCH 3/7] add docstring Signed-off-by: yaoyu-33 --- nemo/collections/vlm/mllama/model/vision.py | 95 +++++++++++++++++---- 1 file changed, 80 insertions(+), 15 deletions(-) diff --git a/nemo/collections/vlm/mllama/model/vision.py b/nemo/collections/vlm/mllama/model/vision.py index aa38e8f16ea6..c60f7740924b 100644 --- a/nemo/collections/vlm/mllama/model/vision.py +++ b/nemo/collections/vlm/mllama/model/vision.py @@ -59,6 +59,9 @@ def to_2tuple(x): + """ + Convert an input to a 2-tuple. + """ if isinstance(x, collections.abc.Iterable): return x return (x, x) @@ -71,9 +74,16 @@ def _stack_images( max_num_images: int, ) -> Tuple[torch.Tensor, List[int]]: """ - Takes a list of list of images and stacks them into a tensor. - This function is needed since images can be of completely - different resolutions and aspect ratios. + Stack a list of image lists into a tensor while accounting for varying resolutions and aspect ratios. + + Args: + images (List[List[PIL_Image.Image]]): List of image lists for stacking. + max_num_chunks (int): Maximum number of chunks per image. + image_res (int): Target resolution for each image. + max_num_images (int): Maximum number of images to stack. + + Returns: + Tuple[torch.Tensor, List[int]]: Tensor of stacked images and a list of chunk counts for each image. """ out_images, out_num_chunks = [], [] for imgs_sample in images: @@ -97,7 +107,17 @@ def build_encoder_attention_mask( x: torch.Tensor, ar_ids: torch.Tensor, ntok: int, num_chunks: int, supported_aspect_ratios: List[List[int]] ): """ - Build vision encoder attention mask that omits padding tiles and tokens. + Build attention masks for a vision encoder to handle padding and token alignment. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, sequence_length). + ar_ids (torch.Tensor): Aspect ratio IDs for masking. + ntok (int): Number of tokens. + num_chunks (int): Number of chunks in the data. + supported_aspect_ratios (List[List[int]]): List of supported aspect ratios. + + Returns: + torch.Tensor: Tensor containing the attention mask. """ masks = [] for ar_id in ar_ids: @@ -278,16 +298,22 @@ def forward_with_return_intermediate( class ColumnParallelConv2dPatch(MegatronModule): - """Conv2D Patching layer with model parallelism. - Column parallel over unfolded input. - Arguments: - in_channels: Input channels. - out_channels: Output channels. - kernel_size: Size of convolution kernel. - stride (default 1): Stride for convolution. - bias (default False): Use bias in Conv2d. - Input: (bsz, in_channels, width, height) - Output: (bsz, num_tokens, out_channels) + """ + Conv2D Patching layer with model parallelism. Applies convolution in a column-parallel fashion. + + Args: + config (TransformerConfig): Configuration object for the layer. + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (Union[int, Tuple[int, int]]): Size of the convolution kernel. + stride (Union[int, Tuple[int, int]]): Stride of the convolution. + bias (Optional[bool], default=False): Whether to include a bias term. + + Input: + torch.Tensor: Input tensor of shape (batch_size, in_channels, width, height). + + Output: + torch.Tensor: Output tensor of shape (batch_size, num_tokens, out_channels). """ def __init__( @@ -324,6 +350,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class PrecomputedTilePositionEmbedding(torch.nn.Module): + """ + Module to compute positional embeddings for tiles with optional gating. + + Args: + config (TransformerConfig): Configuration object. + gated (bool, default=False): Whether to apply gating to the embeddings. + + Methods: + forward(hidden_states, aspect_ratio_ids): Applies positional embeddings to the input states. + """ def __init__( self, config: TransformerConfig, @@ -351,7 +387,15 @@ def forward(self, hidden_states: torch.Tensor, aspect_ratio_ids: torch.Tensor) - class SelfAttentionNoBias(SelfAttention): - """Self-attention layer class without bias""" + """ + Self-attention layer implementation without bias. + + Args: + config (TransformerConfig): Configuration for the transformer. + submodules (SelfAttentionSubmodules): Submodules required for self-attention. + layer_number (int): The layer number in the transformer stack. + attn_mask_type (AttnMaskType): Type of attention mask to apply. + """ def __init__( self, @@ -396,6 +440,15 @@ def __init__( class ImageTransformerLayer(TransformerLayer): + """ + Transformer layer adapted for processing image data with optional gating. + + Args: + config (TransformerConfig): Transformer configuration object. + submodules (TransformerLayerSubmodules): Submodules to use in the layer. + layer_number (int, default=1): Layer number in the transformer. + hidden_dropout (float, optional): Dropout rate for hidden layers. + """ def __init__( self, config: TransformerConfig, @@ -486,6 +539,18 @@ def forward( class VisionEncoder(MegatronModule): + """ + Vision encoder module for processing image inputs with patch-based embeddings. + + Args: + config ('CrossAttentionVisionConfig'): Configuration object for the encoder. + image_size (int, default=560): Input image size. + patch_size (int, default=14): Size of patches extracted from the image. + in_channels (int, default=3): Number of input channels. + pre_process (bool, default=True): Whether to preprocess input. + post_process (bool, default=True): Whether to postprocess output. + return_intermediate (Optional[bool]): Whether to return intermediate layers. + """ def __init__( self, config: 'CrossAttentionVisionConfig', From d33c5161df4c3148723263224c2ef816126e4733 Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Fri, 15 Nov 2024 19:13:22 +0000 Subject: [PATCH 4/7] Apply isort and black reformatting Signed-off-by: yaoyu-33 --- nemo/collections/vlm/mllama/model/vision.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemo/collections/vlm/mllama/model/vision.py b/nemo/collections/vlm/mllama/model/vision.py index c60f7740924b..d3591082d679 100644 --- a/nemo/collections/vlm/mllama/model/vision.py +++ b/nemo/collections/vlm/mllama/model/vision.py @@ -360,6 +360,7 @@ class PrecomputedTilePositionEmbedding(torch.nn.Module): Methods: forward(hidden_states, aspect_ratio_ids): Applies positional embeddings to the input states. """ + def __init__( self, config: TransformerConfig, @@ -449,6 +450,7 @@ class ImageTransformerLayer(TransformerLayer): layer_number (int, default=1): Layer number in the transformer. hidden_dropout (float, optional): Dropout rate for hidden layers. """ + def __init__( self, config: TransformerConfig, @@ -551,6 +553,7 @@ class VisionEncoder(MegatronModule): post_process (bool, default=True): Whether to postprocess output. return_intermediate (Optional[bool]): Whether to return intermediate layers. """ + def __init__( self, config: 'CrossAttentionVisionConfig', From be0a75fd8a35f83d6c85765e4a639d4386d61275 Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Fri, 15 Nov 2024 11:29:21 -0800 Subject: [PATCH 5/7] fix docstring2 Signed-off-by: yaoyu-33 --- .../models/multimodal_llm/neva/neva_model.py | 6 ++++-- nemo/collections/vlm/mllama/model/vision.py | 20 +++++++++++++++---- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index 5291497f92c3..0a695df68b2a 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -618,7 +618,8 @@ def pad_word_embeddings(self, state_dict): - state_dict['model.language_model.embedding.word_embeddings.weight'].shape[0] ) state_dict['model.language_model.embedding.word_embeddings.weight'] = F.pad( - state_dict['model.language_model.embedding.word_embeddings.weight'], (0, 0, 0, pad_length) + state_dict['model.language_model.embedding.word_embeddings.weight'], + (0, 0, 0, pad_length) ) if 'model.language_model.output_layer.weight' in state_dict: @@ -627,7 +628,8 @@ def pad_word_embeddings(self, state_dict): == state_dict['model.language_model.output_layer.weight'].shape ) state_dict['model.language_model.output_layer.weight'] = F.pad( - state_dict['model.language_model.output_layer.weight'], (0, 0, 0, pad_length) + state_dict['model.language_model.output_layer.weight'], + (0, 0, 0, pad_length) ) return state_dict diff --git a/nemo/collections/vlm/mllama/model/vision.py b/nemo/collections/vlm/mllama/model/vision.py index d3591082d679..7650f99eb051 100644 --- a/nemo/collections/vlm/mllama/model/vision.py +++ b/nemo/collections/vlm/mllama/model/vision.py @@ -133,6 +133,9 @@ def build_encoder_attention_mask( def apply_scaling(freqs: torch.Tensor): + """ + Scale frequency values based on predefined thresholds and a smoothing factor. + """ # Values obtained from grid search scale_factor = 8 low_freq_factor = 1 @@ -157,6 +160,9 @@ def apply_scaling(freqs: torch.Tensor): # Use this spec for an implementation using modules in TE def get_image_transformer_layer_spec() -> ModuleSpec: + """ + Create a specification for an image transformer layer. + """ image_transformer_submodules = TransformerLayerSubmodules( input_layernorm=TENorm, self_attention=ModuleSpec( @@ -195,6 +201,10 @@ def forward_with_return_intermediate( packed_seq_params: PackedSeqParams = None, return_intermediate: List[int] = None, ): + """ + Perform a forward pass through the transformer layers with optional intermediate outputs. + Override regular MCore transformer layer forward pass. + """ # hidden_states (float): [s, b, h] # attention_mask (bool): [1, 1, s, s] @@ -342,6 +352,7 @@ def __init__( ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward.""" x = self._unfold(x) x = x.permute(0, 2, 1) x = F.linear(x, self._linear.weight) @@ -356,9 +367,6 @@ class PrecomputedTilePositionEmbedding(torch.nn.Module): Args: config (TransformerConfig): Configuration object. gated (bool, default=False): Whether to apply gating to the embeddings. - - Methods: - forward(hidden_states, aspect_ratio_ids): Applies positional embeddings to the input states. """ def __init__( @@ -377,6 +385,7 @@ def __init__( self.gate = nn.Parameter(torch.zeros(1)) def forward(self, hidden_states: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + """Forward.""" embeddings = self.embedding(aspect_ratio_ids) embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) @@ -482,6 +491,7 @@ def forward( inference_params=None, packed_seq_params=None, ): + """Forward.""" # hidden_states: [s, b, h] # Residual connection. @@ -625,7 +635,7 @@ def __init__( self.gated_positional_embedding_gate = nn.Parameter(torch.zeros(1)) def apply_positional_embedding(self, x, aspect_ratio_ids): - # apply regular position embedding + """Apply regular position embedding and tile positonal embedding.""" bsz, num_chunks, num_tokens, dim = x.shape x = x.view(bsz * num_chunks, num_tokens, dim) x = x + self.positional_embedding * (1 - self.gated_positional_embedding_gate.tanh()) @@ -636,6 +646,7 @@ def apply_positional_embedding(self, x, aspect_ratio_ids): return x def apply_class_embedding(self, x): + """Concat class embedding tokens.""" x = torch.cat( [ self.class_embedding.to(x.dtype) @@ -647,6 +658,7 @@ def apply_class_embedding(self, x): return x def forward(self, images: torch.Tensor, ar_ids: torch.Tensor) -> torch.Tensor: + """Forward.""" if images.ndim == 5: num_concurrent_media = 1 bsz, num_chunks, nch, w, h = images.shape From 9891dcd3d2c12ad19cd4d2f0e345d67877bddf02 Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Fri, 15 Nov 2024 19:30:37 +0000 Subject: [PATCH 6/7] Apply isort and black reformatting Signed-off-by: yaoyu-33 --- .../multimodal/models/multimodal_llm/neva/neva_model.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index 0a695df68b2a..5291497f92c3 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -618,8 +618,7 @@ def pad_word_embeddings(self, state_dict): - state_dict['model.language_model.embedding.word_embeddings.weight'].shape[0] ) state_dict['model.language_model.embedding.word_embeddings.weight'] = F.pad( - state_dict['model.language_model.embedding.word_embeddings.weight'], - (0, 0, 0, pad_length) + state_dict['model.language_model.embedding.word_embeddings.weight'], (0, 0, 0, pad_length) ) if 'model.language_model.output_layer.weight' in state_dict: @@ -628,8 +627,7 @@ def pad_word_embeddings(self, state_dict): == state_dict['model.language_model.output_layer.weight'].shape ) state_dict['model.language_model.output_layer.weight'] = F.pad( - state_dict['model.language_model.output_layer.weight'], - (0, 0, 0, pad_length) + state_dict['model.language_model.output_layer.weight'], (0, 0, 0, pad_length) ) return state_dict From 1eadf6e134cd5ed701f7e360c16274709a198eba Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Fri, 15 Nov 2024 11:49:35 -0800 Subject: [PATCH 7/7] fix line too long Signed-off-by: yaoyu-33 --- nemo/collections/vlm/mllama/model/vision.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo/collections/vlm/mllama/model/vision.py b/nemo/collections/vlm/mllama/model/vision.py index 7650f99eb051..f023cc7bf943 100644 --- a/nemo/collections/vlm/mllama/model/vision.py +++ b/nemo/collections/vlm/mllama/model/vision.py @@ -698,7 +698,8 @@ def forward(self, images: torch.Tensor, ar_ids: torch.Tensor) -> torch.Tensor: return_intermediate=self.return_intermediate, ) - # [ntok * num_concurrent_media * num_chunks, bsz, hidden_size] -> [bsz, ntok * num_concurrent_media * num_chunks, hidden_size] + # [ntok * num_concurrent_media * num_chunks, bsz, hidden_size] + # -> [bsz, ntok * num_concurrent_media * num_chunks, hidden_size] x, int_x = x.transpose(0, 1).contiguous(), int_x.transpose(0, 1).contiguous() x = self.ln_post(x) x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim)