diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c1398894ac2..d71cf55f6fd 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -425,7 +425,6 @@ class ScheduleBatch: req_pool_indices: torch.Tensor = None seq_lens: torch.Tensor = None out_cache_loc: torch.Tensor = None - output_ids: torch.Tensor = None # For processing logprobs @@ -442,27 +441,23 @@ class ScheduleBatch: # Stream has_stream: bool = False - # device - device: str = "cuda" - # Has regex has_regex: bool = False + # device + device: str = "cuda" + @classmethod def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): - return_logprob = any(req.return_logprob for req in reqs) - has_stream = any(req.stream for req in reqs) - has_regex = any(req.regex_fsm for req in reqs) - return cls( reqs=reqs, req_to_token_pool=req_to_token_pool, token_to_kv_pool=token_to_kv_pool, tree_cache=tree_cache, - return_logprob=return_logprob, - has_stream=has_stream, + return_logprob=any(req.return_logprob for req in reqs), + has_stream=any(req.stream for req in reqs), + has_regex=any(req.regex_fsm for req in reqs), device=req_to_token_pool.device, - has_regex=has_regex, ) def batch_size(self): @@ -754,7 +749,7 @@ def check_for_jump_forward(self, pad_input_ids_func): return jump_forward_reqs - def prepare_for_decode(self): + def prepare_for_decode(self, enable_overlap: bool = False): self.forward_mode = ForwardMode.DECODE self.input_ids = self.output_ids @@ -767,10 +762,19 @@ def prepare_for_decode(self): # Alloc mem bs = len(self.reqs) self.out_cache_loc = self.alloc_token_slots(bs) - self.req_to_token_pool.write( - (self.req_pool_indices, self.seq_lens), self.out_cache_loc - ) - self.seq_lens.add_(1) + + if enable_overlap: + # Do not use in-place operations in the overlap mode + self.req_to_token_pool.write( + (self.req_pool_indices, self.seq_lens), self.out_cache_loc + ) + self.seq_lens = self.seq_lens + 1 + else: + # A faster in-place version + self.req_to_token_pool.write( + (self.req_pool_indices, self.seq_lens), self.out_cache_loc + ) + self.seq_lens.add_(1) def filter_batch( self, @@ -882,6 +886,7 @@ def get_model_worker_batch(self): ) def copy(self): + # Only contain fields that will be used by process_batch_result return ScheduleBatch( reqs=self.reqs, forward_mode=self.forward_mode, @@ -940,9 +945,9 @@ def copy(self): return ModelWorkerBatch( bid=self.bid, forward_mode=self.forward_mode, - input_ids=self.input_ids.clone(), + input_ids=self.input_ids, req_pool_indices=self.req_pool_indices, - seq_lens=self.seq_lens.clone(), + seq_lens=self.seq_lens, out_cache_loc=self.out_cache_loc, req_to_token_pool_records=self.req_to_token_pool_records, return_logprob=self.return_logprob, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 7d20689ff9e..df4b5dfb4ac 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -103,6 +103,7 @@ def __init__( self.disable_regex_jump_forward = server_args.disable_regex_jump_forward self.lora_paths = server_args.lora_paths self.max_loras_per_batch = server_args.max_loras_per_batch + self.enable_overlap = server_args.enable_overlap_schedule # Init inter-process communication context = zmq.Context(2) @@ -146,7 +147,7 @@ def __init__( ) # Launch a tensor parallel worker - if self.server_args.enable_overlap_schedule: + if self.enable_overlap: TpWorkerClass = TpModelWorkerClient self.resolve_next_token_ids = ( lambda bid, x: self.tp_worker.resolve_future_token_ids(bid) @@ -670,7 +671,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: # Mixed-style chunked prefill if self.is_mixed_chunk and self.running_batch is not None: - self.running_batch.prepare_for_decode() + self.running_batch.prepare_for_decode(self.enable_overlap) new_batch.mix_with_running(self.running_batch) new_batch.decoding_reqs = self.running_batch.reqs self.running_batch = None @@ -717,7 +718,7 @@ def update_running_batch(self): return # Update batch tensors - batch.prepare_for_decode() + batch.prepare_for_decode(self.enable_overlap) def run_batch(self, batch: ScheduleBatch): """Run a batch.""" diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 5f5f92ca0ac..27a2d07fb27 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -51,7 +51,7 @@ def from_schedule_batch( disable_penalizer: bool, ): reqs = batch.reqs - device = batch.input_ids.device + device = batch.device temperatures = ( torch.tensor( [r.sampling_params.temperature for r in reqs], @@ -95,7 +95,7 @@ def from_schedule_batch( ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( vocab_size=vocab_size, batch=batch, - device=batch.input_ids.device, + device=batch.device, Penalizers={ penaltylib.BatchedFrequencyPenalizer, penaltylib.BatchedMinNewTokensPenalizer,