Skip to content

Commit

Permalink
Simplify the usage of device (#1734)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Oct 21, 2024
1 parent 554fbf9 commit e12358d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 23 deletions.
41 changes: 23 additions & 18 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,6 @@ class ScheduleBatch:
req_pool_indices: torch.Tensor = None
seq_lens: torch.Tensor = None
out_cache_loc: torch.Tensor = None

output_ids: torch.Tensor = None

# For processing logprobs
Expand All @@ -442,27 +441,23 @@ class ScheduleBatch:
# Stream
has_stream: bool = False

# device
device: str = "cuda"

# Has regex
has_regex: bool = False

# device
device: str = "cuda"

@classmethod
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
return_logprob = any(req.return_logprob for req in reqs)
has_stream = any(req.stream for req in reqs)
has_regex = any(req.regex_fsm for req in reqs)

return cls(
reqs=reqs,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache,
return_logprob=return_logprob,
has_stream=has_stream,
return_logprob=any(req.return_logprob for req in reqs),
has_stream=any(req.stream for req in reqs),
has_regex=any(req.regex_fsm for req in reqs),
device=req_to_token_pool.device,
has_regex=has_regex,
)

def batch_size(self):
Expand Down Expand Up @@ -754,7 +749,7 @@ def check_for_jump_forward(self, pad_input_ids_func):

return jump_forward_reqs

def prepare_for_decode(self):
def prepare_for_decode(self, enable_overlap: bool = False):
self.forward_mode = ForwardMode.DECODE

self.input_ids = self.output_ids
Expand All @@ -767,10 +762,19 @@ def prepare_for_decode(self):
# Alloc mem
bs = len(self.reqs)
self.out_cache_loc = self.alloc_token_slots(bs)
self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
)
self.seq_lens.add_(1)

if enable_overlap:
# Do not use in-place operations in the overlap mode
self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
)
self.seq_lens = self.seq_lens + 1
else:
# A faster in-place version
self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
)
self.seq_lens.add_(1)

def filter_batch(
self,
Expand Down Expand Up @@ -882,6 +886,7 @@ def get_model_worker_batch(self):
)

def copy(self):
# Only contain fields that will be used by process_batch_result
return ScheduleBatch(
reqs=self.reqs,
forward_mode=self.forward_mode,
Expand Down Expand Up @@ -940,9 +945,9 @@ def copy(self):
return ModelWorkerBatch(
bid=self.bid,
forward_mode=self.forward_mode,
input_ids=self.input_ids.clone(),
input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens.clone(),
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
req_to_token_pool_records=self.req_to_token_pool_records,
return_logprob=self.return_logprob,
Expand Down
7 changes: 4 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
self.lora_paths = server_args.lora_paths
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = server_args.enable_overlap_schedule

# Init inter-process communication
context = zmq.Context(2)
Expand Down Expand Up @@ -146,7 +147,7 @@ def __init__(
)

# Launch a tensor parallel worker
if self.server_args.enable_overlap_schedule:
if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
Expand Down Expand Up @@ -670,7 +671,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:

# Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None:
self.running_batch.prepare_for_decode()
self.running_batch.prepare_for_decode(self.enable_overlap)
new_batch.mix_with_running(self.running_batch)
new_batch.decoding_reqs = self.running_batch.reqs
self.running_batch = None
Expand Down Expand Up @@ -717,7 +718,7 @@ def update_running_batch(self):
return

# Update batch tensors
batch.prepare_for_decode()
batch.prepare_for_decode(self.enable_overlap)

def run_batch(self, batch: ScheduleBatch):
"""Run a batch."""
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/sampling/sampling_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def from_schedule_batch(
disable_penalizer: bool,
):
reqs = batch.reqs
device = batch.input_ids.device
device = batch.device
temperatures = (
torch.tensor(
[r.sampling_params.temperature for r in reqs],
Expand Down Expand Up @@ -95,7 +95,7 @@ def from_schedule_batch(
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
vocab_size=vocab_size,
batch=batch,
device=batch.input_ids.device,
device=batch.device,
Penalizers={
penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer,
Expand Down

0 comments on commit e12358d

Please sign in to comment.