Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: replace GeluAndMul #1234

Merged
merged 2 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions python/sglang/srt/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from flashinfer.activation import gelu_tanh_and_mul, silu_and_mul
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
Expand All @@ -43,18 +43,24 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:


class GeluAndMul(CustomOp):
def __init__(self, **kwargs):
def __init__(self, approximate="tanh"):
super().__init__()
self.approximate = approximate

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate="tanh") * x[..., d:]
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
gelu_tanh_and_mul(x, out)
if self.approximate == "tanh":
gelu_tanh_and_mul(x, out)
elif self.approximate == "none":
gelu_and_mul(x, out)
else:
raise RuntimeError("GeluAndMul only support tanh or none")
return out


Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from transformers import PretrainedConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
Expand All @@ -34,6 +33,7 @@
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
Expand All @@ -60,7 +60,7 @@ def __init__(
bias=False,
quant_config=quant_config,
)
self.act_fn = GeluAndMul()
self.act_fn = GeluAndMul("none")

def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
Expand Down
2 changes: 1 addition & 1 deletion test/srt/models/test_generation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def assert_close_prefill_logits_and_output_strs(
if hf_logprobs.shape[0] <= 100:
assert torch.all(
abs(hf_logprobs - srt_logprobs) < prefill_tolerance
), "prefill logprobs are not all close"
), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}"

print(f"hf_outputs.output_strs={hf_outputs.output_strs}")
print(f"srt_outputs.output_strs={srt_outputs.output_strs}")
Expand Down
Loading