Skip to content

Commit

Permalink
Fix test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Jan 26, 2025
1 parent 0847053 commit 7aaf4b0
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 29 deletions.
19 changes: 11 additions & 8 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,15 +772,18 @@ def prepare_for_extend(self):
# If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting

# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
extend_logprob_start_len = min(
req.logprob_start_len - pre_len, req.extend_input_len - 1
)
else:
raise RuntimeError("This should never happen")
if req.return_logprob:
# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
extend_logprob_start_len = min(
req.logprob_start_len - pre_len, req.extend_input_len - 1
)
else:
raise RuntimeError(
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
)
req.extend_logprob_start_len = extend_logprob_start_len

req.extend_logprob_start_len = extend_logprob_start_len
req.is_retracted = False
pre_lens.append(pre_len)

Expand Down
37 changes: 17 additions & 20 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,28 +1051,26 @@ def run_batch(
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
"""Run a batch."""
self.forward_ct += 1
assert batch.extend_num_tokens != 0

if self.is_generation:
if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0:
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = (
self.tp_worker.forward_batch_generation(model_worker_batch)
)
else:
(
logits_output,
next_token_ids,
model_worker_batch,
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
num_accepted_tokens + batch.batch_size()
)
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
else:
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
(
logits_output,
next_token_ids,
model_worker_batch,
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
num_accepted_tokens + batch.batch_size()
)
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens
batch.output_ids = next_token_ids

ret = GenerationBatchResult(
Expand All @@ -1081,7 +1079,6 @@ def run_batch(
bid=model_worker_batch.bid,
)
else: # embedding or reward model
assert batch.extend_num_tokens != 0
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = EmbeddingBatchResult(
Expand Down
7 changes: 6 additions & 1 deletion test/srt/test_srt_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import json
import random
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
Expand Down Expand Up @@ -325,7 +326,11 @@ def test_custom_logit_processor_batch_mixed(self):
list(executor.map(self.run_custom_logit_processor, target_token_ids))

def test_cache_tokens(self):
response = requests.post(self.base_url + "/flush_cache")
for _ in range(5):
response = requests.post(self.base_url + "/flush_cache")
if response.status_code == 200:
break
time.sleep(1)
assert response.status_code == 200

def send_and_check_cached_tokens(input_ids):
Expand Down

0 comments on commit 7aaf4b0

Please sign in to comment.