From c23768098d817f83069836a1581f6d9d0ab4cb92 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Tue, 27 Aug 2024 14:04:35 +0000 Subject: [PATCH] feat: replace GeluAndMul --- python/sglang/srt/layers/activation.py | 14 ++++++++++---- python/sglang/srt/models/gemma.py | 4 ++-- test/srt/models/test_generation_models.py | 2 +- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 5df387cb2b9..9047197af2f 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from flashinfer.activation import gelu_tanh_and_mul, silu_and_mul +from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from vllm.distributed import ( divide, get_tensor_model_parallel_rank, @@ -43,18 +43,24 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: class GeluAndMul(CustomOp): - def __init__(self, **kwargs): + def __init__(self, approximate="tanh"): super().__init__() + self.approximate = approximate def forward_native(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - return F.gelu(x[..., :d], approximate="tanh") * x[..., d:] + return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - gelu_tanh_and_mul(x, out) + if self.approximate == "tanh": + gelu_tanh_and_mul(x, out) + elif self.approximate == "none": + gelu_and_mul(x, out) + else: + raise RuntimeError("GeluAndMul only support tanh or none") return out diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 990937f5180..ae3b1b1948c 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -23,7 +23,6 @@ from transformers import PretrainedConfig from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -34,6 +33,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention @@ -60,7 +60,7 @@ def __init__( bias=False, quant_config=quant_config, ) - self.act_fn = GeluAndMul() + self.act_fn = GeluAndMul("none") def forward(self, x): gate_up, _ = self.gate_up_proj(x) diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index e38584741e0..08288c510c9 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -96,7 +96,7 @@ def assert_close_prefill_logits_and_output_strs( if hf_logprobs.shape[0] <= 100: assert torch.all( abs(hf_logprobs - srt_logprobs) < prefill_tolerance - ), "prefill logprobs are not all close" + ), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}" print(f"hf_outputs.output_strs={hf_outputs.output_strs}") print(f"srt_outputs.output_strs={srt_outputs.output_strs}")