From 1fb94599087e4881c8b31dc4de46b1685fcaa124 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 22 Aug 2024 07:26:35 +1000 Subject: [PATCH] fix: custom op fallback forward native when lower sm80 (#1177) --- python/sglang/srt/layers/activation.py | 7 +++++++ python/sglang/srt/layers/layernorm.py | 3 +++ 2 files changed, 10 insertions(+) diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 7cd8abb6f96..a6f05610bd4 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -20,11 +20,18 @@ class SiluAndMul(CustomOp): + def __init__(self, **kwargs): + super().__init__() + self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8 + def forward_native(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + if self.is_lower_sm80: + return self.forward_native(x) + d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index ac4d368d3f6..6cea85404a0 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -32,12 +32,15 @@ def __init__( super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8 def forward_cuda( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if self.is_lower_sm80: + return self.forward_native(x, residual) if residual is not None: fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)