Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update GemmaRMSNorm #1232

Merged
merged 3 commits into from
Aug 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
@@ -19,7 +19,12 @@

import torch
import torch.nn as nn
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from flashinfer.norm import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
)
from vllm.model_executor.custom_op import CustomOp


@@ -63,3 +68,44 @@ def forward_native(
return x
else:
return x, residual


class GemmaRMSNorm(CustomOp):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps

def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
if residual is not None:
x = x + residual
residual = x

x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x * (1.0 + self.weight.float())
x = x.to(orig_dtype)
return x if residual is None else (x, residual)

def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
gemma_fused_add_rmsnorm(
x, residual, self.weight.data, self.variance_epsilon
)
return x, residual
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
52 changes: 1 addition & 51 deletions python/sglang/srt/models/gemma2.py
Original file line number Diff line number Diff line change
@@ -22,11 +22,6 @@
from transformers import PretrainedConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size

# FIXME: temporary solution, remove after next vllm release
from vllm.model_executor.custom_op import CustomOp

# from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
@@ -39,6 +34,7 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import GemmaRMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -50,52 +46,6 @@ def get_attention_sliding_window_size(config):
return config.sliding_window - 1


class GemmaRMSNorm(CustomOp):
"""RMS normalization for Gemma.

Two differences from the above RMSNorm:
1. x * (1 + w) instead of x * w.
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
"""

def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps

def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
if residual is not None:
x = x + residual
residual = x

x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
x = x * (1.0 + self.weight.float())
x = x.to(orig_dtype)
return x if residual is None else (x, residual)

def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# from vLLM: TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
return self.forward_native(x, residual)


# FIXME: temporary solution, remove after next vllm release
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding

54 changes: 53 additions & 1 deletion python/sglang/test/test_layernorm.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@

import torch

from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm


class TestRMSNorm(unittest.TestCase):
@@ -56,5 +56,57 @@ def test_rms_norm(self):
self._run_rms_norm_test(*params)


class TestGemmaRMSNorm(unittest.TestCase):
DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 4096]
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
ADD_RESIDUAL = [False, True]
SEEDS = [0]

@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")

def _run_gemma_rms_norm_test(
self, num_tokens, hidden_size, add_residual, dtype, seed
):
torch.manual_seed(seed)

layer = GemmaRMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
residual = torch.randn_like(x) * scale if add_residual else None

with torch.inference_mode():
ref_out = layer.forward_native(x, residual)
out = layer(x, residual)

if add_residual:
self.assertTrue(torch.allclose(out[0], ref_out[0], atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(out[1], ref_out[1], atol=1e-3, rtol=1e-3))
else:
self.assertTrue(torch.allclose(out, ref_out, atol=1e-3, rtol=1e-3))

def test_gemma_rms_norm(self):
for params in itertools.product(
self.NUM_TOKENS,
self.HIDDEN_SIZES,
self.ADD_RESIDUAL,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
num_tokens=params[0],
hidden_size=params[1],
add_residual=params[2],
dtype=params[3],
seed=params[4],
):
self._run_gemma_rms_norm_test(*params)


if __name__ == "__main__":
unittest.main(verbosity=2)