Skip to content

Commit

Permalink
Optimize Triton decoding kernel for long context (#2394)
Browse files Browse the repository at this point in the history
  • Loading branch information
ispobock authored Dec 8, 2024
1 parent 1f09e84 commit 7dc66fc
Show file tree
Hide file tree
Showing 4 changed files with 328 additions and 360 deletions.
21 changes: 13 additions & 8 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def __init__(self, model_runner: ModelRunner):
else:
self.reduce_dtype = torch.float16

self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]

self.forward_metadata = None

self.cuda_graph_max_seq_len = model_runner.model_config.context_len
Expand All @@ -53,10 +56,14 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)

total_num_tokens = forward_batch.seq_lens_sum
attn_logits = torch.empty(
(self.num_head, total_num_tokens),
dtype=self.reduce_dtype,
(
forward_batch.batch_size,
self.num_head,
self.num_kv_splits,
self.v_head_dim + 1,
),
dtype=torch.float32,
device=self.device,
)

Expand All @@ -75,11 +82,8 @@ def init_cuda_graph_state(self, max_bs: int):
(max_bs,), dtype=torch.int32, device=self.device
)
self.cuda_graph_attn_logits = torch.empty(
(
self.num_head,
self.cuda_graph_max_total_num_tokens,
),
dtype=self.reduce_dtype,
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
dtype=torch.float32,
device="cuda",
)

Expand Down Expand Up @@ -189,6 +193,7 @@ def forward_decode(
forward_batch.seq_lens,
attn_logits,
max_seq_len,
self.num_kv_splits,
layer.scaling,
layer.logit_cap,
)
Expand Down
Loading

0 comments on commit 7dc66fc

Please sign in to comment.