Skip to content

Commit

Permalink
Add a test case for cached_tokens (#3145)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Jan 26, 2025
1 parent f8b28e4 commit d1a0863
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 63 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
| [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) |

## News
- [2024/12] 🔥 SGLang v0.4: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)).
- [2024/10] 🔥 The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)).
- [2024/09] SGLang v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)).
- [2024/07] Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).
- [2025/01] 🔥 SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeekSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html))
- [2024/12] 🔥 v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)).
- [2024/09] v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)).
- [2024/07] v0.2 Release: Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).

<details>
<summary>More</summary>

- [2024/10] The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)).
- [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
- [2024/04] SGLang is used by the official **LLaVA-NeXT (video)** release ([blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)).
- [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)).
- [2024/01] SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)).

Expand Down
29 changes: 14 additions & 15 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def __init__(

# The number of cached tokens, that were already cached in the KV cache
self.cached_tokens = 0
self.already_computed = 0

def extend_image_inputs(self, image_inputs):
if self.image_inputs is None:
Expand Down Expand Up @@ -750,13 +751,6 @@ def prepare_for_extend(self):

pt = 0
for i, req in enumerate(reqs):
already_computed = (
req.extend_logprob_start_len + 1 + req.cached_tokens
if req.extend_logprob_start_len > 0
else 0
)
req.cached_tokens += len(req.prefix_indices) - already_computed

req.req_pool_idx = req_pool_indices[i]
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
seq_lens.append(seq_len)
Expand All @@ -772,15 +766,20 @@ def prepare_for_extend(self):
# If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting

# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
extend_logprob_start_len = min(
req.logprob_start_len - pre_len, req.extend_input_len - 1
)
else:
extend_logprob_start_len = req.extend_input_len - 1
if req.return_logprob:
# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
extend_logprob_start_len = min(
req.logprob_start_len - pre_len, req.extend_input_len - 1
)
else:
raise RuntimeError(
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
)
req.extend_logprob_start_len = extend_logprob_start_len

req.extend_logprob_start_len = extend_logprob_start_len
req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len
req.is_retracted = False
pre_lens.append(pre_len)

Expand Down
58 changes: 29 additions & 29 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,24 +660,23 @@ def handle_generate_request(
self.waiting_queue.append(req)
return

# Copy more attributes
req.logprob_start_len = recv_req.logprob_start_len

if req.logprob_start_len == -1:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(req.origin_input_ids) - 1

# Validate prompts length
error_msg = validate_input_length(
req,
self.max_req_input_len,
self.server_args.allow_auto_truncate,
)

if error_msg:
self.waiting_queue.append(req)
return

# Copy more attributes
if recv_req.logprob_start_len == -1:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(req.origin_input_ids) - 1
else:
req.logprob_start_len = recv_req.logprob_start_len

req.sampling_params.max_new_tokens = min(
(
req.sampling_params.max_new_tokens
Expand Down Expand Up @@ -725,12 +724,17 @@ def handle_embedding_request(
req.tokenizer = self.tokenizer

# Validate prompts length
validate_input_length(
error_msg = validate_input_length(
req,
self.max_req_input_len,
self.server_args.allow_auto_truncate,
)
if error_msg:
self.waiting_queue.append(req)
return

# Copy more attributes
req.logprob_start_len = len(req.origin_input_ids) - 1
self.waiting_queue.append(req)

def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
Expand Down Expand Up @@ -1044,26 +1048,23 @@ def run_batch(
self.forward_ct += 1

if self.is_generation:
if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0:
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = (
self.tp_worker.forward_batch_generation(model_worker_batch)
)
else:
(
logits_output,
next_token_ids,
model_worker_batch,
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
num_accepted_tokens + batch.batch_size()
)
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
else:
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
(
logits_output,
next_token_ids,
model_worker_batch,
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
num_accepted_tokens + batch.batch_size()
)
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens
batch.output_ids = next_token_ids

ret = GenerationBatchResult(
Expand All @@ -1072,7 +1073,6 @@ def run_batch(
bid=model_worker_batch.bid,
)
else: # embedding or reward model
assert batch.extend_num_tokens != 0
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = EmbeddingBatchResult(
Expand Down
1 change: 0 additions & 1 deletion test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"test_eagle_infer.py",
"test_embedding_openai_server.py",
"test_eval_accuracy_mini.py",
"test_get_weights_by_name.py",
"test_gguf.py",
"test_input_embeddings.py",
"test_json_constrained.py",
Expand Down
7 changes: 0 additions & 7 deletions test/srt/test_ebnf_constrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,5 @@ def test_ebnf_generate_custom_log_format(self):
)


class TestJumpForward(TestEBNFConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, disable_overlap=True)
cls.check_jump_forward = True


if __name__ == "__main__":
unittest.main()
32 changes: 26 additions & 6 deletions test/srt/test_srt_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import json
import random
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
Expand Down Expand Up @@ -317,19 +318,38 @@ def test_custom_logit_processor(self):
"""Test custom logit processor with a single request."""
self.run_custom_logit_processor(target_token_id=5)

def test_custom_logit_processor_batch(self):
"""Test custom logit processor with a batch of requests."""
target_token_ids = list(range(32))
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))

def test_custom_logit_processor_batch_mixed(self):
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids = list(range(32)) + [None] * 16
random.shuffle(target_token_ids)
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))

def test_cache_tokens(self):
for _ in range(2):
time.sleep(1)
response = requests.post(self.base_url + "/flush_cache")
assert response.status_code == 200

def send_and_check_cached_tokens(input_ids):
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": list(input_ids),
"sampling_params": {
"max_new_tokens": 1,
},
},
)
response_json = response.json()
return response_json["meta_info"]["cached_tokens"]

self.assertEqual(send_and_check_cached_tokens(range(0, 100)), 0)
self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 100)
self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 9999)
self.assertEqual(send_and_check_cached_tokens(range(0, 1000)), 999)
self.assertEqual(send_and_check_cached_tokens(range(0, 11000)), 10000)

def test_get_server_info(self):
response = requests.get(self.base_url + "/get_server_info")
response_json = response.json()
Expand Down

0 comments on commit d1a0863

Please sign in to comment.