diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index fe9e900f4ad07..a85c155cc0a85 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -80,7 +80,7 @@ def mcore_register_adapters(self): if ( self.config.sequence_parallel and hasattr(self.linear_qkv, "return_layernorm_output_gathered") - and not self.config.tp_comm_overlap + and not self.linear_qkv.ub_overlap_ag ): # for LoRA SP, TE v1.5 can return layernorm output gathered so there is no need # to perform the redundant gather in the adapter module, unless TP communication @@ -142,11 +142,19 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): if SplitAlongDim is not None: # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list,) + (query, key, value) = SplitAlongDim( + mixed_qkv, + 3, + split_arg_list, + ) else: # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3,) + (query, key, value) = torch.split( + mixed_qkv, + split_arg_list, + dim=3, + ) # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) @@ -231,11 +239,21 @@ def forward( if self.checkpoint_core_attention: core_attn_out = self._checkpointed_attention_forward( - query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params, + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + packed_seq_params=packed_seq_params, ) else: core_attn_out = self.core_attention( - query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params, + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + packed_seq_params=packed_seq_params, ) if packed_seq_params is not None: @@ -316,7 +334,9 @@ def forward(self, hidden_states): intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) elif self.activation_func == F.silu and self.config.gated_linear_unit: intermediate_parallel = bias_swiglu_impl( - intermediate_parallel, bias_parallel, self.config.activation_func_fp8_input_store, + intermediate_parallel, + bias_parallel, + self.config.activation_func_fp8_input_store, ) else: