Skip to content

Commit

Permalink
fix: custom op fallback forward native when lower sm80 (#1177)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs authored Aug 21, 2024
1 parent bea2bb9 commit 1fb9459
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
7 changes: 7 additions & 0 deletions python/sglang/srt/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,18 @@


class SiluAndMul(CustomOp):
def __init__(self, **kwargs):
super().__init__()
self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
if self.is_lower_sm80:
return self.forward_native(x)

d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ def __init__(
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8

def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.is_lower_sm80:
return self.forward_native(x, residual)

if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
Expand Down

0 comments on commit 1fb9459

Please sign in to comment.