Skip to content

Commit

Permalink
Fix Triton decode kernel & ut (sgl-project#1819)
Browse files Browse the repository at this point in the history
  • Loading branch information
ispobock authored and zolinthecow committed Oct 29, 2024
1 parent fff0abd commit fc7444e
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 40 deletions.
131 changes: 101 additions & 30 deletions python/sglang/srt/layers/attention/triton_ops/decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,18 @@ def _fwd_grouped_kernel_stage1(
Lk: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_kv_head = tl.program_id(1)
cur_head_id = tl.program_id(1)
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
start_n = tl.program_id(2)

reduce_dtype = Att_Out.dtype.element_ty
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num

if BLOCK_H < kv_group_num:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)

offs_d = tl.arange(0, BLOCK_DMODEL)
Expand Down Expand Up @@ -400,10 +406,15 @@ def _fwd_grouped_kernel_stage2(
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_kv_head = tl.program_id(1)
cur_head_id = tl.program_id(1)
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)

cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
if BLOCK_H < kv_group_num:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)

cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
Expand Down Expand Up @@ -485,7 +496,7 @@ def _decode_grouped_att_m_fwd(
batch, head_num = B_req_idx.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[1]

BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
grid = (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
Expand Down Expand Up @@ -534,7 +545,7 @@ def _decode_grouped_softmax_reducev_fwd(
BLOCK = 128
batch, head_num = b_seq_len.shape[0], logits.shape[0]
kv_group_num = logits.shape[0] // v_buffer.shape[1]
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)

num_warps = 8
Expand Down Expand Up @@ -567,6 +578,80 @@ def _decode_grouped_softmax_reducev_fwd(
)


def decode_attention_fwd_normal(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
max_len_in_batch,
sm_scale,
logit_cap=0.0,
):
_decode_att_m_fwd(
q,
k_buffer,
attn_logits,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
max_len_in_batch,
sm_scale,
logit_cap,
)
_decode_softmax_reducev_fwd(
attn_logits,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
)


def decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
max_len_in_batch,
sm_scale,
logit_cap=0.0,
):
_decode_grouped_att_m_fwd(
q,
k_buffer,
attn_logits,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
max_len_in_batch,
sm_scale,
logit_cap,
)
_decode_grouped_softmax_reducev_fwd(
attn_logits,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
)


def decode_attention_fwd(
q,
k_buffer,
Expand All @@ -585,47 +670,33 @@ def decode_attention_fwd(

if kv_group_num == 1:
# MHA
_decode_att_m_fwd(
decode_attention_fwd_normal(
q,
k_buffer,
attn_logits,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
max_len_in_batch,
sm_scale,
logit_cap,
)
_decode_softmax_reducev_fwd(
attn_logits,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
)
else:
# GQA/MQA/MLA
_decode_grouped_att_m_fwd(
decode_attention_fwd_grouped(
q,
k_buffer,
attn_logits,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
max_len_in_batch,
sm_scale,
logit_cap,
)
_decode_grouped_softmax_reducev_fwd(
attn_logits,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
)
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _fwd_kernel(
def context_attention_fwd(
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
):
if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
BLOCK = 128
else:
BLOCK = 64
Expand Down
3 changes: 2 additions & 1 deletion test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
"test_srt_endpoint.py",
"test_torch_compile.py",
"test_torchao.py",
"test_triton_attn_backend.py",
"test_triton_attention_kernels.py",
"test_triton_attention_backend.py",
"test_update_weights.py",
"test_vision_openai_server.py",
],
Expand Down
File renamed without changes.
120 changes: 112 additions & 8 deletions test/srt/test_triton_attention_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

import torch

from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
from sglang.srt.layers.attention.triton_ops.decode_attention import (
decode_attention_fwd,
decode_attention_fwd_grouped,
decode_attention_fwd_normal,
)
from sglang.srt.layers.attention.triton_ops.extend_attention import (
extend_attention_fwd,
redundant_attention,
Expand All @@ -13,7 +17,7 @@
)


class TestExtendAttention(unittest.TestCase):
class TestTritonAttention(unittest.TestCase):

def _set_all_seeds(self, seed):
"""Set all random seeds for reproducibility."""
Expand Down Expand Up @@ -127,7 +131,7 @@ def test_extend_attention(self):
for value in attention_values:
self._test_extend_attention_once(19, 12331, 12, 4, value)

def _test_context_attention_once(self, head_dim):
def _test_context_attention_once(self, head_dim, is_causal):
# Set up a simple test case
num_heads = 4
seq_lens = [8, 12]
Expand All @@ -143,15 +147,35 @@ def _test_context_attention_once(self, head_dim):
b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda")
b_seq_len = torch.tensor(seq_lens, device="cuda")

context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_seq_len)
context_attention_fwd(
q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal
)

cu_seq_lens = [0] * (len(seq_lens) + 1)
for i, seq_len in enumerate(seq_lens):
cu_seq_lens[i + 1] = cu_seq_lens[i] + seq_len

for i in range(len(seq_lens)):
start, end = cu_seq_lens[i], cu_seq_lens[i + 1]
o_torch = torch.nn.functional.scaled_dot_product_attention(
q[start:end].permute(1, 0, 2),
k[start:end].permute(1, 0, 2),
v[start:end].permute(1, 0, 2),
is_causal=is_causal,
).permute(1, 0, 2)

cos_sim = torch.nn.functional.cosine_similarity(
o[start:end].flatten(), o_torch.flatten(), dim=0
)
self.assertTrue(cos_sim.item() > 1 - (1e-5))
self.assertTrue(torch.allclose(o[start:end], o_torch, atol=1e-2))

def test_context_attention(self):
# Here we just to ensure there is no error
# TODO: correctnesss test
head_dim = [128, 96, 80, 13]

for dim in head_dim:
self._test_context_attention_once(dim)
for is_causal in [True, False]:
self._test_context_attention_once(dim, is_causal)

def _test_decode_attention_once(self, B, H_Q, H_KV, D):
dtype = torch.bfloat16
Expand All @@ -174,6 +198,12 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D):
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
b_seq_len = torch.full((B,), seq_len, device="cuda")

attn_logits = torch.empty(
(H_Q, total_tokens),
dtype=dtype,
device="cuda",
)

decode_attention_fwd(
q,
k_buffer,
Expand All @@ -183,8 +213,8 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D):
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
seq_len,
total_tokens,
sm_scale,
)

Expand All @@ -203,6 +233,80 @@ def test_decode_attention(self):
for B, H_Q, H_KV, D in configs:
self._test_decode_attention_once(B, H_Q, H_KV, D)

def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V):
dtype = torch.bfloat16
seq_len = 10 # This represents the number of tokens already in the sequence
total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5)

# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")

# k_buffer and v_buffer represent all previous tokens
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda")

# o will have the same shape as q
o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
o_grouped = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")

req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
b_req_idx = torch.arange(B, device="cuda")
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
b_seq_len = torch.full((B,), seq_len, device="cuda")

attn_logits = torch.empty(
(H_Q, total_tokens),
dtype=dtype,
device="cuda",
)

decode_attention_fwd_normal(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
seq_len,
sm_scale,
)

decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o_grouped,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
seq_len,
sm_scale,
)

cos_sim = torch.nn.functional.cosine_similarity(
o.flatten(), o_grouped.flatten(), dim=0
)
self.assertTrue(cos_sim.item() > 0.99)
self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2))

def test_grouped_decode_attention(self):
configs = [
(2, 16, 1, 64, 64),
(2, 64, 1, 13, 13),
(2, 128, 1, 80, 80),
(2, 128, 2, 512, 512),
(2, 128, 1, 576, 512),
]

for B, H_Q, H_KV, D, D_V in configs:
self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V)


if __name__ == "__main__":
unittest.main()

0 comments on commit fc7444e

Please sign in to comment.