Skip to content

Commit

Permalink
add geglu to mlp swap (NVIDIA#8999)
Browse files Browse the repository at this point in the history
* add geglu to mlp swap

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* match swiglu

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and galv committed Apr 29, 2024
1 parent 0ae8231 commit 66ff8cd
Showing 1 changed file with 11 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import torch
import torch.nn.functional as F
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.core.transformer.attention import SelfAttention
Expand Down Expand Up @@ -279,10 +281,16 @@ def forward(self, hidden_states):

if self.config.bias_activation_fusion:
if self.activation_func == F.gelu:
assert self.config.add_bias_linear is True
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
if self.config.gated_linear_unit:
intermediate_parallel = bias_geglu_impl(intermediate_parallel, bias_parallel)
else:
assert self.config.add_bias_linear is True
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
elif self.activation_func == F.silu and self.config.gated_linear_unit:
intermediate_parallel = bias_swiglu_impl(intermediate_parallel, bias_parallel)
intermediate_parallel = bias_swiglu_impl(
intermediate_parallel, bias_parallel, self.config.activation_func_fp8_input_store,
)

else:
raise ValueError("Only support fusion of gelu and swiglu")
else:
Expand Down

0 comments on commit 66ff8cd

Please sign in to comment.