Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[kernel] Use sgl_kernel rope #3169

Merged
merged 10 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@

import torch
import torch.nn as nn
from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp

from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.utils import is_cuda_available

_is_cuda_available = is_cuda_available()
if _is_cuda_available:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace


def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -75,7 +81,9 @@ def __init__(
self.dtype = dtype

cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
if not _is_cuda_available:
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)

Expand Down Expand Up @@ -141,17 +149,25 @@ def forward_cuda(
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops

self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
if _is_cuda_available:
apply_rope_with_cos_sin_cache_inplace(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything is going well except for an accuracy issue with test_session_control.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect the test is flaky by nature. I switched from 1B to 8B model on the main and test_session_control failed with different output: #3184.

Here is my finding so far

Kernel 1B Model 8B Model
vLLM Pass Fail
flashinfer Fail Fail

Given other tests for accuracy all pass. I think the correctness looks ok

positions=positions,
query=query,
key=key,
head_size=self.head_size,
cos_sin_cache=self.cos_sin_cache,
is_neox=self.is_neox_style,
)
else:
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key

def forward_xpu(
Expand Down
21 changes: 17 additions & 4 deletions test/srt/test_session_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_session_control(self, gen_len=12):
chunks_ids[i] = chunks_ids[i][1:]

# 1. using session control
requests.post(self.base_url + "/flush_cache")
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
Expand Down Expand Up @@ -215,7 +216,9 @@ def test_session_control(self, gen_len=12):
print(outputs_from_session)
print("outputs from normal queries:")
print(outputs_normal)
assert outputs_from_session == outputs_normal
assert (
outputs_from_session == outputs_normal
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"

async def async_generate(self, payload):
url = self.base_url + "/generate"
Expand Down Expand Up @@ -250,6 +253,7 @@ async def run_session_control_backtrack_with_abort(self, replace):
chunks_ids[i] = chunks_ids[i][1:]

# 1. using session control
requests.post(self.base_url + "/flush_cache")
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
Expand Down Expand Up @@ -320,6 +324,7 @@ async def run_session_control_backtrack_with_abort(self, replace):
assert response["meta_info"]["finish_reason"]["type"] == "abort"
else:
# 2. not using session control
requests.post(self.base_url + "/flush_cache")
output_ids = tokenizer.encode(gen_so_far)
if output_ids[0] == tokenizer.bos_token_id:
output_ids = output_ids[1:]
Expand All @@ -342,7 +347,9 @@ async def run_session_control_backtrack_with_abort(self, replace):
output_no_session = response["text"]
print("second request output without session:")
print(output_no_session)
assert second_output == output_no_session
assert (
second_output == output_no_session
), f"second_output: {second_output}, output_no_session: {output_no_session}"

def test_session_control_backtrack_with_abort(self):
asyncio.run(self.run_session_control_backtrack_with_abort(replace=True))
Expand All @@ -355,6 +362,7 @@ def run_session_control_with_branching(
assert len(x) == len(chunks_per_step[0])

# 1. using session control
requests.post(self.base_url + "/flush_cache")
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
Expand Down Expand Up @@ -459,7 +467,9 @@ def run_session_control_with_branching(
print(outputs_from_session)
print("====== outputs from normal queries: =======")
print(outputs_normal)
assert outputs_from_session == outputs_normal
assert (
outputs_from_session == outputs_normal
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"

def test_session_control_with_branching(self):
root_prompt = "First, let me explain in one sentence about AI"
Expand Down Expand Up @@ -525,6 +535,7 @@ def test_session_control(self):
gen_len = 32

# 1. using session control
requests.post(self.base_url + "/flush_cache")
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
Expand Down Expand Up @@ -691,7 +702,9 @@ def test_session_control(self):
print(outputs_from_session)
print("outputs from normal queries:")
print(outputs_normal)
assert outputs_from_session == outputs_normal
assert (
outputs_from_session == outputs_normal
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"


if __name__ == "__main__":
Expand Down
Loading