Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KVCache] Add max num threads awareness to KVCache kernels #1822

Merged
merged 5 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions python/mlc_chat/model/model_preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,30 @@
"context_window_size": 2048,
"prefill_chunk_size": 2048,
},
"tinyllama_1b_chat_v1.0": {
"architectures": ["LlamaForCausalLM"],
"attention_bias": False,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 5632,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 22,
"num_key_value_heads": 4,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": None,
"rope_theta": 10000.0,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.35.0",
"use_cache": True,
"vocab_size": 32000,
},
"mistral_7b": {
"architectures": ["MistralForCausalLM"],
"bos_token_id": 1,
Expand Down
97 changes: 61 additions & 36 deletions python/mlc_chat/nn/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
rope_freq,
)

from ..support.max_thread_check import check_max_num_threads


class RopeMode(enum.IntEnum):
"""The RoPE mode of the Paged KV cache.
Expand Down Expand Up @@ -477,10 +479,20 @@ def _attention_prefill(h_kv, h_q, d, dtype, target: Target): # pylint: disable=
group_size = h_q // h_kv
sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1))

bdx = 32
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not exactly sure about substituting 32 with bdx in the entire _attention_prefill() and _attention_prefill_ragged(); is it the intended semantics?

num_warps = 4
tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16
L_per_cta = tile_x // group_size

# Otherwise we would exceed maxComputeWorkgroupStorageSize
if (
str(target.kind) == "webgpu"
and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4
):
tile_z = 8
num_warps = 2
check_max_num_threads(target, bdx=bdx, bdy=num_warps, bdz=1)

def mask(causal, row, col, kv_len, qo_len):
return T.if_then_else(
causal > 0,
Expand Down Expand Up @@ -529,7 +541,7 @@ def batch_prefill_paged_kv(
for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"):
for lby in T.thread_binding(h_kv, thread="blockIdx.y"):
for lty in T.thread_binding(num_warps, thread="threadIdx.y"):
for ltx in T.thread_binding(32, thread="threadIdx.x"):
for ltx in T.thread_binding(bdx, thread="threadIdx.x"):
with T.block("attn"):
bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
T.reads()
Expand All @@ -553,9 +565,9 @@ def batch_prefill_paged_kv(
m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared")
d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared")

m_new = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local")
m_prev = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local")
d_new = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local")
m_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local")
m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local")
d_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local")

## get tile_no, batch_idx, batch_tiles, batch_rows
tile_id[0] = bx
Expand Down Expand Up @@ -588,8 +600,8 @@ def batch_prefill_paged_kv(
T.tvm_storage_sync("shared")

# init states
for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)):
row: T.int32 = i * 32 * num_warps + ty * 32 + tx
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
if row < tile_x:
m_smem[row] = -5e4
d_smem[row] = 1.0
Expand Down Expand Up @@ -667,8 +679,8 @@ def batch_prefill_paged_kv(
T.tvm_storage_sync("shared")

# Update S, m, d
for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)):
row: T.int32 = i * 32 * num_warps + ty * 32 + tx
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
if row < tile_x:
with T.block("update1"):
m_prev[i] = m_smem[row]
Expand All @@ -683,8 +695,8 @@ def batch_prefill_paged_kv(
m_new[i] = T.max(m_new[i], S_smem[row, j])
d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])

for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)):
row: T.int32 = i * 32 * num_warps + ty * 32 + tx
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
with T.block("update"):
for j in T.serial(tile_z):
# this is to avoid sync inside condition branch
Expand All @@ -698,8 +710,8 @@ def batch_prefill_paged_kv(
else:
S_smem[row, j] = T.exp2(-5e4 - m_new[i])

for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)):
row: T.int32 = i * 32 * num_warps + ty * 32 + tx
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
if row < tile_x:
with T.block("update"):
for j in T.serial(tile_z):
Expand Down Expand Up @@ -752,7 +764,7 @@ def apply_to_qkv_load(sch: tir.Schedule, block):
loop_x, loop_y = sch.get_loops(block)[-2:]
loop = sch.fuse(loop_x, loop_y)
_, ty, tx, vec = sch.split(
loop, factors=[None, num_warps, 32, LOAD_VEC], preserve_unit_iters=True
loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True
)
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
Expand All @@ -764,7 +776,7 @@ def apply_to_so_ewise(sch: tir.Schedule, block, tile):
yo, yi = sch.split(loop_y, factors=[None, tile[1]])
sch.reorder(xo, yo, xi, yi)
t = sch.fuse(xo, yo)
ty, tx = sch.split(t, factors=[num_warps, 32])
ty, tx = sch.split(t, factors=[num_warps, bdx])
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")

Expand All @@ -776,7 +788,7 @@ def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument
yo, yi = sch.split(loop_y, factors=[None, tile[1]])
sch.reorder(xo, yo, xi, yi)
t = sch.fuse(xo, yo)
ty, tx = sch.split(t, factors=[num_warps, 32])
ty, tx = sch.split(t, factors=[num_warps, bdx])
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")

Expand All @@ -789,12 +801,12 @@ def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument

def apply_to_md(sch, block):
loop = sch.get_loops(block)[-1]
_, ty, tx = sch.split(loop, factors=[None, num_warps, 32])
_, ty, tx = sch.split(loop, factors=[None, num_warps, bdx])
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")

tile_s = get_tile_size(tile_x, tile_z, 32 * num_warps)
tile_o = get_tile_size(tile_x, tile_y, 32 * num_warps)
tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps)
tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps)
apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True)
apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False)
apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s)
Expand Down Expand Up @@ -833,6 +845,7 @@ def _attention_decode(
bdz = threads_per_CTA // (bdx * bdy)
tile_size_per_bdx = 2 if GROUP_SIZE == 1 else 1
log2e = math.log2(math.exp(1))
check_max_num_threads(target, bdx=bdx, bdy=bdy, bdz=bdz)

# pylint: disable=line-too-long,too-many-arguments,too-many-branches
# fmt: off
Expand Down Expand Up @@ -1049,6 +1062,9 @@ def _merge_state_inplace(
VEC_SIZE = min(max(8 // v_dtype_bytes, head_dim // 32), 4)
bdx = head_dim // VEC_SIZE
bdy = num_heads
while bdx * bdy > target.max_num_threads and bdy > 1:
bdy //= 2
check_max_num_threads(target, bdx=bdx, bdy=bdy, bdz=1)
CharlieFRuan marked this conversation as resolved.
Show resolved Hide resolved

@T.prim_func
def merge_state_inplace(
Expand Down Expand Up @@ -1119,10 +1135,19 @@ def _attention_prefill_ragged(
group_size = h_q // h_kv
sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1))

bdx = 32
num_warps = 4
tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16
L_per_cta = tile_x // group_size

# Otherwise we would exceed maxComputeWorkgroupStorageSize
if (
str(target.kind) == "webgpu"
and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4
):
tile_z = 8
num_warps = 2

def mask(causal, row, col, kv_len, qo_len):
return T.if_then_else(
causal > 0,
Expand Down Expand Up @@ -1166,7 +1191,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-bran
for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"):
for lby in T.thread_binding(h_kv, thread="blockIdx.y"):
for lty in T.thread_binding(num_warps, thread="threadIdx.y"):
for ltx in T.thread_binding(32, thread="threadIdx.x"):
for ltx in T.thread_binding(bdx, thread="threadIdx.x"):
with T.block("attn"):
bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
T.reads()
Expand All @@ -1190,9 +1215,9 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-bran
m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared")
d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared")

m_new = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local")
m_prev = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local")
d_new = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local")
m_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local")
m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local")
d_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local")

## get tile_no, batch_idx, batch_tiles, batch_rows
tile_id[0] = bx
Expand All @@ -1218,8 +1243,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-bran
T.tvm_storage_sync("shared")

# init states
for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)):
row: T.int32 = i * 32 * num_warps + ty * 32 + tx
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
if row < tile_x:
m_smem[row] = -5e4
d_smem[row] = 1.0
Expand Down Expand Up @@ -1294,8 +1319,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-bran
T.tvm_storage_sync("shared")

# Update S, m, d
for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)):
row: T.int32 = i * 32 * num_warps + ty * 32 + tx
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
if row < tile_x:
with T.block("update1"):
m_prev[i] = m_smem[row]
Expand All @@ -1310,8 +1335,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-bran
m_new[i] = T.max(m_new[i], S_smem[row, j])
d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])

for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)):
row: T.int32 = i * 32 * num_warps + ty * 32 + tx
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
with T.block("update"):
for j in T.serial(tile_z):
# this is to avoid sync inside condition branch
Expand All @@ -1325,8 +1350,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-bran
else:
S_smem[row, j] = T.exp2(-5e4 - m_new[i])

for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)):
row: T.int32 = i * 32 * num_warps + ty * 32 + tx
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
if row < tile_x:
with T.block("update"):
for j in T.serial(tile_z):
Expand Down Expand Up @@ -1379,7 +1404,7 @@ def apply_to_qkv_load(sch: tir.Schedule, block):
loop_x, loop_y = sch.get_loops(block)[-2:]
loop = sch.fuse(loop_x, loop_y)
_, ty, tx, vec = sch.split(
loop, factors=[None, num_warps, 32, LOAD_VEC], preserve_unit_iters=True
loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True
)
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
Expand All @@ -1391,7 +1416,7 @@ def apply_to_so_ewise(sch: tir.Schedule, block, tile):
yo, yi = sch.split(loop_y, factors=[None, tile[1]])
sch.reorder(xo, yo, xi, yi)
t = sch.fuse(xo, yo)
ty, tx = sch.split(t, factors=[num_warps, 32])
ty, tx = sch.split(t, factors=[num_warps, bdx])
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")

Expand All @@ -1403,7 +1428,7 @@ def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument
yo, yi = sch.split(loop_y, factors=[None, tile[1]])
sch.reorder(xo, yo, xi, yi)
t = sch.fuse(xo, yo)
ty, tx = sch.split(t, factors=[num_warps, 32])
ty, tx = sch.split(t, factors=[num_warps, bdx])
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")

Expand All @@ -1416,12 +1441,12 @@ def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument

def apply_to_md(sch, block):
loop = sch.get_loops(block)[-1]
_, ty, tx = sch.split(loop, factors=[None, num_warps, 32])
_, ty, tx = sch.split(loop, factors=[None, num_warps, bdx])
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")

tile_s = get_tile_size(tile_x, tile_z, 32 * num_warps)
tile_o = get_tile_size(tile_x, tile_y, 32 * num_warps)
tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps)
tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps)
apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True)
apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False)
apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s)
Expand Down
21 changes: 15 additions & 6 deletions python/mlc_chat/op/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from tvm.script import tir as T
from tvm.target import Target

from ..support.max_thread_check import check_max_num_threads

# pylint: disable=invalid-name


Expand Down Expand Up @@ -313,6 +315,13 @@ def llama_inplace_rope(
if rotary_dim is None:
rotary_dim = head_dim

VEC_SIZE = 4
bdx = (head_dim + VEC_SIZE - 1) // VEC_SIZE # T.ceildiv(head_dim, VEC_SIZE)
bdy = 32
while bdx * bdy > target.max_num_threads and bdy > 1:
bdy //= 2
check_max_num_threads(target, bdx=bdx, bdy=bdy, bdz=1)

def _rope(
x: T.Buffer,
s: tir.Var,
Expand Down Expand Up @@ -359,12 +368,12 @@ def tir_rotary( # pylint: disable=too-many-locals
instance_offset: T.int32 = append_len_indptr[b]
rope_offset: T.int32 = rope_offsets[b]
append_len: T.int32 = append_len_indptr[b + 1] - append_len_indptr[b]
for s0 in range(T.ceildiv(append_len, 32)):
for s1 in T.thread_binding(32, thread="threadIdx.y"):
for d0 in T.thread_binding(T.ceildiv(head_dim, 4), thread="threadIdx.x"):
for d1 in T.vectorized(4):
s: T.int32 = s0 * 32 + s1
d: T.int32 = d0 * 4 + d1
for s0 in range(T.ceildiv(append_len, bdy)):
for s1 in T.thread_binding(bdy, thread="threadIdx.y"):
for d0 in T.thread_binding(bdx, thread="threadIdx.x"):
for d1 in T.vectorized(VEC_SIZE):
s: T.int32 = s0 * bdy + s1
d: T.int32 = d0 * VEC_SIZE + d1
if s < append_len and d < rotary_dim:
if h < num_q_heads:
q[s + instance_offset, h, d] = _rope(q, s, h, d, rope_offset, instance_offset)
Expand Down
14 changes: 14 additions & 0 deletions python/mlc_chat/support/max_thread_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Helper functions for checking max num thread."""

from tvm.target import Target


def check_max_num_threads(target: Target, bdx: int, bdy: int, bdz: int):
"""Check whether max num threads exceeded given a target."""
assert (
bdx * bdy * bdz <= target.max_num_threads
), f"{target.kind} max num threads exceeded: {bdx}*{bdy}*{bdz}>{target.max_num_threads}"

if str(target.kind) != "webgpu":
CharlieFRuan marked this conversation as resolved.
Show resolved Hide resolved
# https://gpuweb.github.io/gpuweb/#dom-supported-limits-maxcomputeworkgroupsizez
assert bdz <= 64, f"webgpu's z dimension cannot exceed 64, but got bdz={bdz}"
Loading