From ee8864eacc5812a49b5eadc4af8f935355ad796c Mon Sep 17 00:00:00 2001 From: zhyncs Date: Tue, 27 Aug 2024 13:23:13 +0000 Subject: [PATCH 1/2] feat: update GemmaRMSNorm --- python/sglang/srt/layers/layernorm.py | 48 +++++++++++++++++++++++- python/sglang/srt/models/gemma2.py | 49 +----------------------- python/sglang/test/test_layernorm.py | 54 ++++++++++++++++++++++++++- 3 files changed, 101 insertions(+), 50 deletions(-) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index ac4d368d3f6..4c24f50ffe4 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -19,7 +19,12 @@ import torch import torch.nn as nn -from flashinfer.norm import fused_add_rmsnorm, rmsnorm +from flashinfer.norm import ( + fused_add_rmsnorm, + gemma_fused_add_rmsnorm, + gemma_rmsnorm, + rmsnorm, +) from vllm.model_executor.custom_op import CustomOp @@ -63,3 +68,44 @@ def forward_native( return x else: return x, residual + + +class GemmaRMSNorm(CustomOp): + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + if residual is not None: + x = x + residual + residual = x + + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x * (1.0 + self.weight.float()) + x = x.to(orig_dtype) + return x if residual is None else (x, residual) + + def forward_cuda( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + gemma_fused_add_rmsnorm( + x, residual, self.weight.data, self.variance_epsilon + ) + return x, residual + out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon) + return out diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index c6dbc7e5569..9bdf90505c4 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -25,8 +25,6 @@ # FIXME: temporary solution, remove after next vllm release from vllm.model_executor.custom_op import CustomOp - -# from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -39,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import GeluAndMul +from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -50,52 +49,6 @@ def get_attention_sliding_window_size(config): return config.sliding_window - 1 -class GemmaRMSNorm(CustomOp): - """RMS normalization for Gemma. - - Two differences from the above RMSNorm: - 1. x * (1 + w) instead of x * w. - 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w. - """ - - def __init__( - self, - hidden_size: int, - eps: float = 1e-6, - ) -> None: - super().__init__() - self.weight = nn.Parameter(torch.zeros(hidden_size)) - self.variance_epsilon = eps - - def forward_native( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """PyTorch-native implementation equivalent to forward().""" - orig_dtype = x.dtype - if residual is not None: - x = x + residual - residual = x - - x = x.float() - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) - # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) - # See https://github.com/huggingface/transformers/pull/29402 - x = x * (1.0 + self.weight.float()) - x = x.to(orig_dtype) - return x if residual is None else (x, residual) - - def forward_cuda( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - # from vLLM: TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm. - return self.forward_native(x, residual) - - # FIXME: temporary solution, remove after next vllm release from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding diff --git a/python/sglang/test/test_layernorm.py b/python/sglang/test/test_layernorm.py index ab61aa80405..770e69733db 100644 --- a/python/sglang/test/test_layernorm.py +++ b/python/sglang/test/test_layernorm.py @@ -3,7 +3,7 @@ import torch -from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm class TestRMSNorm(unittest.TestCase): @@ -56,5 +56,57 @@ def test_rms_norm(self): self._run_rms_norm_test(*params) +class TestGemmaRMSNorm(unittest.TestCase): + DTYPES = [torch.half, torch.bfloat16] + NUM_TOKENS = [7, 83, 4096] + HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199] + ADD_RESIDUAL = [False, True] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _run_gemma_rms_norm_test( + self, num_tokens, hidden_size, add_residual, dtype, seed + ): + torch.manual_seed(seed) + + layer = GemmaRMSNorm(hidden_size).to(dtype=dtype) + layer.weight.data.normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale + residual = torch.randn_like(x) * scale if add_residual else None + + with torch.inference_mode(): + ref_out = layer.forward_native(x, residual) + out = layer(x, residual) + + if add_residual: + self.assertTrue(torch.allclose(out[0], ref_out[0], atol=1e-3, rtol=1e-3)) + self.assertTrue(torch.allclose(out[1], ref_out[1], atol=1e-3, rtol=1e-3)) + else: + self.assertTrue(torch.allclose(out, ref_out, atol=1e-3, rtol=1e-3)) + + def test_gemma_rms_norm(self): + for params in itertools.product( + self.NUM_TOKENS, + self.HIDDEN_SIZES, + self.ADD_RESIDUAL, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + num_tokens=params[0], + hidden_size=params[1], + add_residual=params[2], + dtype=params[3], + seed=params[4], + ): + self._run_gemma_rms_norm_test(*params) + + if __name__ == "__main__": unittest.main(verbosity=2) From ae6fb2b00d746350f2cbfdec70ba4bef757286d6 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Wed, 28 Aug 2024 19:29:14 +1000 Subject: [PATCH 2/2] fix --- python/sglang/srt/models/gemma2.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 9bdf90505c4..3223424d79c 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -22,9 +22,6 @@ from transformers import PretrainedConfig from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size - -# FIXME: temporary solution, remove after next vllm release -from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear,