From 7aaf4b0b17f1e026377762ba00ab42b71d4b3e15 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 26 Jan 2025 00:05:34 -0800 Subject: [PATCH] Fix test cases --- python/sglang/srt/managers/schedule_batch.py | 19 +++++----- python/sglang/srt/managers/scheduler.py | 37 +++++++++----------- test/srt/test_srt_endpoint.py | 7 +++- 3 files changed, 34 insertions(+), 29 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 717d16550b5..197f11b79c5 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -772,15 +772,18 @@ def prepare_for_extend(self): # If req.input_embeds is already a list, append its content directly input_embeds.extend(req.input_embeds) # Use extend to avoid nesting - # Compute the relative logprob_start_len in an extend batch - if req.logprob_start_len >= pre_len: - extend_logprob_start_len = min( - req.logprob_start_len - pre_len, req.extend_input_len - 1 - ) - else: - raise RuntimeError("This should never happen") + if req.return_logprob: + # Compute the relative logprob_start_len in an extend batch + if req.logprob_start_len >= pre_len: + extend_logprob_start_len = min( + req.logprob_start_len - pre_len, req.extend_input_len - 1 + ) + else: + raise RuntimeError( + f"This should never happen. {req.logprob_start_len=}, {pre_len=}" + ) + req.extend_logprob_start_len = extend_logprob_start_len - req.extend_logprob_start_len = extend_logprob_start_len req.is_retracted = False pre_lens.append(pre_len) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index d1bfe7ac82f..82063862224 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1051,28 +1051,26 @@ def run_batch( ) -> Union[GenerationBatchResult, EmbeddingBatchResult]: """Run a batch.""" self.forward_ct += 1 + assert batch.extend_num_tokens != 0 if self.is_generation: - if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0: - if self.spec_algorithm.is_none(): - model_worker_batch = batch.get_model_worker_batch() - logits_output, next_token_ids = ( - self.tp_worker.forward_batch_generation(model_worker_batch) - ) - else: - ( - logits_output, - next_token_ids, - model_worker_batch, - num_accepted_tokens, - ) = self.draft_worker.forward_batch_speculative_generation(batch) - self.spec_num_total_accepted_tokens += ( - num_accepted_tokens + batch.batch_size() - ) - self.spec_num_total_forward_ct += batch.batch_size() - self.num_generated_tokens += num_accepted_tokens + if self.spec_algorithm.is_none(): + model_worker_batch = batch.get_model_worker_batch() + logits_output, next_token_ids = self.tp_worker.forward_batch_generation( + model_worker_batch + ) else: - assert False, "batch.extend_num_tokens == 0, this is unexpected!" + ( + logits_output, + next_token_ids, + model_worker_batch, + num_accepted_tokens, + ) = self.draft_worker.forward_batch_speculative_generation(batch) + self.spec_num_total_accepted_tokens += ( + num_accepted_tokens + batch.batch_size() + ) + self.spec_num_total_forward_ct += batch.batch_size() + self.num_generated_tokens += num_accepted_tokens batch.output_ids = next_token_ids ret = GenerationBatchResult( @@ -1081,7 +1079,6 @@ def run_batch( bid=model_worker_batch.bid, ) else: # embedding or reward model - assert batch.extend_num_tokens != 0 model_worker_batch = batch.get_model_worker_batch() embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) ret = EmbeddingBatchResult( diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 8d0eaa3bb90..630658643d0 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -5,6 +5,7 @@ import json import random +import time import unittest from concurrent.futures import ThreadPoolExecutor from typing import Optional @@ -325,7 +326,11 @@ def test_custom_logit_processor_batch_mixed(self): list(executor.map(self.run_custom_logit_processor, target_token_ids)) def test_cache_tokens(self): - response = requests.post(self.base_url + "/flush_cache") + for _ in range(5): + response = requests.post(self.base_url + "/flush_cache") + if response.status_code == 200: + break + time.sleep(1) assert response.status_code == 200 def send_and_check_cached_tokens(input_ids):