Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attention_bias argument in transformer block and transformer layer modules, addressing change in MCore #11289

Merged
merged 7 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def forward(
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
inference_params=None,
packed_seq_params=None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def forward(
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
inference_params=None,
packed_seq_params=None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def forward(
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
inference_params=None,
packed_seq_params=None, # TODO: handle this
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,20 @@ def forward(
rotary_pos_emb: Tensor = None,
rotary_pos_cos: Tensor = None,
rotary_pos_sin: Tensor = None,
attention_bias: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
):
hidden_states = super().forward(
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
inference_params,
packed_seq_params,
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does attention bias need to be passed to super.forward()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mcore version hasn't bumped in ci, cannot add now

)

mlp_head_adapter = self.get_adapter_module(AdapterName.MLP_HEAD_ADAPTER)
Expand Down Expand Up @@ -232,6 +233,7 @@ def forward(
packed_seq_params=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
):
# hidden_states: [sq, b, h]

Expand Down
116 changes: 99 additions & 17 deletions nemo/collections/vlm/mllama/model/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@


def to_2tuple(x):
"""
Convert an input to a 2-tuple.
"""
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
Expand All @@ -71,9 +74,16 @@ def _stack_images(
max_num_images: int,
) -> Tuple[torch.Tensor, List[int]]:
"""
Takes a list of list of images and stacks them into a tensor.
This function is needed since images can be of completely
different resolutions and aspect ratios.
Stack a list of image lists into a tensor while accounting for varying resolutions and aspect ratios.

Args:
images (List[List[PIL_Image.Image]]): List of image lists for stacking.
max_num_chunks (int): Maximum number of chunks per image.
image_res (int): Target resolution for each image.
max_num_images (int): Maximum number of images to stack.

Returns:
Tuple[torch.Tensor, List[int]]: Tensor of stacked images and a list of chunk counts for each image.
"""
out_images, out_num_chunks = [], []
for imgs_sample in images:
Expand All @@ -97,7 +107,17 @@ def build_encoder_attention_mask(
x: torch.Tensor, ar_ids: torch.Tensor, ntok: int, num_chunks: int, supported_aspect_ratios: List[List[int]]
):
"""
Build vision encoder attention mask that omits padding tiles and tokens.
Build attention masks for a vision encoder to handle padding and token alignment.

Args:
x (torch.Tensor): Input tensor of shape (batch_size, sequence_length).
ar_ids (torch.Tensor): Aspect ratio IDs for masking.
ntok (int): Number of tokens.
num_chunks (int): Number of chunks in the data.
supported_aspect_ratios (List[List[int]]): List of supported aspect ratios.

Returns:
torch.Tensor: Tensor containing the attention mask.
"""
masks = []
for ar_id in ar_ids:
Expand All @@ -113,6 +133,9 @@ def build_encoder_attention_mask(


def apply_scaling(freqs: torch.Tensor):
"""
Scale frequency values based on predefined thresholds and a smoothing factor.
"""
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
Expand All @@ -137,6 +160,9 @@ def apply_scaling(freqs: torch.Tensor):

# Use this spec for an implementation using modules in TE
def get_image_transformer_layer_spec() -> ModuleSpec:
"""
Create a specification for an image transformer layer.
"""
image_transformer_submodules = TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
Expand Down Expand Up @@ -175,6 +201,10 @@ def forward_with_return_intermediate(
packed_seq_params: PackedSeqParams = None,
return_intermediate: List[int] = None,
):
"""
Perform a forward pass through the transformer layers with optional intermediate outputs.
Override regular MCore transformer layer forward pass.
"""
# hidden_states (float): [s, b, h]
# attention_mask (bool): [1, 1, s, s]

Expand Down Expand Up @@ -278,16 +308,22 @@ def forward_with_return_intermediate(


class ColumnParallelConv2dPatch(MegatronModule):
"""Conv2D Patching layer with model parallelism.
Column parallel over unfolded input.
Arguments:
in_channels: Input channels.
out_channels: Output channels.
kernel_size: Size of convolution kernel.
stride (default 1): Stride for convolution.
bias (default False): Use bias in Conv2d.
Input: (bsz, in_channels, width, height)
Output: (bsz, num_tokens, out_channels)
"""
Conv2D Patching layer with model parallelism. Applies convolution in a column-parallel fashion.

Args:
config (TransformerConfig): Configuration object for the layer.
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (Union[int, Tuple[int, int]]): Size of the convolution kernel.
stride (Union[int, Tuple[int, int]]): Stride of the convolution.
bias (Optional[bool], default=False): Whether to include a bias term.

Input:
torch.Tensor: Input tensor of shape (batch_size, in_channels, width, height).

Output:
torch.Tensor: Output tensor of shape (batch_size, num_tokens, out_channels).
"""

def __init__(
Expand Down Expand Up @@ -316,6 +352,7 @@ def __init__(
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward."""
x = self._unfold(x)
x = x.permute(0, 2, 1)
x = F.linear(x, self._linear.weight)
Expand All @@ -324,6 +361,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class PrecomputedTilePositionEmbedding(torch.nn.Module):
"""
Module to compute positional embeddings for tiles with optional gating.

Args:
config (TransformerConfig): Configuration object.
gated (bool, default=False): Whether to apply gating to the embeddings.
"""

def __init__(
self,
config: TransformerConfig,
Expand All @@ -340,6 +385,7 @@ def __init__(
self.gate = nn.Parameter(torch.zeros(1))

def forward(self, hidden_states: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
"""Forward."""
embeddings = self.embedding(aspect_ratio_ids)
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)

Expand All @@ -351,7 +397,15 @@ def forward(self, hidden_states: torch.Tensor, aspect_ratio_ids: torch.Tensor) -


class SelfAttentionNoBias(SelfAttention):
"""Self-attention layer class without bias"""
"""
Self-attention layer implementation without bias.

Args:
config (TransformerConfig): Configuration for the transformer.
submodules (SelfAttentionSubmodules): Submodules required for self-attention.
layer_number (int): The layer number in the transformer stack.
attn_mask_type (AttnMaskType): Type of attention mask to apply.
"""

def __init__(
self,
Expand Down Expand Up @@ -396,6 +450,16 @@ def __init__(


class ImageTransformerLayer(TransformerLayer):
"""
Transformer layer adapted for processing image data with optional gating.

Args:
config (TransformerConfig): Transformer configuration object.
submodules (TransformerLayerSubmodules): Submodules to use in the layer.
layer_number (int, default=1): Layer number in the transformer.
hidden_dropout (float, optional): Dropout rate for hidden layers.
"""

def __init__(
self,
config: TransformerConfig,
Expand Down Expand Up @@ -423,9 +487,11 @@ def forward(
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
inference_params=None,
packed_seq_params=None,
):
"""Forward."""
# hidden_states: [s, b, h]

# Residual connection.
Expand Down Expand Up @@ -485,6 +551,19 @@ def forward(


class VisionEncoder(MegatronModule):
"""
Vision encoder module for processing image inputs with patch-based embeddings.

Args:
config ('CrossAttentionVisionConfig'): Configuration object for the encoder.
image_size (int, default=560): Input image size.
patch_size (int, default=14): Size of patches extracted from the image.
in_channels (int, default=3): Number of input channels.
pre_process (bool, default=True): Whether to preprocess input.
post_process (bool, default=True): Whether to postprocess output.
return_intermediate (Optional[bool]): Whether to return intermediate layers.
"""

def __init__(
self,
config: 'CrossAttentionVisionConfig',
Expand Down Expand Up @@ -556,7 +635,7 @@ def __init__(
self.gated_positional_embedding_gate = nn.Parameter(torch.zeros(1))

def apply_positional_embedding(self, x, aspect_ratio_ids):
# apply regular position embedding
"""Apply regular position embedding and tile positonal embedding."""
bsz, num_chunks, num_tokens, dim = x.shape
x = x.view(bsz * num_chunks, num_tokens, dim)
x = x + self.positional_embedding * (1 - self.gated_positional_embedding_gate.tanh())
Expand All @@ -567,6 +646,7 @@ def apply_positional_embedding(self, x, aspect_ratio_ids):
return x

def apply_class_embedding(self, x):
"""Concat class embedding tokens."""
x = torch.cat(
[
self.class_embedding.to(x.dtype)
Expand All @@ -578,6 +658,7 @@ def apply_class_embedding(self, x):
return x

def forward(self, images: torch.Tensor, ar_ids: torch.Tensor) -> torch.Tensor:
"""Forward."""
if images.ndim == 5:
num_concurrent_media = 1
bsz, num_chunks, nch, w, h = images.shape
Expand Down Expand Up @@ -617,7 +698,8 @@ def forward(self, images: torch.Tensor, ar_ids: torch.Tensor) -> torch.Tensor:
return_intermediate=self.return_intermediate,
)

# [ntok * num_concurrent_media * num_chunks, bsz, hidden_size] -> [bsz, ntok * num_concurrent_media * num_chunks, hidden_size]
# [ntok * num_concurrent_media * num_chunks, bsz, hidden_size]
# -> [bsz, ntok * num_concurrent_media * num_chunks, hidden_size]
x, int_x = x.transpose(0, 1).contiguous(), int_x.transpose(0, 1).contiguous()
x = self.ln_post(x)
x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim)
Expand Down
Loading