Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce hardcoded logic of kernel usage #707

Merged
merged 2 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
return o

def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
self.store_kv_cache(k, v, input_metadata)
if not input_metadata.use_ragged:
self.store_kv_cache(k, v, input_metadata)

if input_metadata.total_num_tokens <= 4096:
o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
Expand Down Expand Up @@ -122,6 +122,8 @@ def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):

o, _ = merge_state(o1, s1, o2, s2)

self.store_kv_cache(k, v, input_metadata)

if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
torch.cuda.synchronize()

Expand Down
29 changes: 18 additions & 11 deletions python/sglang/srt/managers/controller/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,7 @@ class InputMetadata:
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
use_ragged = False

@classmethod
def create(
Expand All @@ -742,13 +743,16 @@ def create(
skip_flashinfer_init=False,
):
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
if forward_mode != ForwardMode.DECODE and total_num_tokens > 4096:
use_ragged = True
init_flashinfer_args(
forward_mode,
model_runner,
req_pool_indices,
seq_lens,
prefix_lens,
model_runner.flashinfer_decode_wrapper,
use_ragged,
)

batch_size = len(req_pool_indices)
Expand Down Expand Up @@ -803,6 +807,7 @@ def create(
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
use_ragged=use_ragged,
)

if model_runner.server_args.disable_flashinfer:
Expand All @@ -823,6 +828,7 @@ def init_flashinfer_args(
seq_lens,
prefix_lens,
flashinfer_decode_wrapper,
use_ragged=False,
):
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
Expand All @@ -831,10 +837,10 @@ def init_flashinfer_args(
batch_size = len(req_pool_indices)
total_num_tokens = int(torch.sum(seq_lens))

if forward_mode == ForwardMode.DECODE or total_num_tokens <= 4096:
paged_kernel_lens = seq_lens
else:
if use_ragged:
paged_kernel_lens = prefix_lens
else:
paged_kernel_lens = seq_lens

kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
Expand Down Expand Up @@ -867,14 +873,15 @@ def init_flashinfer_args(
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)

model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
qo_indptr,
qo_indptr,
num_qo_heads,
num_kv_heads,
head_dim,
)
if use_ragged:
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
qo_indptr,
qo_indptr,
num_qo_heads,
num_kv_heads,
head_dim,
)

# cached part
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
Expand Down
Loading