Skip to content

Commit

Permalink
[Llama] Support batched prefill
Browse files Browse the repository at this point in the history
This PR supports the Llama modeling with batched prefill, which
can bring higher throughput for the overall prefill process in
serving.

Besides, the PR splits the attention function used in batching
settings into two separate ones, so that we do not dispatch to
the prefill/decode attention functions at runtime.
  • Loading branch information
MasterJH5574 committed Nov 11, 2023
1 parent fab4486 commit 3d6d53d
Showing 1 changed file with 47 additions and 33 deletions.
80 changes: 47 additions & 33 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,8 @@ def __init__(self, config: LlamaConfig):
super().__init__(config)
ctx_mod = relax.BlockBuilder.current().get()
self.kv_cache_transpose_append = ctx_mod.get_global_var("kv_cache_transpose_append")
self.attention_compute = ctx_mod.get_global_var("attention")
self.attention_compute_prefill = ctx_mod.get_global_var("attention_prefill")
self.attention_compute_decode = ctx_mod.get_global_var("attention_decode")

def attention_fwd(
self,
Expand Down Expand Up @@ -416,12 +417,13 @@ def attention_fwd(
)

f_kv_cache_attention = relax.extern("vm.builtin.paged_attention_kv_cache_attention")
is_decode = query_states.struct_info.shape[1] == 1
attn_output = nn.emit(
relax.call_dps_packed(
f_kv_cache_attention,
[
past_key_values,
self.attention_compute,
self.attention_compute_decode if is_decode else self.attention_compute_prefill,
query_states,
relax.PrimValue(layer_id),
True,
Expand Down Expand Up @@ -825,6 +827,7 @@ def forward(
inputs: relax.Expr,
all_seq_len_shape: Optional[relax.Expr],
past_key_values: relax.Expr,
logit_positions: Optional[relax.Expr] = None,
):
hidden_states, key_value_cache = self.model(
inputs=inputs,
Expand All @@ -840,7 +843,13 @@ def te_slicing(x: te.Tensor):
name="slice",
)

logits = self.lm_head(nn.emit_te(te_slicing, hidden_states, primfunc_name_hint="slice"))
if hidden_states.struct_info.shape[1] != 1:
if logit_positions is None:
hidden_states = nn.emit_te(te_slicing, hidden_states, primfunc_name_hint="slice")
else:
hidden_states = relax.op.take(hidden_states, logit_positions, axis=1)
logits = self.lm_head(hidden_states)

if logits.struct_info.dtype != "float32":
logits = nn.emit(relax.op.astype(logits, "float32"))

Expand All @@ -866,13 +875,12 @@ def create_embed_func(
) -> None:
func_name = "embed"

bsz = tvm.tir.Var("nseq", "int64")
seq_len = tvm.tir.Var("n", "int64")
with bb.function(func_name):
model = LlamaEmbedTokensWrapper(config, tvm.tir.Var("vocab_size", "int64"))
param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind)

input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids")
input_ids = nn.Placeholder((1, seq_len), dtype="int32", name="input_ids")
with bb.dataflow():
inputs_embeds = model(input_ids)
params = [input_ids] + model.parameters()
Expand Down Expand Up @@ -940,8 +948,8 @@ def create_prefill_func_for_batching(
) -> None:
func_name = "prefill_with_embed"

bsz = 1
seq_len = tvm.tir.Var("n", "int64")
bsz = tir.Var("nseq", "int64")
total_seq_len = tvm.tir.Var("n", "int64")
hidden_size = config.hidden_size
with bb.function(func_name):
model = LlamaForCausalLM(
Expand All @@ -950,22 +958,24 @@ def create_prefill_func_for_batching(
param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind)

inputs = nn.Placeholder(
(bsz, seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds"
(1, total_seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds"
)
logit_pos = nn.Placeholder((bsz,), dtype="int32", name="logit_positions")
past_key_values = relax.Var("kv_cache", relax.ObjectStructInfo())
with bb.dataflow():
logits, key_value_cache = model(
inputs,
all_seq_len_shape=None,
past_key_values=past_key_values,
logit_positions=logit_pos,
)
params = [inputs, past_key_values] + model.parameters()
params = [inputs, logit_pos, past_key_values] + model.parameters()
gv = bb.emit_output((logits, key_value_cache))
bb.emit_func_output(gv, params)

mod = bb.get()
gv = mod.get_global_var(func_name)
bb.update_func(gv, mod[gv].with_attr("num_input", 2))
bb.update_func(gv, mod[gv].with_attr("num_input", 3))


def create_decoding_func_for_single_seq(
Expand Down Expand Up @@ -1092,6 +1102,7 @@ def create_paged_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> N
relax.PrimValue(num_key_value_heads),
relax.PrimValue(head_dim),
zeros,
relax.PrimValue(0),
],
sinfo_args=[relax.ObjectStructInfo()],
)
Expand Down Expand Up @@ -1129,9 +1140,13 @@ def create_softmax_func_for_batching(bb: relax.BlockBuilder, config: LlamaConfig
bb.emit_func_output(gv, [logits, temperature])


def emit_paged_kv_cache_op(bb: relax.BlockBuilder, dtype: str) -> None:
def emit_paged_kv_cache_op(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
from tvm.script import tir as T

num_layers = config.num_hidden_layers
num_heads = config.num_key_value_heads
head_dim = config.hidden_size // config.num_attention_heads

# fmt: off
@T.prim_func
def kv_cache_transpose_append(
Expand All @@ -1143,31 +1158,28 @@ def kv_cache_transpose_append(
var_last_page_offset: T.handle,
var_append_length_indptr: T.handle,
var_pos2seqidx: T.handle,
layer_id: T.int32,
layer_id: T.int64,
):
nseq = T.int32()
ntoken = T.int32()
nhead = T.int32()
nfeat = T.int32()
nlayer = T.int32()
npage = T.int32()
page_size = T.int32()
num_pages = T.int32()

pages = T.match_buffer(var_pages, (num_pages, nlayer, 2, nhead, page_size, nfeat), dtype)
k_data = T.match_buffer(var_k_data, (ntoken, nhead, nfeat), dtype)
v_data = T.match_buffer(var_v_data, (ntoken, nhead, nfeat), dtype)
nseq = T.int64()
ntoken = T.SizeVar("ntoken", "int64")
npage = T.int64()
page_size = T.SizeVar("page_size", "int64")
num_pages = T.int64()

pages = T.match_buffer(var_pages, (num_pages, num_layers, 2, num_heads, page_size, head_dim), config.dtype)
k_data = T.match_buffer(var_k_data, (ntoken, num_heads, head_dim), config.dtype)
v_data = T.match_buffer(var_v_data, (ntoken, num_heads, head_dim), config.dtype)
last_page_offset = T.match_buffer(var_last_page_offset, (nseq,), "int32")
page_table_indptr = T.match_buffer(var_page_table_indptr, (nseq + 1,), "int32")
page_table_values = T.match_buffer(var_page_table_values, (npage,), "int32")
append_length_indptr = T.match_buffer(var_append_length_indptr, (nseq + 1,), "int32")
pos2seqidx = T.match_buffer(var_pos2seqidx, (ntoken,), "int32")

for global_pos, h, f in T.grid(ntoken, nhead, nfeat):
for global_pos, h, f in T.grid(ntoken, num_heads, head_dim):
with T.block("k_transpose_append"):
vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
seq_idx = pos2seqidx[vgpos]
seqlen: T.int32 = (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx]
seq_idx: T.int64 = T.Cast("int64", pos2seqidx[vgpos])
seqlen: T.int64 = T.Cast("int64", (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx])
pages[
page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)],
layer_id,
Expand All @@ -1178,8 +1190,8 @@ def kv_cache_transpose_append(
] = k_data[vgpos, vh, vf]
with T.block("v_transpose_append"):
vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
seq_idx = pos2seqidx[vgpos]
seqlen: T.int32 = (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx]
seq_idx: T.int64 = T.Cast("int64", pos2seqidx[vgpos])
seqlen: T.int64 = T.Cast("int64", (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx])
pages[
page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)],
layer_id,
Expand All @@ -1191,8 +1203,8 @@ def kv_cache_transpose_append(
# fmt: on

bb.add_func(kv_cache_transpose_append, "kv_cache_transpose_append")
# Todo: integrating attention TIR func/kernel.
bb.add_func(relax.extern("attention_func"), "attention")
bb.add_func(relax.extern("paged_kv_cache.attention_kernel_prefill"), "attention_prefill")
bb.add_func(relax.extern("paged_kv_cache.attention_kernel_decode"), "attention_decode")


def setup_params(mod, param_manager, dtype, config, args):
Expand Down Expand Up @@ -1318,7 +1330,9 @@ def get_model(args, hf_config):
build_model_only=args.build_model_only,
)
else:
raise Exception("The model config should contain information about maximum sequence length.")
raise Exception(
"The model config should contain information about maximum sequence length."
)

# If there is a user-provided maximum sequence length, override hf config.
if args.max_seq_len != -1:
Expand All @@ -1331,7 +1345,7 @@ def get_model(args, hf_config):
create_embed_func(bb, param_manager, config, args.quantization)

if enable_batching:
emit_paged_kv_cache_op(bb, dtype)
emit_paged_kv_cache_op(bb, config)
create_prefill_func_for_batching(bb, param_manager, config, args.quantization)
create_decoding_func_for_batching(bb, param_manager, config, args.quantization)
create_paged_kv_cache_func(bb, config)
Expand Down

0 comments on commit 3d6d53d

Please sign in to comment.