Skip to content

Commit

Permalink
feat: use FlashInfer rmsnorm and silu
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs committed Aug 11, 2024
1 parent 43fbb6d commit 0d1ee26
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 10 deletions.
29 changes: 29 additions & 0 deletions python/sglang/srt/layers/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import torch
import torch.nn as nn
from flashinfer.activation import silu_and_mul


class SiluAndMul(nn.Module):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]

def forward(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
silu_and_mul(x, out)
return out
62 changes: 62 additions & 0 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
from flashinfer.norm import fused_add_rmsnorm, rmsnorm


class RMSNorm(nn.Module):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:

if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
return out

def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)

variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
if residual is None:
return x
else:
return x, residual
9 changes: 2 additions & 7 deletions python/sglang/srt/models/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from transformers import PretrainedConfig
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
Expand All @@ -38,13 +36,14 @@
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata


class InternLM2MLP(nn.Module):

def __init__(
self,
hidden_size: int,
Expand Down Expand Up @@ -74,7 +73,6 @@ def forward(self, x):


class InternLM2Attention(nn.Module):

def __init__(
self,
hidden_size: int,
Expand Down Expand Up @@ -150,7 +148,6 @@ def forward(


class InternLMDecoderLayer(nn.Module):

def __init__(
self,
config: PretrainedConfig,
Expand Down Expand Up @@ -207,7 +204,6 @@ def forward(


class InternLM2Model(nn.Module):

def __init__(
self,
config: PretrainedConfig,
Expand Down Expand Up @@ -254,7 +250,6 @@ def forward(


class InternLM2ForCausalLM(nn.Module):

def __init__(
self,
config: PretrainedConfig,
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/models/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from transformers import LlamaConfig
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
Expand All @@ -39,6 +37,8 @@
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if not server_args.disable_flashinfer:
assert_pkg_version(
"flashinfer",
"0.1.3",
"0.1.4",
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
Expand Down
60 changes: 60 additions & 0 deletions python/sglang/test/test_layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import itertools
import unittest

import torch

from sglang.srt.layers.layernorm import RMSNorm


class TestRMSNorm(unittest.TestCase):
DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 4096]
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
ADD_RESIDUAL = [False, True]
SEEDS = [0]

@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")

def _run_rms_norm_test(self, num_tokens, hidden_size, add_residual, dtype, seed):
torch.manual_seed(seed)

layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
residual = torch.randn_like(x) * scale if add_residual else None

with torch.inference_mode():
ref_out = layer.forward_native(x, residual)
out = layer(x, residual)

if add_residual:
self.assertTrue(torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2))
self.assertTrue(torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2))
else:
self.assertTrue(torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2))

def test_rms_norm(self):
for params in itertools.product(
self.NUM_TOKENS,
self.HIDDEN_SIZES,
self.ADD_RESIDUAL,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
num_tokens=params[0],
hidden_size=params[1],
add_residual=params[2],
dtype=params[3],
seed=params[4],
):
self._run_rms_norm_test(*params)


if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit 0d1ee26

Please sign in to comment.