Skip to content

Commit

Permalink
handle encoder_lens and wrapper init
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 committed Oct 1, 2024
1 parent 7f23930 commit 64a5ebc
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 14 deletions.
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def init_forward_metadata_capture_cuda_graph(
raise NotImplementedError()

def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
self, bs: int, req_pool_indices, seq_lens, encoder_lens=None
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError()
Expand Down
17 changes: 14 additions & 3 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,21 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()

encoder_lens = [
im.num_image_tokens if im is not None else 0
for im in forward_batch.image_inputs
]
encoder_lens = torch.tensor(
encoder_lens, device="cuda", dtype=forward_batch.seq_lens.dtype
)

update_flashinfer_indices(
forward_batch.forward_mode,
self.model_runner,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
prefix_lens,
encoder_lens=encoder_lens,
use_ragged=use_ragged,
)

Expand Down Expand Up @@ -187,23 +196,25 @@ def init_forward_metadata_capture_cuda_graph(
req_pool_indices,
seq_lens,
None,
decode_wrappers,
encoder_lens=torch.zeros_like(seq_lens),
decode_wrappers=decode_wrappers,
)

self.cuda_graph_metadata[bs] = decode_wrappers

self.forward_metadata = (False, False, None, decode_wrappers)

def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
self, bs: int, req_pool_indices, seq_lens, encoder_lens=None
):
update_flashinfer_indices(
ForwardMode.DECODE,
self.model_runner,
req_pool_indices[:bs],
seq_lens[:bs],
None,
self.cuda_graph_metadata[bs],
encoder_lens=encoder_lens,
decode_wrappers=self.cuda_graph_metadata[bs],
)

def get_cuda_graph_seq_len_fill_value(self):
Expand Down
27 changes: 24 additions & 3 deletions python/sglang/srt/layers/attention/flashinfer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,15 @@ def __init__(
prefix_lens,
decode_wrappers=None,
use_ragged=False,
encoder_lens=None,
):
self.forward_mode = forward_mode
self.model_runner = model_runner
self.req_pool_indices = req_pool_indices
self.seq_lens = seq_lens
self.prefix_lens = prefix_lens
self.use_ragged = use_ragged
self.encoder_lens = encoder_lens

self.num_qo_heads = (
model_runner.model_config.num_attention_heads // model_runner.tp_size
Expand Down Expand Up @@ -161,6 +163,15 @@ def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0):
# full attention
paged_kernel_lens = self.seq_lens
self.kv_start_idx = self.seq_lens - paged_kernel_lens
elif dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
if wrapper_id == 0:
# Text Attention
paged_kernel_lens = self.seq_lens - self.encoder_lens
self.kv_start_idx = self.encoder_lens
else:
# Image Attention
paged_kernel_lens = self.encoder_lens
self.kv_start_idx = None

self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
Expand Down Expand Up @@ -192,7 +203,15 @@ def _update_indicess_single_wrapper(self):
)

def _update_indices_cross_attention(self):
pass
for wrapper_id in range(2):
self._get_indices(WrapperDispatch.CROSS_ATTENTION, wrapper_id)
if self.forward_mode.is_decode():
self._update_decode_indices(self.decode_wrappers[wrapper_id])
else:
self._update_extend_indices(
None,
self.prefill_wrappers_paged[wrapper_id],
)

def _update_indices_sliding_window(self):
assert self.use_ragged is False
Expand All @@ -213,6 +232,7 @@ def update_flashinfer_indices(
req_pool_indices,
seq_lens,
prefix_lens,
encoder_lens=None,
decode_wrappers=None,
use_ragged=False,
):
Expand All @@ -222,8 +242,9 @@ def update_flashinfer_indices(
req_pool_indices,
seq_lens,
prefix_lens,
decode_wrappers,
use_ragged,
decode_wrappers=decode_wrappers,
use_ragged=use_ragged,
encoder_lens=encoder_lens,
)

dispatch_reason = model_runner.attn_backend.dispatch_reason
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ def init_forward_metadata_capture_cuda_graph(
)

def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
self, bs: int, req_pool_indices, seq_lens, encoder_lens=None
):
assert encoder_lens is None
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)

Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class ImageInputs:
image_offsets: Optional[list] = None
pad_values: Optional[list] = None
modalities: Optional[list] = None
num_image_tokens: Optional[int] = None

image_embeds: Optional[List[torch.Tensor]] = None
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
Expand Down Expand Up @@ -781,15 +782,14 @@ def merge_batch(self, other: "ScheduleBatch"):

def get_model_worker_batch(self):
if self.forward_mode.is_decode():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
image_inputs
) = None
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
image_inputs = [r.image_inputs for r in self.reqs]

# NOTE: decode also has image_inputs
image_inputs = [r.image_inputs for r in self.reqs]
lora_paths = [req.lora_path for req in self.reqs]
self.sampling_info.regex_fsm_states = [req.regex_fsm_state for req in self.reqs]

Expand Down
11 changes: 10 additions & 1 deletion python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,18 @@ def replay(self, forward_batch: ForwardBatch):
self.seq_lens[:raw_bs] = forward_batch.seq_lens
self.out_cache_loc[:raw_bs] = forward_batch.out_cache_loc

# Encoder lens to initialize the attention wrappers
encoder_lens = [
im.num_image_tokens if im is not None else 0
for im in forward_batch.image_inputs
]
encoder_lens = torch.tensor(
encoder_lens, device="cuda", dtype=forward_batch.seq_lens.dtype
)

# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs, self.req_pool_indices, self.seq_lens
bs, self.req_pool_indices, self.seq_lens, encoder_lens=encoder_lens
)

# Replay
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def init_new(
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
out_cache_loc=batch.out_cache_loc,
image_inputs=batch.image_inputs,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
lora_paths=batch.lora_paths,
Expand All @@ -150,7 +151,6 @@ def init_new(
device=device,
).to(torch.int64)

ret.image_inputs = batch.image_inputs
ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, device=device
Expand Down

0 comments on commit 64a5ebc

Please sign in to comment.