Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multihead latent attention kernel #699

Draft
wants to merge 16 commits into
base: main_perf
Choose a base branch
from
Draft

Conversation

juuso-oskari
Copy link

@juuso-oskari juuso-oskari commented Jan 13, 2025

MLA.py

Kernel that implements Multihead Latent Attention from DeepSeekV3. I'm following the reference pytorch implementation here.
For the first version the idea is to just fuse the absorbed gemms into flash attention.
In pseudocode:

"""
Q_NOPE: shape bhsd, query tokens with no positional embedding applied
Q_PE: shape bhsr, query tokens with positional embedding applied
KV: shape btc, latent representation for keys and values
K_PE: shape btr, latent representation for keys, positional embedding applied
WKV_B: shape h(d+d)c, projection matrix for KV

b: batch size, h: num heads, s: query sequence length, t: key/value sequence length,
d/r/c: query/pos.emb./latent repr. head dim.

I'm currently running with: b=2, h=16, s=128, t=128, d=64, r=32, c=256. And with:
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2},
                      num_stages=1, num_warps=1),
"""
@triton.jit
def attn_fwd(Q_NOPE, Q_PE, KV, K_PE, WKV_B, ...):
   q_nope_ptrs = ... # of size BLOCK_M x d
   kv_ptrs = ... # of size c x BLOCK_N
   k_pe_ptrs = ... # of size r x BLOCK_N
   q_pe_ptrs = ... # of size BLOCK_M x r
   wkv_b_ptrs1 = ... # of size d x c
   wkv_b_ptrs2 = ... # of size d x c

   q_nope = tl.load(q_nope_ptrs)
   q_pe = tl.load(q_pe_ptrs)
   wkv_b1 = tl.load(wkv_b_ptrs1) # WKV_B[:, :self.qk_nope_head_dim], absorbed into q_nope
   wkv_b = tl.load(wkv_b_ptrs2) # WKV_B[:, -self.v_head_dim:], needed in _attn_fwd_inner

   q_nope = tl.dot(q_nope, wkv_b1) # absorbtion
   acc, l_i, m_i = _attn_fwd_inner(
                        acc, l_i, m_i, q_nope, q_pe, wkv_b, kv_ptrs, k_pe_ptrs, ...)
   
   l_recip = 1 / l_i[:, None]
   acc = acc * l_recip
   tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask)

@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q_nope, q_pe, wkv_b, kv_ptrs, k_pe_ptrs, ...):
   for start_n in range(block_min, block_max, BLOCK_N):
      kv = load_fn(kv_ptrs, ...)
      k_pe = load_fn(k_pe_ptrs, ...)

      qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
      # -- compute qk ----
      qk += tl.dot(q_nope.to(kv.type.element_ty), kv)
      qk += tl.dot(q_pe, k_pe)
      qk *= QK_SCALE

      # softmax
      m_ij = tl.maximum(m_i, tl.max(qk, 1))
      qk = qk - m_ij[:, None]
      p = tl.math.exp2(qk)

      l_ij = tl.sum(p, 1)

      alpha = tl.math.exp2(m_i - m_ij)
      acc = acc * alpha[:, None]

      # -- update m_i and l_i
      l_i = l_i * alpha + l_ij
      m_i = m_ij

      v = tl.dot(wkv_b, kv).trans() # not the actual v, but helps to think
      acc += tl.dot(p.to(v.type.element_ty), v)

      kv_ptrs += BLOCK_N * stride_kv_n
      k_pe_ptrs += BLOCK_N * stride_k_pe_n
   
   return acc, l_i, m_i

Launch of the kernel:

class _attention(torch.autograd.Function):

    @staticmethod
    def forward(ctx, Q_NOPE, Q_PE, KV, K_PE, O, WKV_B, metadata: MetaData):
      grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']), nheads_q, batch)
      attn_fwd[grid](
         Q_NOPE, Q_PE, KV, K_PE, WKV_B, ..., metadata.sm_scale, ..., O, ...)
      
      return O, _, _

And then how the kernel gets called:

global attn_impl

def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
   """
   Forward pass for the Multi-Headed Attention Layer (MLA).

   Args:
      x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
      start_pos (int): Starting position in the sequence for caching.
      freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
      mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.

   Returns:
      torch.Tensor: Output tensor with the same shape as the input.
   """
   bsz, seqlen, _ = x.size()
   end_pos = start_pos + seqlen
   if self.q_lora_rank == 0:
      q = self.wq(x)
   else:
      q = self.wq_b(self.q_norm(self.wq_a(x)))
   q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
   q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
   q_pe = apply_rotary_emb(q_pe, freqs_cis)
   kv = self.wkv_a(x)
   kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
   k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)

   if attn_impl == "flash":
      # "absorb" and Flash Attention fused
      wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) 
      wkv_b = wkv_b.view(self.n_heads, -1, self.kv_lora_rank)
      kv = self.kv_norm(kv)
      x = flash_attention(q_nope, q_pe, kv, k_pe.squeeze(2), wkv_b, self.qk_nope_head_dim, self.v_head_dim, self.softmax_scale)
   else:
      if attn_impl == "naive":
            q = torch.cat([q_nope, q_pe], dim=-1)
            kv = self.wkv_b(self.kv_norm(kv))
            kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
            k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
            self.k_cache[:bsz, start_pos:end_pos] = k
            self.v_cache[:bsz, start_pos:end_pos] = v
            scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
      else: # absorb
            wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(
               self.wkv_b.weight, self.wkv_b.scale, block_size)
            wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
            q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
            self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
            self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
            scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
                     torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale

      if mask is not None:
            scores += mask.unsqueeze(1)
      scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
      if attn_impl == "naive":
            x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
      else:
            x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
            x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
   x = self.wo(x.flatten(2))

   return x

def flash_attention(q_nope, q_pe, kv, k_pe, wkv_b, qk_nope_head_dim, v_head_dim, sm_scale):
    # q comes in as bshd, we assume bhsd inside kernel. So permute for now.
    q_nope = q_nope.permute((0, 2, 1, 3))
    q_pe = q_pe.permute((0, 2, 1, 3))
    o = torch.zeros((*q_nope.shape[:-1], v_head_dim), dtype=q_nope.dtype)
    Z, H, N, D = q_nope.shape
    _, _, _, input_metadata = input_helper(Z, H, H, N, N, D, q_nope.dtype, "bhsd", requires_grad=False)
    input_metadata.qk_nope_head_dim = qk_nope_head_dim
    input_metadata.v_head_dim = v_head_dim
    input_metadata.sm_scale = sm_scale
    o, _, _ = attention(q_nope, q_pe, kv, k_pe, o, wkv_b, input_metadata)
    return o.permute((0, 2, 1, 3))  # permute back to bshd

TODO:

  • Problem is that even the reference "naive" and "absorb" implementations do not produce the same result. This makes it hard to know if the MLA kernel implementation is correct or not.

  • Check that the math is right. Can we even fuse the "absorbed" implementation with FA, because it has that additional x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])? I suppose it shouldnt matter...

@juuso-oskari juuso-oskari self-assigned this Jan 13, 2025
@juuso-oskari juuso-oskari marked this pull request as draft January 14, 2025 09:56
@juuso-oskari juuso-oskari requested a review from vgokhale January 14, 2025 09:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant