diff --git a/benchmark/kernels/minmax-text-01-lighting_attention/benchmark_lighting_attention_decode.py b/benchmark/kernels/minmax-text-01-lighting_attention/benchmark_lighting_attention_decode.py index 1a2036dc0ae..4ce7f2b499d 100644 --- a/benchmark/kernels/minmax-text-01-lighting_attention/benchmark_lighting_attention_decode.py +++ b/benchmark/kernels/minmax-text-01-lighting_attention/benchmark_lighting_attention_decode.py @@ -23,7 +23,10 @@ def _decode_kernel( h: tl.constexpr, n: tl.constexpr, d: tl.constexpr, + d_original: tl.constexpr, e: tl.constexpr, + e_original: tl.constexpr, + BLOCK_SIZE: tl.constexpr = 32, ): off_bh = tl.program_id(0) off_h = off_bh % h @@ -39,21 +42,38 @@ def _decode_kernel( d_idx = tl.arange(0, d) e_idx = tl.arange(0, e) - q = tl.load(Q + qk_offset + d_idx) - k = tl.load(K + qk_offset + d_idx) - v = tl.load(V + v_offset + e_idx) + # Create masks for original dimensions + d_mask = d_idx < d_original + e_mask = e_idx < e_original - kv = tl.load(KV + kv_offset + d_idx[:, None] * e + e_idx[None, :]) + # Load with masking + q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0) + k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0) + v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0) + # Load KV with 2D masking + kv = tl.load( + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + mask=(d_mask[:, None] & e_mask[None, :]), + other=0.0, + ) + + # Compute outer product using element-wise operations k_v_prod = k[:, None] * v[None, :] kv = ratio * kv + k_v_prod + # Store KV with 2D masking tl.store( - KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], kv.to(KV.dtype.element_ty) + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + kv.to(KV.dtype.element_ty), + mask=(d_mask[:, None] & e_mask[None, :]), ) + # Compute matrix-vector multiplication using element-wise operations and reduction o = tl.sum(q[:, None] * kv, axis=0) - tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty)) + + # Store output with masking + tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask) def lightning_attn_decode(q, k, v, kv, s): @@ -62,26 +82,27 @@ def lightning_attn_decode(q, k, v, kv, s): e = v.shape[-1] assert n == 1, "Sequence length must be 1 in decode mode" - # Pad dimensions to power of 2 + # Get padded dimensions (power of 2) d_padded = next_power_of_2(d) e_padded = next_power_of_2(e) - # Pad inputs - q_padded = F.pad(q, (0, d_padded - d)) - k_padded = F.pad(k, (0, d_padded - d)) - v_padded = F.pad(v, (0, e_padded - e)) - kv_padded = F.pad(kv, (0, e_padded - e, 0, d_padded - d)) - - # Ensure inputs are contiguous - q_padded = q_padded.contiguous() - k_padded = k_padded.contiguous() - v_padded = v_padded.contiguous() - kv_padded = kv_padded.contiguous().to(torch.float32) - s = s.contiguous() - # Create output tensor (padded) o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device) + # Create padded tensors without actually padding the data + q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device) + k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device) + v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device) + kv_padded = torch.empty( + b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device + ) + + # Copy data to padded tensors + q_padded[..., :d] = q + k_padded[..., :d] = k + v_padded[..., :e] = v + kv_padded[..., :d, :e] = kv + # Launch kernel grid = (b * h, 1) _decode_kernel[grid]( @@ -95,10 +116,12 @@ def lightning_attn_decode(q, k, v, kv, s): h=h, n=n, d=d_padded, + d_original=d, e=e_padded, + e_original=e, ) - # Remove padding + # Get unpadded outputs o = o_padded[..., :e] kv_out = kv_padded[..., :d, :e] @@ -351,6 +374,8 @@ def test_lightning_attention_implementations(model_params): msg="Lightning attention implementations produce different kv results", ) + print("✅ Two implementations match") + def _build_slope_tensor(n_attention_heads: int): def get_slopes(n): @@ -375,7 +400,7 @@ def get_slopes_power_of_2(n): def get_benchmark(): - batch_size_range = [2**i for i in range(0, 12)] # max 2048 + batch_size_range = [i for i in range(1, 33)] # max 32 seq_length_range = [1] # decode mode sequence length is fixed to 1 configs = list(itertools.product(batch_size_range, seq_length_range))