From 3d6d53d831af138e0b68e2c2c70100bfcda7fb9f Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 11 Nov 2023 13:48:08 -0500 Subject: [PATCH] [Llama] Support batched prefill This PR supports the Llama modeling with batched prefill, which can bring higher throughput for the overall prefill process in serving. Besides, the PR splits the attention function used in batching settings into two separate ones, so that we do not dispatch to the prefill/decode attention functions at runtime. --- mlc_llm/relax_model/llama.py | 80 +++++++++++++++++++++--------------- 1 file changed, 47 insertions(+), 33 deletions(-) diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 49ff7862fa..07531c7c33 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -387,7 +387,8 @@ def __init__(self, config: LlamaConfig): super().__init__(config) ctx_mod = relax.BlockBuilder.current().get() self.kv_cache_transpose_append = ctx_mod.get_global_var("kv_cache_transpose_append") - self.attention_compute = ctx_mod.get_global_var("attention") + self.attention_compute_prefill = ctx_mod.get_global_var("attention_prefill") + self.attention_compute_decode = ctx_mod.get_global_var("attention_decode") def attention_fwd( self, @@ -416,12 +417,13 @@ def attention_fwd( ) f_kv_cache_attention = relax.extern("vm.builtin.paged_attention_kv_cache_attention") + is_decode = query_states.struct_info.shape[1] == 1 attn_output = nn.emit( relax.call_dps_packed( f_kv_cache_attention, [ past_key_values, - self.attention_compute, + self.attention_compute_decode if is_decode else self.attention_compute_prefill, query_states, relax.PrimValue(layer_id), True, @@ -825,6 +827,7 @@ def forward( inputs: relax.Expr, all_seq_len_shape: Optional[relax.Expr], past_key_values: relax.Expr, + logit_positions: Optional[relax.Expr] = None, ): hidden_states, key_value_cache = self.model( inputs=inputs, @@ -840,7 +843,13 @@ def te_slicing(x: te.Tensor): name="slice", ) - logits = self.lm_head(nn.emit_te(te_slicing, hidden_states, primfunc_name_hint="slice")) + if hidden_states.struct_info.shape[1] != 1: + if logit_positions is None: + hidden_states = nn.emit_te(te_slicing, hidden_states, primfunc_name_hint="slice") + else: + hidden_states = relax.op.take(hidden_states, logit_positions, axis=1) + logits = self.lm_head(hidden_states) + if logits.struct_info.dtype != "float32": logits = nn.emit(relax.op.astype(logits, "float32")) @@ -866,13 +875,12 @@ def create_embed_func( ) -> None: func_name = "embed" - bsz = tvm.tir.Var("nseq", "int64") seq_len = tvm.tir.Var("n", "int64") with bb.function(func_name): model = LlamaEmbedTokensWrapper(config, tvm.tir.Var("vocab_size", "int64")) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") + input_ids = nn.Placeholder((1, seq_len), dtype="int32", name="input_ids") with bb.dataflow(): inputs_embeds = model(input_ids) params = [input_ids] + model.parameters() @@ -940,8 +948,8 @@ def create_prefill_func_for_batching( ) -> None: func_name = "prefill_with_embed" - bsz = 1 - seq_len = tvm.tir.Var("n", "int64") + bsz = tir.Var("nseq", "int64") + total_seq_len = tvm.tir.Var("n", "int64") hidden_size = config.hidden_size with bb.function(func_name): model = LlamaForCausalLM( @@ -950,22 +958,24 @@ def create_prefill_func_for_batching( param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) inputs = nn.Placeholder( - (bsz, seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds" + (1, total_seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds" ) + logit_pos = nn.Placeholder((bsz,), dtype="int32", name="logit_positions") past_key_values = relax.Var("kv_cache", relax.ObjectStructInfo()) with bb.dataflow(): logits, key_value_cache = model( inputs, all_seq_len_shape=None, past_key_values=past_key_values, + logit_positions=logit_pos, ) - params = [inputs, past_key_values] + model.parameters() + params = [inputs, logit_pos, past_key_values] + model.parameters() gv = bb.emit_output((logits, key_value_cache)) bb.emit_func_output(gv, params) mod = bb.get() gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 2)) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) def create_decoding_func_for_single_seq( @@ -1092,6 +1102,7 @@ def create_paged_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> N relax.PrimValue(num_key_value_heads), relax.PrimValue(head_dim), zeros, + relax.PrimValue(0), ], sinfo_args=[relax.ObjectStructInfo()], ) @@ -1129,9 +1140,13 @@ def create_softmax_func_for_batching(bb: relax.BlockBuilder, config: LlamaConfig bb.emit_func_output(gv, [logits, temperature]) -def emit_paged_kv_cache_op(bb: relax.BlockBuilder, dtype: str) -> None: +def emit_paged_kv_cache_op(bb: relax.BlockBuilder, config: LlamaConfig) -> None: from tvm.script import tir as T + num_layers = config.num_hidden_layers + num_heads = config.num_key_value_heads + head_dim = config.hidden_size // config.num_attention_heads + # fmt: off @T.prim_func def kv_cache_transpose_append( @@ -1143,31 +1158,28 @@ def kv_cache_transpose_append( var_last_page_offset: T.handle, var_append_length_indptr: T.handle, var_pos2seqidx: T.handle, - layer_id: T.int32, + layer_id: T.int64, ): - nseq = T.int32() - ntoken = T.int32() - nhead = T.int32() - nfeat = T.int32() - nlayer = T.int32() - npage = T.int32() - page_size = T.int32() - num_pages = T.int32() - - pages = T.match_buffer(var_pages, (num_pages, nlayer, 2, nhead, page_size, nfeat), dtype) - k_data = T.match_buffer(var_k_data, (ntoken, nhead, nfeat), dtype) - v_data = T.match_buffer(var_v_data, (ntoken, nhead, nfeat), dtype) + nseq = T.int64() + ntoken = T.SizeVar("ntoken", "int64") + npage = T.int64() + page_size = T.SizeVar("page_size", "int64") + num_pages = T.int64() + + pages = T.match_buffer(var_pages, (num_pages, num_layers, 2, num_heads, page_size, head_dim), config.dtype) + k_data = T.match_buffer(var_k_data, (ntoken, num_heads, head_dim), config.dtype) + v_data = T.match_buffer(var_v_data, (ntoken, num_heads, head_dim), config.dtype) last_page_offset = T.match_buffer(var_last_page_offset, (nseq,), "int32") page_table_indptr = T.match_buffer(var_page_table_indptr, (nseq + 1,), "int32") page_table_values = T.match_buffer(var_page_table_values, (npage,), "int32") append_length_indptr = T.match_buffer(var_append_length_indptr, (nseq + 1,), "int32") pos2seqidx = T.match_buffer(var_pos2seqidx, (ntoken,), "int32") - for global_pos, h, f in T.grid(ntoken, nhead, nfeat): + for global_pos, h, f in T.grid(ntoken, num_heads, head_dim): with T.block("k_transpose_append"): vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - seq_idx = pos2seqidx[vgpos] - seqlen: T.int32 = (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx] + seq_idx: T.int64 = T.Cast("int64", pos2seqidx[vgpos]) + seqlen: T.int64 = T.Cast("int64", (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx]) pages[ page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)], layer_id, @@ -1178,8 +1190,8 @@ def kv_cache_transpose_append( ] = k_data[vgpos, vh, vf] with T.block("v_transpose_append"): vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - seq_idx = pos2seqidx[vgpos] - seqlen: T.int32 = (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx] + seq_idx: T.int64 = T.Cast("int64", pos2seqidx[vgpos]) + seqlen: T.int64 = T.Cast("int64", (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx]) pages[ page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)], layer_id, @@ -1191,8 +1203,8 @@ def kv_cache_transpose_append( # fmt: on bb.add_func(kv_cache_transpose_append, "kv_cache_transpose_append") - # Todo: integrating attention TIR func/kernel. - bb.add_func(relax.extern("attention_func"), "attention") + bb.add_func(relax.extern("paged_kv_cache.attention_kernel_prefill"), "attention_prefill") + bb.add_func(relax.extern("paged_kv_cache.attention_kernel_decode"), "attention_decode") def setup_params(mod, param_manager, dtype, config, args): @@ -1318,7 +1330,9 @@ def get_model(args, hf_config): build_model_only=args.build_model_only, ) else: - raise Exception("The model config should contain information about maximum sequence length.") + raise Exception( + "The model config should contain information about maximum sequence length." + ) # If there is a user-provided maximum sequence length, override hf config. if args.max_seq_len != -1: @@ -1331,7 +1345,7 @@ def get_model(args, hf_config): create_embed_func(bb, param_manager, config, args.quantization) if enable_batching: - emit_paged_kv_cache_op(bb, dtype) + emit_paged_kv_cache_op(bb, config) create_prefill_func_for_batching(bb, param_manager, config, args.quantization) create_decoding_func_for_batching(bb, param_manager, config, args.quantization) create_paged_kv_cache_func(bb, config)