diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index 5ddce76eab404..7b14c67a2e570 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -30,7 +30,7 @@ from tvm.target import Target from .position_embedding import llama_rope_with_position_map, switch_rope_freq_func -from .tree_attn import tree_attn +from .tree_attn import tree_attn, tree_attn_with_paged_kv_cache def get_max_num_threads_per_block(target: Target) -> int: @@ -257,6 +257,7 @@ def __init__( # pylint: disable=too-many-locals bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"), + bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"), rope_ext_factors, # fmt: on # pylint: enable=line-too-long diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py b/python/tvm/relax/frontend/nn/llm/tree_attn.py index 069eb4892348c..db04aa1511656 100644 --- a/python/tvm/relax/frontend/nn/llm/tree_attn.py +++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py @@ -62,11 +62,29 @@ def _rope( return expr -def _tree_mask(row, col, mask_ptr, offset, stride, kv_len): - return tir.all(col < kv_len, mask_ptr[offset + row * stride + col] == 1) +def _check_tree_order(tree_order_indptr, tree_order, batch, row, col, kv_len, qo_len): + tree_order_len = tree_order_indptr[batch + 1] - tree_order_indptr[batch] + + tree_start = kv_len - tree_order_len + child_idx_in_tree = row + tree_order_len - qo_len + parent_idx_in_tree = col - tree_start + return tir.all( + col < kv_len, + tir.any( + col < tree_start, + tir.all( + tree_order[tree_order_indptr[batch] + child_idx_in_tree, 0] + >= tree_order[tree_order_indptr[batch] + parent_idx_in_tree, 0], + tree_order[tree_order_indptr[batch] + child_idx_in_tree, 0] + < tree_order[tree_order_indptr[batch] + parent_idx_in_tree, 1], + ), + ), + ) -def tree_attn(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target): +def tree_attn( + h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target +): # pylint: disable=unused-argument """Generate tree attention kernel for batched tree attention. Parameters @@ -87,7 +105,7 @@ def tree_attn(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target) mod : tvm.IRModule The generated IR module. """ - # pylint: disable=line-too-long + # pylint: disable=invalid-name,line-too-long NUM_BLKS = 16 LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes group_size = h_q // h_kv @@ -140,7 +158,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset) q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset) mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", elem_offset=mn_indptr_elem_offset) - mask = T.match_buffer(var_mask, (tree_size,), "int32", elem_offset=mask_elem_offset) + mask = T.match_buffer(var_mask, (tree_size, 2), "int32", elem_offset=mask_elem_offset) output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable @@ -276,12 +294,13 @@ def batch_tree_attn( # pylint: disable=too-many-branches # mask out of kv_chunk_len S row_: T.int32 = (LH_start + row) // group_size for j in T.serial(tile_z): - if _tree_mask( + if _check_tree_order( row=row_, col=L_kv_start + j, - mask_ptr=mask, - offset=mn_indptr[b_idx], - stride=q_indptr[b_idx + 1] - q_indptr[b_idx], + batch=b_idx, + tree_order=mask, + tree_order_indptr=mn_indptr, + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx], kv_len=kv_chunk_len[0]): m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) @@ -293,12 +312,13 @@ def batch_tree_attn( # pylint: disable=too-many-branches # this is to avoid sync inside condition branch if row < tile_x: row_: T.int32 = (LH_start + row) // group_size - if _tree_mask( + if _check_tree_order( row=row_, col=L_kv_start + j, - mask_ptr=mask, - offset=mn_indptr[b_idx], - stride=q_indptr[b_idx + 1] - q_indptr[b_idx], + batch=b_idx, + tree_order=mask, + tree_order_indptr=mn_indptr, + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx], kv_len=kv_chunk_len[0]): S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: @@ -345,7 +365,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches # move to next tile tile_id[0] += NUM_BLKS # fmt: on - # pylint: enable=line-too-long,too-many-branches + # pylint: enable=line-too-long,invalid-name,too-many-branches sch = tir.Schedule(batch_tree_attn) def get_tile_size(x, y, t): @@ -415,3 +435,493 @@ def apply_to_md(sch, block): apply_to_md(sch, sch.get_block("lse_store")) return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +def tree_attn_with_paged_kv_cache( + h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target +): + """Generate tree attention kernel for batched tree attention with paged key-value cache. + + Parameters + ---------- + h_kv : int + Number of heads for key and value. + h_q : int + Number of heads for query. + d : int + Hidden dimension. + dtype : str + Data type. + target : Target + The target device. + + Returns + ------- + mod : tvm.IRModule + The generated IR module. + """ + # pylint: disable=import-outside-toplevel + from .kv_cache import ( + _declare_length_info, + _get_kv_chunk_len, + _get_seq_offset, + check_thread_limits, + ) + + # pylint: disable=invalid-name, line-too-long + NUM_BLKS = 16 + LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + + bdx = 32 + num_warps = 4 + tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + + # Otherwise we would exceed maxComputeWorkgroupStorageSize + if ( + str(target.kind) == "webgpu" + and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 + ): + tile_z = 8 + num_warps = 2 + check_thread_limits(target, bdx=bdx, bdy=num_warps, bdz=1, gdz=1) + + global_symbol = "tree_attn_paged_kv" + sliding_window = False # Sliding window is not supported in this kernel. + + # fmt: off + @T.prim_func + def tree_attn_paged_kv( + _0: T.int32, # pylint: disable=unused-argument + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d] + var_page_indptr: T.handle, # [batch_size + 1] + var_page_values: T.handle, # [nnz_pages] + var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] + var_k_rope_pos_offset: T.handle, # [b] + var_q_rope_position: T.handle, # [total_len] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + tree_order_indptr_handle: T.handle, # [batch_size + 1] + tree_order_handle: T.handle, # [total_len, 2] + ): + # pylint: disable=unused-variable, too-many-branches + T.func_attr({"global_symbol": global_symbol}) + batch_size = T.int32(is_size_var=True) + total_len = T.int32(is_size_var=True) + nnz_pages = T.int32(is_size_var=True) + max_num_pages = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + page_indptr_elem_offset = T.int32(is_size_var=True) + page_values_elem_offset = T.int32(is_size_var=True) + length_info_elem_offset = T.int32(is_size_var=True) + tree_order_elem_offset = T.int32(is_size_var=True) + tree_order_indptr_elem_offset = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (total_len, h_q, d), dtype) + q_indptr = T.match_buffer( + var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset + ) + pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype) + page_indptr = T.match_buffer( + var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset + ) + page_values = T.match_buffer( + var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset + ) + k_rope_pos_offset = T.match_buffer( + var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset + ) + q_rope_position = T.match_buffer( + var_q_rope_position, (total_len,), "int32", elem_offset=q_rope_position_elem_offset + ) + output = T.match_buffer(var_output, (total_len, h_q, d), dtype) + lse = T.match_buffer( + var_lse, (total_len, h_q), "float32" + ) # pylint: disable=unused-variable + tree_order_indptr = T.match_buffer( + tree_order_indptr_handle, + (batch_size + 1,), + "int32", + elem_offset=tree_order_indptr_elem_offset, + ) + total_tree_order_len = T.int32(is_size_var=True) + tree_order = T.match_buffer( + tree_order_handle, + (total_tree_order_len, 2), + "int32", + elem_offset=tree_order_elem_offset, + ) + # The length information of the sequences. + # - It is in shape `(3, batch_size)` when sliding window is enabled. + # For a sequence "i", location + # - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), + # - "(1, i)" is the starting offset of the sliding window in the seq, + # - "(2, i)" is the attn sink length of the sequence. + # - It is in shape `(batch_size,)` when sliding window is disabled, + # denoting the "last_page_len". + length_info = _declare_length_info( + var_length_info, batch_size, sliding_window, length_info_elem_offset + ) + + T.Assert( + rotary_mode == T.int32(0), "Inline rotary mode is not supported in tree attention." + ) + + # kernel code + for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): + for lby in T.thread_binding(h_kv, thread="blockIdx.y"): + for lty in T.thread_binding(num_warps, thread="threadIdx.y"): + for ltx in T.thread_binding(bdx, thread="threadIdx.x"): + with T.block("attn"): + bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) + T.reads() + T.writes() + tile_id = _var("int32") + batch_idx = _var("int32") + batch_tiles = _var("int32") + batch_rows = _var("int32") + iterator = _var("int32") + kv_chunk_len = _var("int32") + + Q_smem = T.alloc_buffer((tile_x, d), dtype, scope="shared") + K_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + V_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") + + S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") + O_local = T.alloc_buffer((tile_x, d), "float32", scope="local") + + m_smem = T.alloc_buffer((tile_x,), "float32", scope="shared") + m_prev_smem = T.alloc_buffer((tile_x,), "float32", scope="shared") + d_smem = T.alloc_buffer((tile_x,), "float32", scope="shared") + + m_new = T.alloc_buffer( + (math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local" + ) + m_prev = T.alloc_buffer( + (math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local" + ) + d_new = T.alloc_buffer( + (math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local" + ) + + ## get tile_no, batch_idx, batch_tiles, batch_rows + tile_id[0] = bx + batch_idx[0] = 0 + batch_rows[0] = (q_indptr[1] - q_indptr[0]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + while T.tvm_thread_invariant(batch_idx[0] < batch_size): + # advance to next tile + while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: + tile_id[0] -= batch_tiles[0] + batch_idx[0] += 1 + if batch_idx[0] < batch_size: + b_idx: T.int32 = batch_idx[0] + batch_rows[0] = ( + q_indptr[b_idx + 1] - q_indptr[b_idx] + ) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + + if T.tvm_thread_invariant(batch_idx[0] < batch_size): + b_idx: T.int32 = batch_idx[0] + LH_start: T.int32 = tile_id[0] * tile_x + q_indptr_val: T.int32 = q_indptr[b_idx] + + cur_page_indptr_begin: T.int32 = page_indptr[b_idx] + cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] + kv_chunk_len[0] = T.if_then_else( + cur_page_indptr_begin != cur_page_indptr_end, + _get_kv_chunk_len( + cur_page_indptr_end - cur_page_indptr_begin, + 16, + b_idx, + length_info, + sliding_window, + ), + 0, + ) + T.tvm_storage_sync("shared") + + # init states + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + m_smem[row] = -5e4 + d_smem[row] = 1.0 + + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_init"): + i, j = T.axis.remap("SS", [li, lj]) + O_local[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Load Q from gmem to smem + for li, lj in T.grid(tile_x, tile_y): + with T.block("Q_load"): + i, j = T.axis.remap("SS", [li, lj]) + T.reads() + T.writes() + cur_L = q_indptr_val + (LH_start + i) // group_size + cur_H_qo = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + Q_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope( + q, + q_rope_position[cur_L], + d, + rope_theta, + rope_scale, + (cur_L, cur_H_qo, j), + dtype, + rope_scaling, + ), + q[cur_L, cur_H_qo, j], + ) + else: + Q_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): + L_kv_start: T.int32 = iterator * tile_z + for lz, ly in T.grid(tile_z, tile_y): + with T.block("K_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_chunk_len[0]: + seq_offset: T.int32(is_size_var=True) = _get_seq_offset(cur_L, b_idx, length_info, sliding_window) # type: ignore + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore + K_smem[i, j] = pages[ + page_no, 0, by, page_offset, j + ] + else: + K_smem[i, j] = 0.0 + + T.tvm_storage_sync("shared") + for lz, ly in T.grid(tile_z, tile_y): + with T.block("V_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_chunk_len[0]: + seq_offset: T.int32(is_size_var=True) = _get_seq_offset(cur_L, b_idx, length_info, sliding_window) # type: ignore + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore + V_smem[i, j] = pages[ + page_no, 1, by, page_offset, j + ] + else: + V_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Compute S + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_z, tile_y): + with T.block("S_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + S_local[i, j] = 0.0 + S_local[i, j] += ( + T.cast(Q_smem[i, k], "float32") + * T.cast(K_smem[j, k], "float32") + * attn_score_scaling_factor + * sm_scale + ) + T.tvm_storage_sync("shared") + for li, lj in T.grid(tile_x, tile_z): + with T.block("S_store"): + i, j = T.axis.remap("SS", [li, lj]) + S_smem[i, j] = S_local[i, j] + T.tvm_storage_sync("shared") + + # Update S, m, d + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update1"): + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] + # mask out of kv_chunk_len S + row_: T.int32 = (LH_start + row) // group_size + for j in T.serial(tile_z): + if _check_tree_order( + tree_order_indptr=tree_order_indptr, + tree_order=tree_order, + batch=b_idx, + row=row_, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] + - q_indptr[b_idx], + ): + m_new[i] = T.max( + m_new[i], S_smem[row, j] + ) + d_new[i] = d_smem[row] * T.exp2( + m_prev[i] - m_new[i] + ) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + with T.block("update"): + for j in T.serial(tile_z): + # this is to avoid sync inside condition branch + if row < tile_x: + row_: T.int32 = ( + LH_start + row + ) // group_size + if _check_tree_order( + tree_order_indptr=tree_order_indptr, + tree_order=tree_order, + batch=b_idx, + row=row_, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] + - q_indptr[b_idx], + ): + S_smem[row, j] = T.exp2( + S_smem[row, j] - m_new[i] + ) + else: + S_smem[row, j] = T.exp2(-5e4 - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update"): + for j in T.serial(tile_z): + d_new[i] += S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] + T.tvm_storage_sync("shared") + + # Update O + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_y, tile_z): + with T.block("O_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + O_local[i, j] *= T.exp2( + m_prev_smem[i] - m_smem[i] + ) + O_local[i, j] += S_smem[i, k] * T.cast( + V_smem[k, j], "float32" + ) + + # Store O from smem to gmem + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_store"): + i, j = T.axis.remap("SS", [li, lj]) + cur_L: T.int32 = ( + q_indptr[b_idx] + (LH_start + i) // group_size + ) + cur_H_qo: T.int32 = ( + by * group_size + (LH_start + i) % group_size + ) + if cur_L < q_indptr[b_idx + 1]: + output[cur_L, cur_H_qo, j] = ( + O_local[i, j] / d_smem[i] + ) + + # Store LSE to gmem + for li in T.grid(tile_x): + with T.block("lse_store"): + i = T.axis.remap("S", [li]) + cur_L: T.int32 = ( + q_indptr[b_idx] + (LH_start + i) // group_size + ) + cur_H_qo: T.int32 = ( + by * group_size + (LH_start + i) % group_size + ) + if cur_L < q_indptr[b_idx + 1]: + lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) + + # move to next tile + tile_id[0] += NUM_BLKS + + # fmt: on + # pylint: enable=line-too-long,too-many-branches + sch = tir.Schedule(tree_attn_paged_kv) + + def get_tile_size(x, y, t): + cnt = (x * y) // t + assert (x * y) % t == 0 + tile_y = (int)(math.ceil(math.sqrt(cnt))) + while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: + tile_y += 1 + assert tile_y <= cnt + tile_x = cnt // tile_y + return tile_x, tile_y + + def apply_to_qkv_load(sch: tir.Schedule, block): + loop_x, loop_y = sch.get_loops(block)[-2:] + loop = sch.fuse(loop_x, loop_y) + _, ty, tx, vec = sch.split( + loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True + ) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + def apply_to_so_ewise(sch: tir.Schedule, block, tile): + loop_x, loop_y = sch.get_loops(block)[-2:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + def apply_to_gemm( # pylint: disable=unused-argument + sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False + ): + loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + ko, ki = sch.split(loop_z, factors=[None, r_len]) + if k_major: + sch.reorder(ko, xi, yi, ki) + else: + sch.reorder(ko, ki, xi, yi) + sch.decompose_reduction(block, ty) + + def apply_to_md(sch, block): + loop = sch.get_loops(block)[-1] + _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) + tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) + apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) + apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) + apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) + apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) + apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) + apply_to_qkv_load(sch, sch.get_block("Q_load")) + apply_to_qkv_load(sch, sch.get_block("K_load")) + apply_to_qkv_load(sch, sch.get_block("V_load")) + apply_to_md(sch, sch.get_block("lse_store")) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 591187ab5fe78..8809a1b0729e7 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -330,9 +330,9 @@ class PagedKVCacheAuxDataManager { */ virtual NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) = 0; /*! \brief Copy the tree attention mask. */ - virtual NDArray CopyTreeAttnMaskAsync(HostMemoryVector* data) = 0; + virtual NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! \brief Copy the mn indptr of the tree attention mask. */ - virtual NDArray CopyTreeAttnMNIndptrAsync(HostMemoryVector* data) = 0; + virtual NDArray CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! \brief Commit all the attention auxiliary data copy operations since the last commit. */ virtual void CommitAttnAuxDataCopy() = 0; @@ -379,14 +379,15 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { NDArray::Empty({3, reserved_num_seqs}, dtype_aux_, device)); k_rope_pos_offset_on_depths_device_.push_back( NDArray::Empty({reserved_num_seqs}, dtype_aux_, device)); + tree_attn_mask_device_.push_back(NDArray::Empty( + {kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * reserved_num_seqs}, dtype_aux_, device)); + tree_attn_mn_indptr_device_.push_back( + NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device)); } cur_append_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); k_ragged_rope_pos_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, device); q_rope_position_map_device_ = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); append_position_map_device_ = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); - tree_attn_mask_device_ = NDArray::Empty( - {kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * reserved_num_seqs}, dtype_aux_, device); - tree_attn_mn_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); commit_copy_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); commit_copy_src_dst_pos_in_page_table_device_ = @@ -450,15 +451,15 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyTreeAttnMaskAsync(HostMemoryVector* data) final { + NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) final { NDArray view = - tree_attn_mask_device_.CreateView({static_cast(data->size())}, dtype_aux_); + tree_attn_mask_device_[depth].CreateView({static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyTreeAttnMNIndptrAsync(HostMemoryVector* data) final { - NDArray view = - tree_attn_mn_indptr_device_.CreateView({static_cast(data->size())}, dtype_aux_); + NDArray CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { + NDArray view = tree_attn_mn_indptr_device_[depth].CreateView( + {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } @@ -557,12 +558,12 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { std::vector page_indices_on_depths_device_; std::vector length_info_on_depths_device_; std::vector k_rope_pos_offset_on_depths_device_; + std::vector tree_attn_mask_device_; + std::vector tree_attn_mn_indptr_device_; NDArray cur_append_length_indptr_device_; NDArray k_ragged_rope_pos_offset_device_; NDArray q_rope_position_map_device_; NDArray append_position_map_device_; - NDArray tree_attn_mask_device_; - NDArray tree_attn_mn_indptr_device_; NDArray commit_copy_length_indptr_device_; NDArray commit_copy_src_dst_pos_in_page_table_device_; }; @@ -630,10 +631,11 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyTreeAttnMaskAsync(HostMemoryVector* data) final { - return CopyAttnAuxVecToCache(data); + NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) final { + NDArray mask_1d = CopyAttnAuxVecToCache(data); + return mask_1d.CreateView({static_cast(data->size() / 2), 2}, mask_1d->dtype); } - NDArray CopyTreeAttnMNIndptrAsync(HostMemoryVector* data) final { + NDArray CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyAttnAuxVecToCache(data); } NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, @@ -894,7 +896,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /*! \brief The append lengths of the sequences in the current round of forwarding. */ IntTuple cur_append_lengths_; /*! \brief Whether the current batch of sequences are token chains (not token trees). */ - bool is_chain_; + std::vector is_chain_on_depths_; /*! \brief Number of fork depth in the current round of forward. */ int num_depths_; /*! \brief Whether to compute attention after appending KV into cache or not. */ @@ -930,8 +932,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector q_rope_position_map_host_; HostMemoryVector append_position_map_host_; HostMemoryVector cur_append_lengths_indptr_host_; - HostMemoryVector tree_attn_mask_host_; - HostMemoryVector tree_attn_mn_indptr_host_; + std::vector tree_attn_mask_host_; + std::vector tree_attn_mn_indptr_host_; HostMemoryVector commit_copy_length_indptr_host_; HostMemoryVector commit_copy_src_pos_in_page_table_host_; HostMemoryVector commit_copy_dst_pos_in_page_table_host_; @@ -947,8 +949,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray k_ragged_rope_pos_offset_view_; NDArray q_rope_position_map_view_; NDArray append_position_map_view_; - NDArray tree_attn_mask_view_; - NDArray tree_attn_mn_indptr_view_; NDArray temp_attn_output_view_; NDArray temp_attn_scores_view_; NDArray merged_attn_scores_view_; @@ -957,6 +957,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector page_indices_on_depths_view_; std::vector length_info_on_depths_view_; std::vector k_rope_pos_offset_view_; + std::vector tree_attn_mask_view_; + std::vector tree_attn_mn_indptr_view_; PackedFunc f_transpose_append_; PackedFunc f_compact_copy_; @@ -966,6 +968,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { PackedFunc f_attention_decode_sliding_window_; PackedFunc f_attention_prefill_ragged_; PackedFunc f_attention_prefill_with_tree_mask_; + PackedFunc f_attention_prefill_with_tree_mask_paged_kv_; Optional f_attention_prefill_ragged_begin_forward_; Optional f_attention_prefill_ragged_end_forward_; Optional f_attention_prefill_begin_forward_; @@ -996,6 +999,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { PackedFunc f_attention_decode, PackedFunc f_attention_prefill_sliding_window, PackedFunc f_attention_decode_sliding_window, PackedFunc f_attention_prefill_ragged, PackedFunc f_attention_prefill_with_tree_mask, + PackedFunc f_attention_prefill_with_tree_mask_paged_kv, Optional f_attention_prefill_ragged_begin_forward, Optional f_attention_prefill_ragged_end_forward, Optional f_attention_prefill_begin_forward, @@ -1025,6 +1029,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_attention_decode_sliding_window_(std::move(f_attention_decode_sliding_window)), f_attention_prefill_ragged_(std::move(f_attention_prefill_ragged)), f_attention_prefill_with_tree_mask_(std::move(f_attention_prefill_with_tree_mask)), + f_attention_prefill_with_tree_mask_paged_kv_( + std::move(f_attention_prefill_with_tree_mask_paged_kv)), f_attention_prefill_ragged_begin_forward_( std::move(f_attention_prefill_ragged_begin_forward)), f_attention_prefill_ragged_end_forward_(std::move(f_attention_prefill_ragged_end_forward)), @@ -1059,6 +1065,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); k_rope_pos_offset_on_depths_host_.push_back( HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); + tree_attn_mask_host_.push_back(HostMemoryVector(kTreeAttnMaxTreeSize * 2 * reserved_num_seqs, + dtype_aux_, preferred_host_device)); + tree_attn_mn_indptr_host_.push_back( + HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device)); } k_ragged_rope_pos_offset_host_ = HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device); @@ -1068,11 +1078,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device); cur_append_lengths_indptr_host_ = HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device); - tree_attn_mask_host_ = - HostMemoryVector(kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * reserved_num_seqs, - dtype_aux_, preferred_host_device); - tree_attn_mn_indptr_host_ = - HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device); commit_copy_length_indptr_host_ = HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device); commit_copy_src_pos_in_page_table_host_ = @@ -1092,6 +1097,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { page_indices_on_depths_view_.push_back(NDArray()); length_info_on_depths_view_.push_back(NDArray()); k_rope_pos_offset_view_.push_back(NDArray()); + tree_attn_mask_view_.push_back(NDArray()); + tree_attn_mn_indptr_view_.push_back(NDArray()); + is_chain_on_depths_.push_back(true); } // Additional workspace for the "prefill with ragged kv" kernel. if (NeedKernelBeginForward()) { @@ -1492,36 +1500,18 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { sequences.push_back(&it->second); last_block_length_before_append.push_back( global_block_pool_[it->second.last_block_idx].seq_length); - k_ragged_rope_pos_offset_host_.push_back(it->second.seq_length); + int k_rope_offset = it->second.seq_length; + if (!it->second.accepted_indices_committed) { + int tree_size = static_cast(it->second.token_tree_parent_ptr.size()); + k_rope_offset -= tree_size; + } + k_ragged_rope_pos_offset_host_.push_back(k_rope_offset); it->second.seq_length += append_lengths[i]; if (append_lengths[i] != 1) { is_decode_request_ = false; } } - // - Check token tree validity and process the token tree. - is_chain_ = true; - tree_attn_mask_host_.clear(); - tree_attn_mn_indptr_host_.clear(); - if (opt_token_tree_parent_ptr.defined()) { - is_chain_ = ConstructTokenTreeMask(sequences, opt_token_tree_parent_ptr.value()); - } else { - // The input batch does not form trees. So each sequence in the batch - // is required to have all past accepted tokens committed. - for (int i = 0; i < cur_batch_size_; ++i) { - Sequence* sequence = sequences[i]; - CHECK(sequence->accepted_indices_committed) - << "The input batch does not form a tree, in which case the sequences in the input " - "batch are expected to have their accepted tokens token tree nodes committed. " - "Please invoke CommitAcceptedTokenTreeNodes for sequence " - << seq_ids[i]; - sequence->is_chain = true; - sequence->token_tree_parent_ptr.clear(); - sequence->token_tree_node_depths.clear(); - } - is_chain_ = true; - } - auto [block_ids_on_depths, trailing_blocks] = GetBlockIdsOnDepth(sequences); num_depths_ = std::min(static_cast(block_ids_on_depths.size()), kPagedKVCacheMaxBlockDepth); @@ -1552,6 +1542,36 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::fill(use_decode_kernel_.begin(), use_decode_kernel_.end(), /*value=*/false); } + bool has_previous_tree = + std::any_of(sequences.begin(), sequences.end(), + [](const Sequence* sequence) { return !sequence->accepted_indices_committed; }); + if (has_previous_tree) { + append_before_attn_ = true; + } + + // - Check token tree validity and process the token tree. + if (opt_token_tree_parent_ptr.defined()) { + CHECK(!support_sliding_window_) << "Tree attention does not support sliding window."; + CHECK(rope_mode_ != RoPEMode::kInline) << "Tree attention does not support inline RoPE mode."; + ConstructTokenTreeMask(sequences, opt_token_tree_parent_ptr.value(), block_ids_on_depths, + trailing_blocks); + } else { + // The input batch does not form trees. So each sequence in the batch + // is required to have all past accepted tokens committed. + for (int i = 0; i < cur_batch_size_; ++i) { + Sequence* sequence = sequences[i]; + CHECK(sequence->accepted_indices_committed) + << "The input batch does not form a tree, in which case the sequences in the input " + "batch are expected to have their accepted tokens token tree nodes committed. " + "Please invoke CommitAcceptedTokenTreeNodes for sequence " + << seq_ids[i]; + sequence->is_chain = true; + sequence->token_tree_parent_ptr.clear(); + sequence->token_tree_node_depths.clear(); + } + std::fill(is_chain_on_depths_.begin(), is_chain_on_depths_.end(), true); + } + if (append_before_attn_) { // Right now we use different kernels when depth is 1 or not 1. // For the case where maximum depth is 1, we create the auxiliary @@ -1656,9 +1676,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { int64_t append_length = append_lengths[i]; const Block& block = global_block_pool_[sequences[i]->last_block_idx]; for (int64_t pos = 0; pos < append_length; ++pos) { - q_rope_position_map_host_.push_back( - k_ragged_rope_pos_offset_host_[i] + - (is_chain_ ? pos : sequences[i]->token_tree_node_depths[pos])); + if (sequences[i]->token_tree_node_depths.empty()) { + q_rope_position_map_host_.push_back(k_ragged_rope_pos_offset_host_[i] + pos); + } else { + int64_t offset_in_tree = + static_cast(sequences[i]->token_tree_parent_ptr.size()) - append_length; + ICHECK_GE(offset_in_tree, 0); + q_rope_position_map_host_.push_back( + k_ragged_rope_pos_offset_host_[i] + + sequences[i]->token_tree_node_depths[offset_in_tree + pos]); + } int32_t pos_in_block = block.seq_length - append_length + pos; if (last_block_length_before_append[i] + pos < block.sink_length) { @@ -1763,12 +1790,14 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector sequences; sequences.reserve(num_seq_to_commit); + bool is_chain = true; for (int i = 0; i < num_seq_to_commit; ++i) { auto it = seq_map_.find(seq_ids[i]); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i] << "\" cannot be found in KV cache."; sequences.push_back(&it->second); - CHECK(!it->second.accepted_indices_committed) + is_chain = it->second.is_chain; + CHECK(leaf_indices[i] == -1 || !it->second.accepted_indices_committed) << "The accepted nodes of sequence " << seq_ids[i] << " are already committed."; CHECK_GE(leaf_indices[i], -1) << "Invalid tree index " << leaf_indices[i] << " which is less than -1"; @@ -1778,7 +1807,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { << it->second.token_tree_parent_ptr.size() << " of the sequence"; } - if (!is_chain_) { + if (!is_chain) { commit_copy_length_indptr_host_.clear(); commit_copy_src_pos_in_page_table_host_.clear(); commit_copy_dst_pos_in_page_table_host_.clear(); @@ -1787,6 +1816,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int i = 0; i < num_seq_to_commit; ++i) { if (leaf_indices[i] == -1) { // No node is accepted. All nodes in the token tree need to be popped. + commit_copy_length_indptr_host_.push_back(commit_copy_length_indptr_host_.back()); continue; } @@ -1935,78 +1965,134 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { return block_idx; } - bool ConstructTokenTreeMask(const std::vector& sequences, - const IntTuple& token_tree_parent_ptr) { - // We check if the token tree deteriorates to a chain, - // because chain cases can have simplified attention work flow. - bool is_chain = true; - int64_t sum_new_append_length = 0; - // - Construct the mn indptr array, which is the indptr of the mask size of each sequence. - tree_attn_mn_indptr_host_.push_back(0); - ICHECK_EQ(sequences.size(), cur_batch_size_); - ICHECK_EQ(cur_append_lengths_.size(), cur_batch_size_); - for (int i = 0; i < cur_batch_size_; ++i) { - int64_t append_length = cur_append_lengths_[i]; - // Update the token tree parent pointers. - sequences[i]->token_tree_parent_ptr = { - token_tree_parent_ptr->data + sum_new_append_length, - token_tree_parent_ptr->data + sum_new_append_length + cur_append_lengths_[i]}; - sum_new_append_length += cur_append_lengths_[i]; - - CHECK_LE(append_length, kTreeAttnMaxTreeSize) - << "The tree size is " << append_length << " which exceeds the maximum tree size limit " - << kTreeAttnMaxTreeSize; - tree_attn_mn_indptr_host_.push_back(tree_attn_mn_indptr_host_.back() + - append_length * append_length); - } - CHECK_EQ(token_tree_parent_ptr.size(), sum_new_append_length) - << "Invalid token tree size. The sum of \"append_lengths\" is " << sum_new_append_length - << " while there are " << token_tree_parent_ptr.size() - << " elements in \"token_tree_parent_ptr\"."; - - // - Construct the mask of each sequence. - for (int i = 0; i < cur_batch_size_; ++i) { - int64_t tree_size = sequences[i]->token_tree_parent_ptr.size(); - std::vector> mask; - std::vector depth; - mask.reserve(tree_size); - depth.reserve(tree_size); - sequences[i]->is_chain = true; - sequences[i]->accepted_indices_committed = false; - for (int64_t n = 0; n < tree_size; ++n) { - CHECK_LT(sequences[i]->token_tree_parent_ptr[n], n) - << "Invalid token tree. The parent of node " << n << " in tree " << i << " is " - << sequences[i]->token_tree_parent_ptr[n] << ", which is not smaller than " << n; - CHECK_GE(sequences[i]->token_tree_parent_ptr[n], -1) - << "Invalid token tree. The parent of node " << n << " in tree " << i << " is " - << sequences[i]->token_tree_parent_ptr[n]; - if (sequences[i]->token_tree_parent_ptr[n] != n - 1) { - // The parent of the current node is not the last node. - // Therefore the tree is not a chain. - sequences[i]->is_chain = false; - is_chain = false; + void ConstructTokenTreeMask(const std::vector& sequences, + const IntTuple& token_tree_parent_ptr, + const std::vector>& block_ids_on_depths, + const std::vector>& trailing_blocks) { + // Check whether the token tree of a sequence should be handled at the current depth. + auto check_for_sequence = [&](int seq_i, int depth) -> bool { + if (!append_before_attn_) { + return true; + } + // Check if the last block of the sequence is on the current depth. + if (block_ids_on_depths[depth][seq_i] == sequences[seq_i]->last_block_idx || + (depth + 1 == kPagedKVCacheMaxBlockDepth && !trailing_blocks[seq_i].empty())) { + return true; + } + return false; + }; + for (int d = 0; d < num_depths_; ++d) { + // We check if the token tree deteriorates to a chain, + // because chain cases can have simplified attention work flow. + ICHECK_LT(d, tree_attn_mask_host_.size()); + ICHECK_LT(d, tree_attn_mn_indptr_host_.size()); + HostMemoryVector& tree_attn_mn_indptr = tree_attn_mn_indptr_host_[d]; + HostMemoryVector& tree_attn_mask = tree_attn_mask_host_[d]; + + std::vector seq_in_current_depth(cur_batch_size_, false); + + tree_attn_mn_indptr.clear(); + tree_attn_mask.clear(); + std::fill(is_chain_on_depths_.begin(), is_chain_on_depths_.end(), true); + + bool is_chain = true; + // - Construct the mn indptr array, which is the indptr of the mask size of each sequence. + tree_attn_mn_indptr.push_back(0); + ICHECK_EQ(sequences.size(), cur_batch_size_); + ICHECK_EQ(cur_append_lengths_.size(), cur_batch_size_); + int64_t token_tree_parent_ptr_offset = 0; + for (int i = 0; i < cur_batch_size_; ++i) { + int64_t append_length = cur_append_lengths_[i]; + seq_in_current_depth[i] = check_for_sequence(i, d); + if (!seq_in_current_depth[i]) { + tree_attn_mn_indptr.push_back(tree_attn_mn_indptr.back()); + token_tree_parent_ptr_offset += append_length; // Skip the token tree of this sequence. + continue; + } + // Update the token tree parent pointers. + CHECK_LE(sequences[i]->token_tree_parent_ptr.size(), + global_block_pool_[sequences[i]->last_block_idx].seq_length) + << "The token tree size is larger than the sequence length of the last block."; + std::copy(token_tree_parent_ptr.begin() + token_tree_parent_ptr_offset, + token_tree_parent_ptr.begin() + token_tree_parent_ptr_offset + append_length, + std::back_inserter(sequences[i]->token_tree_parent_ptr)); + token_tree_parent_ptr_offset += append_length; + + CHECK_LE(sequences[i]->token_tree_parent_ptr.size(), kTreeAttnMaxTreeSize) + << "The tree size is " << append_length << " which exceeds the maximum tree size limit " + << kTreeAttnMaxTreeSize; + tree_attn_mn_indptr.push_back(tree_attn_mn_indptr.back() + + sequences[i]->token_tree_parent_ptr.size()); + } + CHECK_EQ(token_tree_parent_ptr.size(), token_tree_parent_ptr_offset) + << "Invalid token tree size. The sum of \"append_lengths\" is " + << token_tree_parent_ptr_offset << " while there are " << token_tree_parent_ptr.size() + << " elements in \"token_tree_parent_ptr\"."; + + // - Construct the mask of each sequence. + for (int i = 0; i < cur_batch_size_; ++i) { + if (!seq_in_current_depth[i]) { + continue; } + int64_t tree_size = sequences[i]->token_tree_parent_ptr.size(); + std::vector> mask; + std::vector depth; + mask.reserve(tree_size); + depth.reserve(tree_size); + sequences[i]->is_chain = true; + sequences[i]->accepted_indices_committed = false; + std::unordered_map> tree_parent_to_children; + std::vector tree_roots; + for (int n = 0; n < tree_size; ++n) { + CHECK_LT(sequences[i]->token_tree_parent_ptr[n], n) + << "Invalid token tree. The parent of node " << n << " in tree " << i << " is " + << sequences[i]->token_tree_parent_ptr[n] << ", which is not smaller than " << n; + CHECK_GE(sequences[i]->token_tree_parent_ptr[n], -1) + << "Invalid token tree. The parent of node " << n << " in tree " << i << " is " + << sequences[i]->token_tree_parent_ptr[n]; + if (sequences[i]->token_tree_parent_ptr[n] != n - 1) { + // The parent of the current node is not the last node. + // Therefore the tree is not a chain. + sequences[i]->is_chain = false; + is_chain = false; + } + tree_parent_to_children[sequences[i]->token_tree_parent_ptr[n]].push_back(n); - std::vector single_pos_mask; - if (sequences[i]->token_tree_parent_ptr[n] != -1) { - // The current node has a parent in the token tree. - single_pos_mask = {mask[sequences[i]->token_tree_parent_ptr[n]].begin(), - mask[sequences[i]->token_tree_parent_ptr[n]].end()}; - depth.push_back(depth[sequences[i]->token_tree_parent_ptr[n]] + 1); - } else { - // The current node is root in the token tree. - single_pos_mask.resize(tree_size, /*value=*/0); - depth.push_back(0); + if (sequences[i]->token_tree_parent_ptr[n] != -1) { + depth.push_back(depth[sequences[i]->token_tree_parent_ptr[n]] + 1); + } else { + depth.push_back(0); + tree_roots.push_back(n); + } + } + std::vector> tree_order(tree_size); + int order = 0; + std::function tree_dfs = [&order, &tree_order, &tree_parent_to_children, + &tree_dfs](int node) -> int { + tree_order[node].first = order++; + int upper_bound = tree_order[node].first + 1; + for (int child : tree_parent_to_children[node]) { + upper_bound = std::max(upper_bound, tree_dfs(child)); + } + tree_order[node].second = upper_bound; + return upper_bound; + }; + for (auto root : tree_roots) { + tree_dfs(root); } - single_pos_mask[n] = 1; - mask.push_back(single_pos_mask); - for (int32_t mask_val : single_pos_mask) { - tree_attn_mask_host_.push_back(mask_val); + for (int n = 0; n < tree_size; ++n) { + tree_attn_mask.push_back(tree_order[n].first); + tree_attn_mask.push_back(tree_order[n].second); } + sequences[i]->token_tree_node_depths = std::move(depth); + } + + is_chain_on_depths_[d] = is_chain; + + if (!append_before_attn_) { + break; } - sequences[i]->token_tree_node_depths = std::move(depth); } - return is_chain; } /*! @@ -2236,13 +2322,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } if (!append_before_attn_) { - if (is_chain_) { + if (is_chain_on_depths_[0]) { f_attention_prefill_ragged_begin_forward_.value()( temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(), cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_, num_kv_heads_, head_dim_, copy_stream_); - } else { - LOG(FATAL) << "Kernel BeginForward doesn't support tree attn."; } } for (int d = 0; d < num_depths_; ++d) { @@ -2285,7 +2369,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (!append_before_attn_) { // The first part of attention, which only involves the q and the newly appended k/v. is_first_kernel = false; - if (is_chain_) { + if (is_chain_on_depths_[0]) { // If the batch does not form a tree, use raggedness prefill kernel. f_attention_prefill_ragged_(q_data, cur_append_length_indptr_view_, k_data, v_data, cur_append_length_indptr_view_, q_rope_position_map_view_, @@ -2296,14 +2380,14 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { rotary_theta_, attn_score_scaling_factor); } else { // The batch requires tree attention. - ICHECK(tree_attn_mask_view_.defined()); - ICHECK(tree_attn_mn_indptr_view_.defined()); ICHECK(f_attention_prefill_with_tree_mask_.defined()) << "Function \"f_attention_prefill_with_tree_mask_\" is not defined."; + ICHECK(tree_attn_mask_view_[0].defined()); + ICHECK(tree_attn_mn_indptr_view_[0].defined()); f_attention_prefill_with_tree_mask_( q_data, cur_append_length_indptr_view_, k_data, v_data, cur_append_length_indptr_view_, - q_rope_position_map_view_, tree_attn_mn_indptr_view_, tree_attn_mask_view_, output, - merged_attn_scores_view_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, + q_rope_position_map_view_, tree_attn_mn_indptr_view_[0], tree_attn_mask_view_[0], + output, merged_attn_scores_view_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, attn_score_scaling_factor, cur_batch_size_); } } @@ -2321,7 +2405,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { attn_output = temp_attn_output_view_; attn_scores = temp_attn_scores_view_; } - if (use_decode_kernel_[d]) { + if (append_before_attn_ && !is_chain_on_depths_[d]) { + f_attention_prefill_with_tree_mask_paged_kv_( + /*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], + page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], + length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], q_rope_position_map_view_, + attn_output, attn_scores, + /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, + attn_score_scaling_factor, tree_attn_mn_indptr_view_[d], tree_attn_mask_view_[d]); + } else if (use_decode_kernel_[d]) { // Use decode kernel for depth d f_decode(/*depth=*/d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], length_info_on_depths_view_[d], @@ -2446,13 +2538,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { append_position_map_view_ = aux_data_manager_->CopyAppendPositionMapAsync(&append_position_map_host_); // 10. tree_attn_mask and tree_attn_mn_indptr - if (!is_chain_) { - tree_attn_mask_view_ = aux_data_manager_->CopyTreeAttnMaskAsync(&tree_attn_mask_host_); - tree_attn_mn_indptr_view_ = - aux_data_manager_->CopyTreeAttnMNIndptrAsync(&tree_attn_mn_indptr_host_); - } else { - tree_attn_mask_view_ = NDArray{nullptr}; - tree_attn_mn_indptr_view_ = NDArray{nullptr}; + for (int d = 0; d < num_depths_; ++d) { + if (!is_chain_on_depths_[d]) { + tree_attn_mask_view_[d] = + aux_data_manager_->CopyTreeAttnMaskOnDepthAsync(&tree_attn_mask_host_[d], d); + tree_attn_mn_indptr_view_[d] = + aux_data_manager_->CopyTreeAttnMNIndptrOnDepthAsync(&tree_attn_mn_indptr_host_[d], d); + } } // 11. Create view for temporary arrays for attention computation. temp_attn_output_view_ = temp_attn_output_device_.CreateView( @@ -2477,7 +2569,7 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj); TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK(args.size() == 27 || args.size() == 28) + CHECK(args.size() == 28 || args.size() == 29) << "Invalid number of KV cache constructor args."; ShapeTuple cache_config = args[0]; ShapeTuple layer_indptr_tuple = args[1]; @@ -2516,10 +2608,11 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") Optional f_debug_get_kv = args[24]; PackedFunc f_compact_copy = args[25]; PackedFunc f_attention_prefill_with_tree_mask = args[26]; + PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[27]; Optional rope_ext_factors = NullOpt; - if (args.size() >= 28 && args[27].IsObjectRef()) { - rope_ext_factors = args[27].AsObjectRef(); + if (args.size() >= 29 && args[28].IsObjectRef()) { + rope_ext_factors = args[28].AsObjectRef(); } CHECK_EQ(cache_config.size(), 5); @@ -2542,6 +2635,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), std::move(f_attention_prefill_with_tree_mask), + std::move(f_attention_prefill_with_tree_mask_paged_kv), std::move(f_attention_prefill_ragged_begin_forward), std::move(f_attention_prefill_ragged_end_forward), std::move(f_attention_prefill_begin_forward), std::move(f_attention_prefill_end_forward), @@ -2553,7 +2647,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK(args.size() == 21 || args.size() == 22) + CHECK(args.size() == 22 || args.size() == 23) << "Invalid number of KV cache constructor args."; ShapeTuple cache_config = args[0]; ShapeTuple layer_indptr_tuple = args[1]; @@ -2586,10 +2680,11 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") Optional f_debug_get_kv = args[18]; PackedFunc f_compact_copy = args[19]; PackedFunc f_attention_prefill_with_tree_mask = args[20]; + PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[21]; Optional rope_ext_factors = NullOpt; - if (args.size() >= 22 && args[21].IsObjectRef()) { - rope_ext_factors = args[21].AsObjectRef(); + if (args.size() >= 23 && args[22].IsObjectRef()) { + rope_ext_factors = args[22].AsObjectRef(); } CHECK_EQ(cache_config.size(), 5); @@ -2611,8 +2706,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") std::move(f_attention_prefill), std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), - std::move(f_attention_prefill_with_tree_mask), // - NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, // + std::move(f_attention_prefill_with_tree_mask), // + std::move(f_attention_prefill_with_tree_mask_paged_kv), // + NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, // std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page), std::move(f_debug_get_kv)); *rv = AttentionKVCache(std::move(n)); diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index c35b7062cdc22..5ab96caa9bc0b 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -36,6 +36,7 @@ _merge_state_inplace, llama_rope_with_position_map, tree_attn, + tree_attn_with_paged_kv_cache, ) from tvm.runtime import ShapeTuple @@ -74,6 +75,7 @@ fattn_decode_sliding_window = None fattn_prefill_ragged = None fattn_prefill_with_tree_mask = None +fattn_prefill_with_tree_mask_paged_kv_cache = None fmerge_state = None fsplit_rotary = None fattention_rotary = None @@ -86,7 +88,7 @@ def set_global_func(head_dim, dtype): global fpopn, fbegin_forward, fend_forward, fcommit_accepted_token_tree_nodes global fattention_with_fuse_qkv, fis_empty, fdebug_get_kv global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode - global fattn_prefill_ragged, fattn_prefill_with_tree_mask + global fattn_prefill_ragged, fattn_prefill_with_tree_mask, fattn_prefill_with_tree_mask_paged_kv_cache global fattn_prefill_sliding_window, fattn_decode_sliding_window global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page, fcompact_copy @@ -124,6 +126,9 @@ def set_global_func(head_dim, dtype): num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling, target ), tree_attn(num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling, target), + tree_attn_with_paged_kv_cache( + num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling, target + ), _merge_state_inplace(num_qo_heads, head_dim, dtype, target), llama_rope_with_position_map( rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype, rope_scaling @@ -146,6 +151,7 @@ def set_global_func(head_dim, dtype): fattn_decode_sliding_window, fattn_prefill_ragged, fattn_prefill_with_tree_mask, + fattn_prefill_with_tree_mask_paged_kv_cache, fmerge_state, fsplit_rotary, fcopy_single_page, @@ -185,6 +191,7 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): fcopy_cache, fcompact_copy, fattn_prefill_with_tree_mask, + fattn_prefill_with_tree_mask_paged_kv_cache, None, ) return cache @@ -206,7 +213,7 @@ class RopeMode(enum.IntEnum): params=itertools.chain( itertools.product( [64, 128], - ["float16", "float32"], + ["float32", "float16"], [RopeMode.NORMAL], [False], ), @@ -296,23 +303,26 @@ def apply_attention( cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) - assert (token_tree_parent_ptr_list is None) == (accepted_leaf_indices is None) flattened_token_tree_parent_ptr = None token_tree_node_depths_list: List[Optional[List[int]]] = [None for _ in batch] if token_tree_parent_ptr_list: assert len(token_tree_node_depths_list) == len(seq_ids) - assert len(accepted_leaf_indices) == len(seq_ids) + if accepted_leaf_indices is not None: + assert len(accepted_leaf_indices) == len(seq_ids) flattened_token_tree_parent_ptr = [] for i, (token_tree_parent_ptr, append_length) in enumerate( zip(token_tree_parent_ptr_list, append_lengths) ): - assert len(token_tree_parent_ptr) == append_length - flattened_token_tree_parent_ptr += token_tree_parent_ptr + assert len(token_tree_parent_ptr) >= append_length + # parent pointer for the last `append_length` nodes (the new tokens) + append_token_tree_parent_ptr = token_tree_parent_ptr[-append_length:] + flattened_token_tree_parent_ptr += append_token_tree_parent_ptr token_tree_node_depths = [] for parent in token_tree_parent_ptr: token_tree_node_depths.append( 0 if parent == -1 else token_tree_node_depths[parent] + 1 ) + # depth of each node in the tree (this contains more than the last `append_length` nodes) token_tree_node_depths_list[i] = token_tree_node_depths fbegin_forward( @@ -337,6 +347,11 @@ def apply_attention( new_v = np.random.rand(num_layers, append_length, num_kv_heads, head_dim).astype(dtype) q_array.append(new_q) + rope_offset = cached_k[seq_id].shape[1] + if token_tree_parent_ptr_list is not None: + prev_tree_size = len(token_tree_parent_ptr_list[i]) - append_length + assert prev_tree_size >= 0 + rope_offset -= prev_tree_size cached_k[seq_id] = np.concatenate( [ cached_k[seq_id], @@ -347,10 +362,12 @@ def apply_attention( if rope_mode != RopeMode.NORMAL else f_apply_rotary( new_k[l], - cached_k[seq_id].shape[1], + rope_offset, rope_scale, rope_theta, - token_tree_node_depths_list[i], + token_tree_node_depths_list[i][-append_length:] + if token_tree_node_depths_list[i] is not None + else None, ) ) for l in range(num_layers) @@ -379,7 +396,11 @@ def apply_attention( for i, (seq_id, append_length) in enumerate(batch): assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >= append_length - rope_offset = cached_k[seq_id].shape[1] - append_length + rope_offset = cached_k[seq_id].shape[1] + if token_tree_parent_ptr_list is not None: + rope_offset -= len(token_tree_parent_ptr_list[i]) + else: + rope_offset -= append_length q_seq = ( q_array[i][layer_id] if rope_mode == RopeMode.NONE @@ -388,7 +409,9 @@ def apply_attention( rope_offset, rope_scale, rope_theta, - token_tree_node_depths_list[i], + token_tree_node_depths_list[i][-append_length:] + if token_tree_node_depths_list[i] is not None + else None, ) ).transpose(1, 0, 2) k_seq = ( @@ -422,15 +445,16 @@ def apply_attention( np.full_like(softmax_input, np.finfo("float32").max), k=length_diff ) + np.triu(np.full_like(softmax_input, np.finfo("float32").min), k=length_diff + 1) if token_tree_parent_ptr_list is not None: + tree_size = len(token_tree_parent_ptr_list[i]) tree_mask = np.full( - (append_length, append_length), np.finfo("float32").min, dtype="float32" + (tree_size, tree_size), np.finfo("float32").min, dtype="float32" ) for i, parent in enumerate(token_tree_parent_ptr_list[i]): if parent != -1: tree_mask[i] = tree_mask[parent] tree_mask[i, i] = np.finfo("float32").max tree_mask = np.broadcast_to(tree_mask, (num_qo_heads, *tree_mask.shape)) - mask[:, :, length_diff:] = tree_mask + mask[:, :, -tree_size:] = tree_mask[:, -append_length:, :] softmax_input = np.minimum(softmax_input, mask) @@ -846,9 +870,12 @@ def test_paged_attention_kv_cache_sliding_window_fork(kv_cache_and_config): @tvm.testing.requires_cuda def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config): kv_cache, rope_mode, support_sliding_window = kv_cache_and_config - if support_sliding_window and rope_mode == RopeMode.NORMAL: + if support_sliding_window: # Normal RoPE mode under sliding window settings is not supported. return + if rope_mode == RopeMode.INLINE: + # Inline RoPE mode is not supported for tree attention. + return fclear(kv_cache) cached_k = {} @@ -899,6 +926,29 @@ def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config): for _ in range(5): apply_attention(kv_cache, rope_mode, [(0, 1), (1, 1), (2, 1), (3, 1)], cached_k, cached_v) + # Test the cases of tree attn with cached kv. + fclear(kv_cache) + cached_k = {} + cached_v = {} + # Prefill 4 sequences + apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)], cached_k, cached_v) + # Do 5 rounds of tree decode. + num_seq = 4 + for i in range(5): + num_leaf_nodes = 2**i + parent_ptr = [(k - 1) // 2 for k in range(0, 2 * num_leaf_nodes - 1)] + apply_attention( + kv_cache, + rope_mode, + [(seq_id, num_leaf_nodes) for seq_id in range(num_seq)], + cached_k, + cached_v, + token_tree_parent_ptr_list=[parent_ptr for _ in range(num_seq)], + accepted_leaf_indices=( + None if i != 4 else [2, 6, -1, 4] + ), # Leaf nodes are committed all at once at the end. + ) + if __name__ == "__main__": HEAD_DIMS = [64, 128]