From 30e57a011e97c2556d634864ee9f4f5fb719f515 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 29 Jan 2024 09:26:32 +0000 Subject: [PATCH 01/15] test stub --- examples/python/run_llama_batched_vllm.py | 144 ++++++++++++++-------- 1 file changed, 94 insertions(+), 50 deletions(-) diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index 0a2c8f0b9c..d9d519cd4f 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -172,6 +172,7 @@ def _prepare_inputs( sliding_window, dev, is_prefill, + query_token_len=1, ): block_tables = [] seq_lens = [] @@ -201,13 +202,15 @@ def _prepare_inputs( start_idx += prompt_len else: - input_ids.append(token_ids[-1]) - pos = len(token_ids) - 1 - positions.append(pos) + input_ids += token_ids[:-query_token_len] + + for i in range(query_token_len): + positions.append(len(token_ids) - (query_token_len - i)) + slot_mapping += all_slot_mappings[request_id][-query_token_len:] + block_table = all_block_tables[request_id] max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table)) block_tables.append(block_table) - slot_mapping.append(all_slot_mappings[request_id][-1]) if sliding_window: seq_lens.append(min(len(token_ids), sliding_window)) @@ -516,58 +519,99 @@ def run(args): for p, g in zip(prompts, generated): print("Prompt = '{}', generated text = '{}'".format(p, g)) - query_token_lens = [4, 3, 5, 2] + if model.disco_session: + return - eval_query_requests = [] + for query_token_lens, func_name in [ + ([4, 3, 5, 2], "evaluate_multi_query"), + ([3, 3, 3, 3], "decode_multi_query"), + ]: + if func_name == "evaluate_multi_query": + eval_query_requests = [] - for request_id, query_token_len in zip(request_ids, query_token_lens): - queries_to_eval = requests[request_id].token_ids[-query_token_len:] - num_past = len(requests[request_id].token_ids) - query_token_len - eval_query_requests.append(EvalQueryRequest(request_id, num_past, queries_to_eval)) + for request_id, query_token_len in zip(request_ids, query_token_lens): + queries_to_eval = requests[request_id].token_ids[-query_token_len:] + num_past = len(requests[request_id].token_ids) - query_token_len + eval_query_requests.append(EvalQueryRequest(request_id, num_past, queries_to_eval)) - ( - input_ids, - positions, - seq_lens, - slot_mapping, - query_lens, - past_slot_mapping, - permute_map, - ) = _prepare_eval_queries( - eval_query_requests, - cache.slot_mappings, - None, - model.dev, - ) + ( + input_ids, + positions, + seq_lens, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + ) = _prepare_eval_queries( + eval_query_requests, + cache.slot_mappings, + None, + model.dev, + ) - logits = model.mod["evaluate_multi_query"]( - input_ids, - positions, - seq_lens, - cache.cache, - slot_mapping, - query_lens, - past_slot_mapping, - permute_map, - model.params, - )[0].numpy() - - assert logits.shape[0] == sum(query_token_lens) - - logits_offset = 0 - - for request_id, query_token_len in zip(request_ids, query_token_lens): - for i in range(query_token_len - 1): - # requests[request_id].token_ids[-query_token_len:] are the "ground truth" tokens. - # Doing argmax over multi-timestep logits computed in parallel should yield the same - # tokens at the corresponding positions. - past_tokens = requests[request_id].token_ids[:-query_token_len] - assert ( - np.argmax(logits[logits_offset + i]) - == requests[request_id].token_ids[len(past_tokens) + i + 1] + logits = model.mod[func_name]( + input_ids, + positions, + seq_lens, + cache.cache, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + model.params, + )[0].numpy() + else: + decode_multi_query_requests = requests + + query_len = query_token_lens[0] + + ( + input_ids, + positions, + seq_lens, + slot_mapping, + _, + block_tables, + ) = _prepare_inputs( + decode_multi_query_requests, + cache.slot_mappings, + cache.block_tables, + model.sliding_window, + model.dev, + False, + query_len, ) - logits_offset += query_token_len + input_ids = tvm.nd.array(np.reshape(input_ids.numpy(), [-1, query_len]), dev) + + logits = model.mod[func_name]( + input_ids, + positions, + seq_lens, + cache.cache, + slot_mapping, + block_tables, + model.params, + )[0].numpy() + + logits = np.reshape(logits, (-1, logits.shape[-1])) + + assert logits.shape[0] == sum(query_token_lens) + + logits_offset = 0 + + for request_id, query_token_len in zip(request_ids, query_token_lens): + for i in range(query_token_len - 1): + # requests[request_id].token_ids[-query_token_len:] are the "ground truth" tokens. + # Doing argmax over multi-timestep logits computed in parallel should yield the same + # tokens at the corresponding positions. + past_tokens = requests[request_id].token_ids[:-query_token_len] + assert ( + np.argmax(logits[logits_offset + i]) + == requests[request_id].token_ids[len(past_tokens) + i + 1] + ) + + logits_offset += query_token_len if __name__ == "__main__": From 6f3429af889e55bdb8c7a7a3bf1d4cf0fc407765 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 29 Jan 2024 11:16:48 +0000 Subject: [PATCH 02/15] wip --- examples/python/run_llama_batched_vllm.py | 156 ++++++------ mlc_llm/relax_model/llama_batched_vllm.py | 293 ++++++++++++++-------- 2 files changed, 276 insertions(+), 173 deletions(-) diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index d9d519cd4f..356384328f 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -21,6 +21,7 @@ class KVCache: def __init__(self, num_blocks, block_size, num_layers, num_heads, head_size, disco_session): + # TODO: use tvm.contrib.flash_attn.allocate_kv_cache if disco_session: init_cache_func = disco_session.get_global_func("tvm.contrib.vllm.allocate_kv_cache") else: @@ -522,80 +523,7 @@ def run(args): if model.disco_session: return - for query_token_lens, func_name in [ - ([4, 3, 5, 2], "evaluate_multi_query"), - ([3, 3, 3, 3], "decode_multi_query"), - ]: - if func_name == "evaluate_multi_query": - eval_query_requests = [] - - for request_id, query_token_len in zip(request_ids, query_token_lens): - queries_to_eval = requests[request_id].token_ids[-query_token_len:] - num_past = len(requests[request_id].token_ids) - query_token_len - eval_query_requests.append(EvalQueryRequest(request_id, num_past, queries_to_eval)) - - ( - input_ids, - positions, - seq_lens, - slot_mapping, - query_lens, - past_slot_mapping, - permute_map, - ) = _prepare_eval_queries( - eval_query_requests, - cache.slot_mappings, - None, - model.dev, - ) - - logits = model.mod[func_name]( - input_ids, - positions, - seq_lens, - cache.cache, - slot_mapping, - query_lens, - past_slot_mapping, - permute_map, - model.params, - )[0].numpy() - else: - decode_multi_query_requests = requests - - query_len = query_token_lens[0] - - ( - input_ids, - positions, - seq_lens, - slot_mapping, - _, - block_tables, - ) = _prepare_inputs( - decode_multi_query_requests, - cache.slot_mappings, - cache.block_tables, - model.sliding_window, - model.dev, - False, - query_len, - ) - - input_ids = tvm.nd.array(np.reshape(input_ids.numpy(), [-1, query_len]), dev) - - logits = model.mod[func_name]( - input_ids, - positions, - seq_lens, - cache.cache, - slot_mapping, - block_tables, - model.params, - )[0].numpy() - - logits = np.reshape(logits, (-1, logits.shape[-1])) - + def verify_logits(logits, query_token_lens): assert logits.shape[0] == sum(query_token_lens) logits_offset = 0 @@ -613,6 +541,86 @@ def run(args): logits_offset += query_token_len + # query_token_lens = [4, 3, 5, 2] + # func_name = "evaluate_multi_query" + + # eval_query_requests = [] + + # for request_id, query_token_len in zip(request_ids, query_token_lens): + # queries_to_eval = requests[request_id].token_ids[-query_token_len:] + # num_past = len(requests[request_id].token_ids) - query_token_len + # eval_query_requests.append(EvalQueryRequest(request_id, num_past, queries_to_eval)) + + # ( + # input_ids, + # positions, + # seq_lens, + # slot_mapping, + # query_lens, + # past_slot_mapping, + # permute_map, + # ) = _prepare_eval_queries( + # eval_query_requests, + # cache.slot_mappings, + # None, + # model.dev, + # ) + + # logits = model.mod[func_name]( + # input_ids, + # positions, + # seq_lens, + # cache.cache, + # slot_mapping, + # query_lens, + # past_slot_mapping, + # permute_map, + # model.params, + # )[0].numpy() + + # verify_logits(logits, query_token_lens) + + # TODO: check KV type is flash decode + query_token_lens = [3, 3, 3, 3] + func_name = "decode_multi_query" + + decode_multi_query_requests = requests + + query_len = query_token_lens[0] + + ( + input_ids, + positions, + seq_lens, + slot_mapping, + _, + block_tables, + ) = _prepare_inputs( + decode_multi_query_requests, + cache.slot_mappings, + cache.block_tables, + model.sliding_window, + model.dev, + False, + query_len, + ) + + input_ids = tvm.nd.array(np.reshape(input_ids.numpy(), [-1, query_len]), dev) + + logits = model.mod[func_name]( + input_ids, + positions, + seq_lens, + cache.cache, + slot_mapping, + block_tables, + model.params, + )[0].numpy() + + logits = np.reshape(logits, (-1, logits.shape[-1])) + + verify_logits(logits, query_token_lens) + if __name__ == "__main__": run(parse_args()) diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index 33b0966f54..2969bba374 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -1,4 +1,5 @@ from typing import Optional, Tuple, Union +from enum import Enum, auto from dataclasses import dataclass @@ -47,6 +48,11 @@ def rotary_compute(*idx): return q_embed, k_embed +class KVCacheType(Enum): + VLLM = auto() + FlashDecoding = auto() + + @dataclass class PrefillAttentionInput: seq_start: Optional[relax.Expr] # (num_seq + 1,) @@ -57,6 +63,7 @@ class PrefillAttentionInput: class DecodeAttentionInput: seq_lens: relax.Expr # (num_seq,) block_tables: Optional[relax.Expr] # (num_seq, max_num_blocks_per_seq) + seqlen_q: Optional[tvm.tir.SizeVar] # For flash decoding @dataclass @@ -78,6 +85,7 @@ class AttentionInput: slot_mapping: Optional[relax.Expr] # (num_query_token,) max_seqlen: relax.Expr # (), must be on CPU aux_info: Union[PrefillAttentionInput, DecodeAttentionInput, EvaluateMultiQueryInput] + kv_type: KVCacheType class LlamaAttentionBatched(LlamaAttentionBase): @@ -89,16 +97,20 @@ def __init__(self, config: LlamaConfig): self.sliding_window = T.IntImm("int32", config.sliding_window) max_context_length = config.sliding_window or config.max_sequence_length - partition_size = 512 # partition_size in vLLM attention - self.max_num_partitions = (max_context_length + partition_size - 1) // partition_size + + if False: # TODO + partition_size = 512 # partition_size in vLLM attention + self.max_num_partitions = (max_context_length + partition_size - 1) // partition_size + else: + self.max_num_partitons = 128 def forward( self, - hidden_states: relax.Expr, # (num_query_token, hidden_size) + hidden_states: relax.Expr, # (num_query_token, hidden_size) or (num_seq, seqlen_q, hidden_size) positions: relax.Expr, # (num_query_token,), for batched RoPE attn_input: AttentionInput, ): - num_query_tokens, _ = hidden_states.struct_info.shape + num_query_tokens = positions.struct_info.shape queries, keys, values = self.project_qkv( hidden_states, @@ -129,17 +141,30 @@ def forward( slot_mapping = attn_input.slot_mapping # kv caches are updated inplace, but make it look like a pure operation - kv = nn.emit( - relax.op.call_pure_packed( - "tvm.contrib.vllm.reshape_and_cache", - keys_to_cache, - values_to_cache, - k_cache, - v_cache, - slot_mapping, - sinfo_args=[k_cache.struct_info, v_cache.struct_info], + if attn_input.kv_type == KVCacheType.FlashDecoding: + kv = nn.emit( + relax.op.call_pure_packed( + "tvm.contrib.vllm.reshape_and_cache", + keys_to_cache, + values_to_cache, + k_cache, + v_cache, + slot_mapping, + sinfo_args=[k_cache.struct_info, v_cache.struct_info], + ) + ) + else: + kv = nn.emit( + relax.op.call_pure_packed( + "tvm.contrib.flash_attn.update_cache", + keys_to_cache, + values_to_cache, + k_cache, + v_cache, + slot_mapping, + sinfo_args=[k_cache.struct_info, v_cache.struct_info], + ) ) - ) k_cache, v_cache = kv[0], kv[1] else: @@ -153,15 +178,27 @@ def forward( kv_shape = (num_past_token, num_kv_head, head_size) kv_sinfo = relax.TensorStructInfo(kv_shape, k_cache.struct_info.dtype) - kv_tensors = nn.emit( - relax.op.call_pure_packed( - "tvm.contrib.vllm.reconstruct_from_cache", - k_cache, - v_cache, - attn_input.aux_info.past_slot_mapping, - sinfo_args=[kv_sinfo, kv_sinfo], + if attn_input.kv_type == KVCacheType.FlashDecoding: + kv_tensors = nn.emit( + relax.op.call_pure_packed( + "tvm.contrib.vllm.reconstruct_from_cache", + k_cache, + v_cache, + attn_input.aux_info.past_slot_mapping, + sinfo_args=[kv_sinfo, kv_sinfo], + ) ) - ) + else: + kv_tensors = nn.emit( + relax.op.call_pure_packed( + "tvm.contrib.flash_attn.reconstruct_from_cache", + k_cache, + v_cache, + attn_input.aux_info.past_slot_mapping, + sinfo_args=[kv_sinfo, kv_sinfo], + ) + ) + keys_past, values_past = kv_tensors[0], kv_tensors[1] # Say we have past tokens [P1, P2, P3] and the current ones [C1, C2, C3]. # Each of P1, C1 etc is a sequence of tokens. @@ -211,62 +248,99 @@ def forward( ) ) else: - # Decode, using vLLM kernel + # Decode, using vLLM or Flash-Decoding kernel assert isinstance(attn_input.aux_info, DecodeAttentionInput) - exp_sums = nn.emit( - relax.op.builtin.alloc_tensor( - relax.ShapeExpr( - (num_query_tokens, self.num_query_heads, self.max_num_partitions) - ), - dtype="float32", - runtime_device_index=0, + if attn_input.kv_type == KVCacheType.VLLM: + exp_sums = nn.emit( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr( + (num_query_tokens, self.num_query_heads, self.max_num_partitions) + ), + dtype="float32", + runtime_device_index=0, + ) ) - ) - max_logits = nn.emit( - relax.op.builtin.alloc_tensor( - relax.ShapeExpr( - (num_query_tokens, self.num_query_heads, self.max_num_partitions) - ), - dtype="float32", - runtime_device_index=0, + max_logits = nn.emit( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr( + (num_query_tokens, self.num_query_heads, self.max_num_partitions) + ), + dtype="float32", + runtime_device_index=0, + ) ) - ) - tmp_out = nn.emit( - relax.op.builtin.alloc_tensor( - relax.ShapeExpr( - ( - num_query_tokens, - self.num_query_heads, - self.max_num_partitions, - self.head_dim, - ) - ), - dtype=queries.struct_info.dtype, - runtime_device_index=0, + tmp_out = nn.emit( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr( + ( + num_query_tokens, + self.num_query_heads, + self.max_num_partitions, + self.head_dim, + ) + ), + dtype=queries.struct_info.dtype, + runtime_device_index=0, + ) ) - ) - attn_output = nn.emit( - relax.op.call_dps_packed( - "tvm.contrib.vllm.single_query_cached_kv_attention", + attn_output = nn.emit( + relax.op.call_dps_packed( + "tvm.contrib.vllm.single_query_cached_kv_attention", + [ + queries, + k_cache, + v_cache, + attn_input.aux_info.block_tables, + attn_input.aux_info.seq_lens, + 16, # block_size + attn_input.max_seqlen, + exp_sums, + max_logits, + tmp_out, + ], + out_sinfo=queries.struct_info, + ) + ) + else: + num_seq, seqlen_q = hidden_states.struct_info.shape + queries = nn.emit(reshape(queries, (num_seq, seqlen_q, self.num_query_heads, self.head_dim))) + + softmax_lse_accum = nn.emit( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr( + (self.max_num_partitions, num_seq, self.num_query_heads, seqlen_q) + ), + dtype="float32", + runtime_device_index=0, + ) + ) + output_accum = nn.emit( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr( + (self.max_num_partitions, num_seq, self.num_query_heads, seqlen_q, self.head_dim) + ), + dtype="float32", + runtime_device_index=0, + ) + ) + + attn_output = R.call_dps_packed( + "tvm.contrib.flash_attn.flash_decoding_with_paged_kvcache", [ queries, k_cache, v_cache, attn_input.aux_info.block_tables, attn_input.aux_info.seq_lens, - 16, # block_size - attn_input.max_seqlen, - exp_sums, - max_logits, - tmp_out, + softmax_lse_accum, + output_accum, ], out_sinfo=queries.struct_info, ) - ) attn_output = nn.emit( - reshape(attn_output, (num_query_tokens, self.num_query_heads * self.head_dim)) + reshape(attn_output, hidden_states.struct_info.shape) ) attn_output = self.o_proj(attn_output) @@ -352,7 +426,7 @@ def forward( else: cache = None - attn_input = AttentionInput(cache, slot_mapping, max_seqlen, attn_aux_info) + attn_input = AttentionInput(cache, slot_mapping, max_seqlen, attn_aux_info, KVCacheType.VLLM) hidden_states, new_kv = decoder_layer( hidden_states, @@ -390,7 +464,7 @@ def __init__( def forward( self, - input_ids: relax.Expr, # (num_query_token,) + input_ids: relax.Expr, # (num_query_token,) or (num_seq, seqlen_q) positions: relax.Expr, # (num_query_token,), for batched RoPE seq_lens: relax.Expr, # (num_seq,) kv_caches: Optional[relax.Expr], # For prefill and decode, not needed for evaluate @@ -504,14 +578,20 @@ def get_logits_last_tokens(x, seq_len_tensor, seq_start): def get_inputs( - num_query_token, num_seq, config, max_num_blocks_per_seq=None, sep_embed=False, need_cache=True + num_query_token, num_seq, config, seqlen_q=None, max_num_blocks_per_seq=None, sep_embed=False, need_cache=True ): hidden_size = config.hidden_size + if seqlen_q is None: + input_shape = (num_query_token,) + else: + input_shape = (num_seq, seqlen_q) + num_query_token = num_seq * seqlen_q + inputs = ( - nn.Placeholder((num_query_token, hidden_size), dtype=config.dtype, name="inputs_embeds") + nn.Placeholder(input_shape + (hidden_size,), dtype=config.dtype, name="inputs_embeds") if sep_embed - else nn.Placeholder((num_query_token,), dtype="int32", name="input_ids") + else nn.Placeholder(input_shape, dtype="int32", name="input_ids") ) seq_lens = nn.Placeholder((num_seq,), dtype="int32", name="seq_lens") @@ -686,48 +766,63 @@ def create_decoding_func( config: LlamaConfig, cpu_dev: VDevice, quant_scheme: QuantizationScheme, + use_flash_decoding=False, ) -> None: """Batched decoding with vLLM paged KV cache.""" func_name = "decode" num_seq = tvm.tir.SizeVar("num_seq", "int64") - max_num_blocks_per_seq = tvm.tir.SizeVar("max_num_blocks_per_seq", "int64") + seqlen_q = tvm.tir.SizeVar("seqlen_q", "int64") - with bb.function(func_name): - inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables = get_inputs( - num_seq, num_seq, config, max_num_blocks_per_seq - ) + seqlen_q_info = [("decode", 1)] - with bb.dataflow(): - model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64")) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + if use_flash_decoding: + seqlen_q_info.append(("decode_multi_query", seqlen_q)) - logits, new_kvs = model( - inputs, - positions, - seq_lens, - past_key_values, - slot_mapping, - block_tables, - None, - None, - None, - None, + for (func_name, seqlen_q) in seqlen_q_info: + max_num_blocks_per_seq = tvm.tir.SizeVar("max_num_blocks_per_seq", "int64") + + # This if / else is probably not needed + if seqlen_q == 1: + num_query_token = num_seq + else: + num_query_token = num_seq * seqlen_q + + with bb.function(func_name): + inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables = get_inputs( + num_query_token, num_seq, config, seqlen_q, max_num_blocks_per_seq ) - params = [ - inputs, - positions, - seq_lens, - past_key_values, - slot_mapping, - block_tables, - ] + model.parameters() - gv = bb.emit_output((logits, relax.Tuple(new_kvs))) - bb.emit_func_output(gv, params) - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 6)) + with bb.dataflow(): + model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64")) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + logits, new_kvs = model( + inputs, + positions, + seq_lens, + past_key_values, + slot_mapping, + block_tables, + None, + None, + None, + None, + ) + params = [ + inputs, + positions, + seq_lens, + past_key_values, + slot_mapping, + block_tables, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(new_kvs))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 6)) def create_evaluate_multi_query_func( From 97a43665528c37841d773c39583cd909f34f6a25 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 29 Jan 2024 11:22:41 +0000 Subject: [PATCH 03/15] wip --- mlc_llm/relax_model/llama_batched_vllm.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index 2969bba374..052f527e69 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -410,6 +410,7 @@ def forward( slot_mapping: Optional[relax.Expr], max_seqlen: relax.Expr, attn_aux_info: Union[PrefillAttentionInput, DecodeAttentionInput, EvaluateMultiQueryInput], + kv_type: KVCacheType, ): if self.embed_tokens: inputs_embeds = self.embed_tokens(inputs) @@ -426,7 +427,7 @@ def forward( else: cache = None - attn_input = AttentionInput(cache, slot_mapping, max_seqlen, attn_aux_info, KVCacheType.VLLM) + attn_input = AttentionInput(cache, slot_mapping, max_seqlen, attn_aux_info, kv_type) hidden_states, new_kv = decoder_layer( hidden_states, @@ -692,6 +693,7 @@ def create_encoding_func( bb: relax.BlockBuilder, param_manager: ParamManager, config: LlamaConfig, + kv_type: KVCacheType, cpu_dev: VDevice, quant_scheme: QuantizationScheme, sep_embed: bool = False, @@ -764,9 +766,9 @@ def create_decoding_func( bb: relax.BlockBuilder, param_manager: ParamManager, config: LlamaConfig, + kv_type: KVCacheType, cpu_dev: VDevice, quant_scheme: QuantizationScheme, - use_flash_decoding=False, ) -> None: """Batched decoding with vLLM paged KV cache.""" func_name = "decode" @@ -776,7 +778,7 @@ def create_decoding_func( seqlen_q_info = [("decode", 1)] - if use_flash_decoding: + if kv_type == KVCacheType.FlashDecoding: seqlen_q_info.append(("decode_multi_query", seqlen_q)) for (func_name, seqlen_q) in seqlen_q_info: @@ -829,6 +831,7 @@ def create_evaluate_multi_query_func( bb: relax.BlockBuilder, param_manager: ParamManager, config: LlamaConfig, + kv_type: KVCacheType, cpu_dev: VDevice, quant_scheme: QuantizationScheme, ) -> None: @@ -953,10 +956,12 @@ def get_model(args, hf_config): # The CPU device to copy the result of relax.op.max(seq_lens) to CPU. cpu_dev = VDevice("llvm", 0, "global") + kv_type = KVCacheType.VLM + create_evaluate_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) - create_encoding_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) - create_decoding_func(bb, param_manager, config, cpu_dev, args.quantization) - create_evaluate_multi_query_func(bb, param_manager, config, cpu_dev, args.quantization) + create_encoding_func(bb, param_manager, config, kv_type, cpu_dev, args.quantization, sep_embed) + create_decoding_func(bb, param_manager, config, kv_type, cpu_dev, args.quantization) + create_evaluate_multi_query_func(bb, param_manager, config, kv_type, cpu_dev, args.quantization) mod = bb.get() From 7279cb6f2811778458df3d0e55b3bf48089cfb0f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 30 Jan 2024 05:21:29 +0900 Subject: [PATCH 04/15] wip --- mlc_llm/relax_model/llama_batched_vllm.py | 81 ++++++++++++++--------- 1 file changed, 48 insertions(+), 33 deletions(-) diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index 052f527e69..47a8c7f9e4 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -63,7 +63,6 @@ class PrefillAttentionInput: class DecodeAttentionInput: seq_lens: relax.Expr # (num_seq,) block_tables: Optional[relax.Expr] # (num_seq, max_num_blocks_per_seq) - seqlen_q: Optional[tvm.tir.SizeVar] # For flash decoding @dataclass @@ -85,12 +84,12 @@ class AttentionInput: slot_mapping: Optional[relax.Expr] # (num_query_token,) max_seqlen: relax.Expr # (), must be on CPU aux_info: Union[PrefillAttentionInput, DecodeAttentionInput, EvaluateMultiQueryInput] - kv_type: KVCacheType class LlamaAttentionBatched(LlamaAttentionBase): - def __init__(self, config: LlamaConfig): + def __init__(self, config: LlamaConfig, kv_type: KVCacheType): super().__init__(config) + self.kv_type = kv_type self.sliding_window = None if config.sliding_window: @@ -98,7 +97,7 @@ def __init__(self, config: LlamaConfig): max_context_length = config.sliding_window or config.max_sequence_length - if False: # TODO + if kv_type == KVCacheType.VLLM: partition_size = 512 # partition_size in vLLM attention self.max_num_partitions = (max_context_length + partition_size - 1) // partition_size else: @@ -141,7 +140,7 @@ def forward( slot_mapping = attn_input.slot_mapping # kv caches are updated inplace, but make it look like a pure operation - if attn_input.kv_type == KVCacheType.FlashDecoding: + if self.kv_type == KVCacheType.FlashDecoding: kv = nn.emit( relax.op.call_pure_packed( "tvm.contrib.vllm.reshape_and_cache", @@ -178,7 +177,7 @@ def forward( kv_shape = (num_past_token, num_kv_head, head_size) kv_sinfo = relax.TensorStructInfo(kv_shape, k_cache.struct_info.dtype) - if attn_input.kv_type == KVCacheType.FlashDecoding: + if self.kv_type == KVCacheType.FlashDecoding: kv_tensors = nn.emit( relax.op.call_pure_packed( "tvm.contrib.vllm.reconstruct_from_cache", @@ -251,7 +250,7 @@ def forward( # Decode, using vLLM or Flash-Decoding kernel assert isinstance(attn_input.aux_info, DecodeAttentionInput) - if attn_input.kv_type == KVCacheType.VLLM: + if self.kv_type == KVCacheType.VLLM: exp_sums = nn.emit( relax.op.builtin.alloc_tensor( relax.ShapeExpr( @@ -348,9 +347,9 @@ def forward( class LlamaDecoderLayerBatched(LlamaDecoderLayer): - def __init__(self, config: LlamaConfig): + def __init__(self, config: LlamaConfig, kv_type: KVCacheType): super().__init__(config, False) - self.self_attn = LlamaAttentionBatched(config) + self.self_attn = LlamaAttentionBatched(config, kv_type) def forward( self, @@ -389,6 +388,7 @@ def __init__( self, config: LlamaConfig, vocab_size_var: tvm.tir.Var, + kv_type: KVCacheType, sep_embed: bool = False, ): self.padding_idx = config.pad_token_id @@ -398,7 +398,7 @@ def __init__( self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) self.layers = ModuleList( - [LlamaDecoderLayerBatched(config) for _ in range(config.num_hidden_layers)] + [LlamaDecoderLayerBatched(config, kv_type) for _ in range(config.num_hidden_layers)] ) self.norm = LlamaRMSNorm(config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps) @@ -410,7 +410,6 @@ def forward( slot_mapping: Optional[relax.Expr], max_seqlen: relax.Expr, attn_aux_info: Union[PrefillAttentionInput, DecodeAttentionInput, EvaluateMultiQueryInput], - kv_type: KVCacheType, ): if self.embed_tokens: inputs_embeds = self.embed_tokens(inputs) @@ -427,7 +426,7 @@ def forward( else: cache = None - attn_input = AttentionInput(cache, slot_mapping, max_seqlen, attn_aux_info, kv_type) + attn_input = AttentionInput(cache, slot_mapping, max_seqlen, attn_aux_info) hidden_states, new_kv = decoder_layer( hidden_states, @@ -445,11 +444,12 @@ def __init__( config: LlamaConfig, cpu_device: VDevice, vocab_size_var: tvm.tir.SizeVar, + kv_type: KVCacheType, sep_embed: bool = False, ): self.num_shards = config.num_shards self.cpu_device = cpu_device - self.model = LlamaModel(config, vocab_size_var, sep_embed) + self.model = LlamaModel(config, vocab_size_var, kv_type, sep_embed) self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) ############ Rotary embedding constants ############ @@ -579,11 +579,11 @@ def get_logits_last_tokens(x, seq_len_tensor, seq_start): def get_inputs( - num_query_token, num_seq, config, seqlen_q=None, max_num_blocks_per_seq=None, sep_embed=False, need_cache=True + num_query_token, num_seq, config, kv_type, seqlen_q=None, max_num_blocks_per_seq=None, sep_embed=False, need_cache=True ): hidden_size = config.hidden_size - if seqlen_q is None: + if kv_type == KVCacheType.VLLM: input_shape = (num_query_token,) else: input_shape = (num_seq, seqlen_q) @@ -600,20 +600,35 @@ def get_inputs( if need_cache: num_blocks = tvm.tir.Var("num_blocks", "int64") - block_size = 16 - - vec_size = 8 # 128 bit, fp16 x 8 - num_key_value_heads = config.get_num_key_value_heads() // config.num_shards - head_size = hidden_size // config.num_attention_heads - - k_cache_shape = ( - num_blocks, - num_key_value_heads, - head_size // vec_size, - block_size, - vec_size, - ) - v_cache_shape = (num_blocks, num_key_value_heads, head_size, block_size) + + if kv_type == KVCacheType.VLLM: + block_size = 16 + + vec_size = 8 # 128 bit, fp16 x 8 + num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + head_size = hidden_size // config.num_attention_heads + + k_cache_shape = ( + num_blocks, + num_key_value_heads, + head_size // vec_size, + block_size, + vec_size, + ) + v_cache_shape = (num_blocks, num_key_value_heads, head_size, block_size) + else: + block_size = 256 + + num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + head_size = hidden_size // config.num_attention_heads + + k_cache_shape = ( + num_blocks, + block_size, + num_key_value_heads, + head_size + ) + v_cache_shape = k_cache_shape get_cache_sinfo = lambda i: relax.TensorStructInfo( k_cache_shape if i % 2 == 0 else v_cache_shape, dtype="float16" @@ -711,7 +726,7 @@ def create_encoding_func( num_inputs = 5 with bb.function(func_name): - model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed) + model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), kv_type, sep_embed) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids, positions, seq_lens, past_key_values, slot_mapping, _ = get_inputs( @@ -796,7 +811,7 @@ def create_decoding_func( ) with bb.dataflow(): - model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64")) + model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), kv_type) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) logits, new_kvs = model( @@ -845,7 +860,7 @@ def create_evaluate_multi_query_func( num_inputs = 8 with bb.function(func_name): - model = LlamaForCausalLM(config, cpu_dev, tvm.tir.Var("vocab_size", "int64"), False) + model = LlamaForCausalLM(config, cpu_dev, tvm.tir.Var("vocab_size", "int64"), kv_type, False) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids, positions, seq_lens, past_key_values, slot_mapping, _ = get_inputs( @@ -956,7 +971,7 @@ def get_model(args, hf_config): # The CPU device to copy the result of relax.op.max(seq_lens) to CPU. cpu_dev = VDevice("llvm", 0, "global") - kv_type = KVCacheType.VLM + kv_type = KVCacheType.VLLM create_evaluate_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) create_encoding_func(bb, param_manager, config, kv_type, cpu_dev, args.quantization, sep_embed) From 7348f0e9ad8179cf93e41572d2053cdbd4c8ba87 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 29 Jan 2024 20:56:39 +0000 Subject: [PATCH 05/15] compiled --- examples/python/run_llama_batched_vllm.py | 29 +++++++++++++---------- mlc_llm/relax_model/llama_batched_vllm.py | 23 +++++++----------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index 356384328f..24978403aa 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -20,15 +20,8 @@ class KVCache: - def __init__(self, num_blocks, block_size, num_layers, num_heads, head_size, disco_session): - # TODO: use tvm.contrib.flash_attn.allocate_kv_cache - if disco_session: - init_cache_func = disco_session.get_global_func("tvm.contrib.vllm.allocate_kv_cache") - else: - init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache") - + def __init__(self, num_blocks, block_size, num_layers, num_heads, head_size, init_cache_func): self.cache = init_cache_func(head_size, num_layers, num_heads, block_size, num_blocks) - self.block_tables = defaultdict(list) self.slot_mappings = defaultdict(list) self.block_size = block_size @@ -38,12 +31,12 @@ class CacheManager: block_size: int = 16 def __init__( - self, num_blocks, num_layers, num_heads, head_size, disco_session=None, sliding_window=None + self, num_blocks, num_layers, num_heads, head_size, init_cache_func, sliding_window=None ): self.num_blocks = num_blocks self.free_blocks = list(range(num_blocks)) self.kv_cache = KVCache( - num_blocks, self.block_size, num_layers, num_heads, head_size, disco_session + num_blocks, self.block_size, num_layers, num_heads, head_size, init_cache_func ) if sliding_window: @@ -464,12 +457,25 @@ def run(args): head_size = config.hidden_size // config.num_attention_heads num_blocks = 500 + # TODO: check KV type is flash decode + use_flash_decoding = True + + if use_flash_decoding: + allocate_func_name = "tvm.contrib.flash_attn.allocate_kv_cache" + else: + allocate_func_name = "tvm.contrib.vllm.allocate_kv_cache" + + if model.disco_session: + init_cache_func = model.disco_session.get_global_func(allocate_func_name) + else: + init_cache_func = tvm.get_global_func(allocate_func_name) + cache_manager = CacheManager( num_blocks, config.num_hidden_layers, num_kv_heads, head_size, - model.disco_session, + init_cache_func, sliding_window=config.sliding_window, ) cache = cache_manager.get() @@ -580,7 +586,6 @@ def verify_logits(logits, query_token_lens): # verify_logits(logits, query_token_lens) - # TODO: check KV type is flash decode query_token_lens = [3, 3, 3, 3] func_name = "decode_multi_query" diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index 47a8c7f9e4..9cfc784c7b 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -109,7 +109,7 @@ def forward( positions: relax.Expr, # (num_query_token,), for batched RoPE attn_input: AttentionInput, ): - num_query_tokens = positions.struct_info.shape + num_query_tokens = positions.struct_info.shape[0] queries, keys, values = self.project_qkv( hidden_states, @@ -579,16 +579,10 @@ def get_logits_last_tokens(x, seq_len_tensor, seq_start): def get_inputs( - num_query_token, num_seq, config, kv_type, seqlen_q=None, max_num_blocks_per_seq=None, sep_embed=False, need_cache=True + num_query_token, num_seq, input_shape, config, kv_type=None, max_num_blocks_per_seq=None, sep_embed=False ): hidden_size = config.hidden_size - if kv_type == KVCacheType.VLLM: - input_shape = (num_query_token,) - else: - input_shape = (num_seq, seqlen_q) - num_query_token = num_seq * seqlen_q - inputs = ( nn.Placeholder(input_shape + (hidden_size,), dtype=config.dtype, name="inputs_embeds") if sep_embed @@ -598,7 +592,7 @@ def get_inputs( seq_lens = nn.Placeholder((num_seq,), dtype="int32", name="seq_lens") positions = nn.Placeholder((num_query_token,), dtype="int32", name="positions") - if need_cache: + if kv_type: num_blocks = tvm.tir.Var("num_blocks", "int64") if kv_type == KVCacheType.VLLM: @@ -675,7 +669,7 @@ def create_evaluate_func( param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) inputs, positions, seq_lens, _, _, _ = get_inputs( - num_query_token, num_seq, config, sep_embed=sep_embed + num_query_token, num_seq, (num_query_token,), config, sep_embed=sep_embed ) with bb.dataflow(): @@ -730,7 +724,7 @@ def create_encoding_func( param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids, positions, seq_lens, past_key_values, slot_mapping, _ = get_inputs( - num_query_token, num_seq, config, sep_embed=sep_embed + num_query_token, num_seq, (num_query_token,), config, kv_type, sep_embed=sep_embed ) with bb.dataflow(): @@ -799,15 +793,16 @@ def create_decoding_func( for (func_name, seqlen_q) in seqlen_q_info: max_num_blocks_per_seq = tvm.tir.SizeVar("max_num_blocks_per_seq", "int64") - # This if / else is probably not needed if seqlen_q == 1: num_query_token = num_seq + input_shape = (num_query_token,) else: num_query_token = num_seq * seqlen_q + input_shape = (num_seq, seqlen_q) with bb.function(func_name): inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables = get_inputs( - num_query_token, num_seq, config, seqlen_q, max_num_blocks_per_seq + num_query_token, num_seq, input_shape, config, kv_type, max_num_blocks_per_seq ) with bb.dataflow(): @@ -864,7 +859,7 @@ def create_evaluate_multi_query_func( param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids, positions, seq_lens, past_key_values, slot_mapping, _ = get_inputs( - num_query_token, num_seq, config, sep_embed=False + num_query_token, num_seq, (num_query_token,), config, kv_type, sep_embed=False ) query_lens = nn.Placeholder((num_seq,), dtype="int32", name="query_lens") From b69237629d3390cfb1d471eb3baf3da71cf91192 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 29 Jan 2024 21:09:11 +0000 Subject: [PATCH 06/15] wip --- examples/python/run_llama_batched_vllm.py | 3 +-- mlc_llm/relax_model/llama_batched_vllm.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index 24978403aa..93a5d93c2e 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -457,8 +457,7 @@ def run(args): head_size = config.hidden_size // config.num_attention_heads num_blocks = 500 - # TODO: check KV type is flash decode - use_flash_decoding = True + use_flash_decoding = False if use_flash_decoding: allocate_func_name = "tvm.contrib.flash_attn.allocate_kv_cache" diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index 9cfc784c7b..4ca01d2783 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -140,7 +140,7 @@ def forward( slot_mapping = attn_input.slot_mapping # kv caches are updated inplace, but make it look like a pure operation - if self.kv_type == KVCacheType.FlashDecoding: + if self.kv_type == KVCacheType.VLLM: kv = nn.emit( relax.op.call_pure_packed( "tvm.contrib.vllm.reshape_and_cache", @@ -177,7 +177,7 @@ def forward( kv_shape = (num_past_token, num_kv_head, head_size) kv_sinfo = relax.TensorStructInfo(kv_shape, k_cache.struct_info.dtype) - if self.kv_type == KVCacheType.FlashDecoding: + if self.kv_type == KVCacheType.VLLM: kv_tensors = nn.emit( relax.op.call_pure_packed( "tvm.contrib.vllm.reconstruct_from_cache", From 1df6cac0253024484225c0119e622ceac7dfa486 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 29 Jan 2024 21:24:05 +0000 Subject: [PATCH 07/15] fix --- examples/python/run_llama_batched_vllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index 93a5d93c2e..f05667774b 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -196,7 +196,7 @@ def _prepare_inputs( start_idx += prompt_len else: - input_ids += token_ids[:-query_token_len] + input_ids += token_ids[-query_token_len:] for i in range(query_token_len): positions.append(len(token_ids) - (query_token_len - i)) From 8c8872c1f303a74b766b080c7c2c061eb03329be Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 29 Jan 2024 21:42:31 +0000 Subject: [PATCH 08/15] fix --- examples/python/run_llama_batched_vllm.py | 77 ++++++++++++----------- mlc_llm/relax_model/llama_batched_vllm.py | 21 +++++-- 2 files changed, 55 insertions(+), 43 deletions(-) diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index f05667774b..5771a3df9d 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -546,45 +546,46 @@ def verify_logits(logits, query_token_lens): logits_offset += query_token_len - # query_token_lens = [4, 3, 5, 2] - # func_name = "evaluate_multi_query" - - # eval_query_requests = [] - - # for request_id, query_token_len in zip(request_ids, query_token_lens): - # queries_to_eval = requests[request_id].token_ids[-query_token_len:] - # num_past = len(requests[request_id].token_ids) - query_token_len - # eval_query_requests.append(EvalQueryRequest(request_id, num_past, queries_to_eval)) - - # ( - # input_ids, - # positions, - # seq_lens, - # slot_mapping, - # query_lens, - # past_slot_mapping, - # permute_map, - # ) = _prepare_eval_queries( - # eval_query_requests, - # cache.slot_mappings, - # None, - # model.dev, - # ) - - # logits = model.mod[func_name]( - # input_ids, - # positions, - # seq_lens, - # cache.cache, - # slot_mapping, - # query_lens, - # past_slot_mapping, - # permute_map, - # model.params, - # )[0].numpy() - - # verify_logits(logits, query_token_lens) + query_token_lens = [4, 3, 5, 2] + func_name = "evaluate_multi_query" + eval_query_requests = [] + + for request_id, query_token_len in zip(request_ids, query_token_lens): + queries_to_eval = requests[request_id].token_ids[-query_token_len:] + num_past = len(requests[request_id].token_ids) - query_token_len + eval_query_requests.append(EvalQueryRequest(request_id, num_past, queries_to_eval)) + + ( + input_ids, + positions, + seq_lens, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + ) = _prepare_eval_queries( + eval_query_requests, + cache.slot_mappings, + None, + model.dev, + ) + + logits = model.mod[func_name]( + input_ids, + positions, + seq_lens, + cache.cache, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + model.params, + )[0].numpy() + + verify_logits(logits, query_token_lens) + + return query_token_lens = [3, 3, 3, 3] func_name = "decode_multi_query" diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index 4ca01d2783..ed13e213cf 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -101,7 +101,7 @@ def __init__(self, config: LlamaConfig, kv_type: KVCacheType): partition_size = 512 # partition_size in vLLM attention self.max_num_partitions = (max_context_length + partition_size - 1) // partition_size else: - self.max_num_partitons = 128 + self.max_num_partitions = 128 def forward( self, @@ -171,8 +171,14 @@ def forward( if isinstance(attn_input.aux_info, EvaluateMultiQueryInput): assert k_cache and v_cache - num_kv_head = v_cache.struct_info.shape[1] - head_size = v_cache.struct_info.shape[2] + + if self.kv_type == KVCacheType.VLLM: + num_kv_head = v_cache.struct_info.shape[1] + else: + num_kv_head = v_cache.struct_info.shape[2] + + head_size = v_cache.struct_info.shape[-1] + num_past_token = attn_input.aux_info.past_slot_mapping.struct_info.shape[0] kv_shape = (num_past_token, num_kv_head, head_size) kv_sinfo = relax.TensorStructInfo(kv_shape, k_cache.struct_info.dtype) @@ -302,7 +308,12 @@ def forward( ) ) else: - num_seq, seqlen_q = hidden_states.struct_info.shape + if len(hidden_states.struct_info.shape) == 3: + num_seq, seqlen_q, _ = hidden_states.struct_info.shape + else: + num_seq = hidden_states.struct_info.shape[0] + seqlen_q = 1 + queries = nn.emit(reshape(queries, (num_seq, seqlen_q, self.num_query_heads, self.head_dim))) softmax_lse_accum = nn.emit( @@ -966,7 +977,7 @@ def get_model(args, hf_config): # The CPU device to copy the result of relax.op.max(seq_lens) to CPU. cpu_dev = VDevice("llvm", 0, "global") - kv_type = KVCacheType.VLLM + kv_type = KVCacheType.FlashDecoding create_evaluate_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) create_encoding_func(bb, param_manager, config, kv_type, cpu_dev, args.quantization, sep_embed) From 6a8272f746da14db9ee9f7f3925612bbccbb1272 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 29 Jan 2024 22:08:42 +0000 Subject: [PATCH 09/15] wip, decode with flash decoding works --- examples/python/run_llama_batched_vllm.py | 93 ++++++++++++----------- 1 file changed, 48 insertions(+), 45 deletions(-) diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index 5771a3df9d..9e154f3a6b 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -28,11 +28,10 @@ def __init__(self, num_blocks, block_size, num_layers, num_heads, head_size, ini class CacheManager: - block_size: int = 16 - def __init__( - self, num_blocks, num_layers, num_heads, head_size, init_cache_func, sliding_window=None + self, num_blocks, block_size, num_layers, num_heads, head_size, init_cache_func, sliding_window=None ): + self.block_size = block_size self.num_blocks = num_blocks self.free_blocks = list(range(num_blocks)) self.kv_cache = KVCache( @@ -455,14 +454,17 @@ def run(args): num_kv_heads = config.get_num_key_value_heads() // args.num_shards head_size = config.hidden_size // config.num_attention_heads - num_blocks = 500 - use_flash_decoding = False + use_flash_decoding = True if use_flash_decoding: allocate_func_name = "tvm.contrib.flash_attn.allocate_kv_cache" + block_size = 256 + num_blocks = 30 else: allocate_func_name = "tvm.contrib.vllm.allocate_kv_cache" + block_size = 16 + num_blocks = 500 if model.disco_session: init_cache_func = model.disco_session.get_global_func(allocate_func_name) @@ -471,6 +473,7 @@ def run(args): cache_manager = CacheManager( num_blocks, + block_size, config.num_hidden_layers, num_kv_heads, head_size, @@ -546,46 +549,46 @@ def verify_logits(logits, query_token_lens): logits_offset += query_token_len - query_token_lens = [4, 3, 5, 2] - func_name = "evaluate_multi_query" - - eval_query_requests = [] - - for request_id, query_token_len in zip(request_ids, query_token_lens): - queries_to_eval = requests[request_id].token_ids[-query_token_len:] - num_past = len(requests[request_id].token_ids) - query_token_len - eval_query_requests.append(EvalQueryRequest(request_id, num_past, queries_to_eval)) - - ( - input_ids, - positions, - seq_lens, - slot_mapping, - query_lens, - past_slot_mapping, - permute_map, - ) = _prepare_eval_queries( - eval_query_requests, - cache.slot_mappings, - None, - model.dev, - ) - - logits = model.mod[func_name]( - input_ids, - positions, - seq_lens, - cache.cache, - slot_mapping, - query_lens, - past_slot_mapping, - permute_map, - model.params, - )[0].numpy() - - verify_logits(logits, query_token_lens) - - return + # query_token_lens = [4, 3, 5, 2] + # func_name = "evaluate_multi_query" + + # eval_query_requests = [] + + # for request_id, query_token_len in zip(request_ids, query_token_lens): + # queries_to_eval = requests[request_id].token_ids[-query_token_len:] + # num_past = len(requests[request_id].token_ids) - query_token_len + # eval_query_requests.append(EvalQueryRequest(request_id, num_past, queries_to_eval)) + + # ( + # input_ids, + # positions, + # seq_lens, + # slot_mapping, + # query_lens, + # past_slot_mapping, + # permute_map, + # ) = _prepare_eval_queries( + # eval_query_requests, + # cache.slot_mappings, + # None, + # model.dev, + # ) + + # logits = model.mod[func_name]( + # input_ids, + # positions, + # seq_lens, + # cache.cache, + # slot_mapping, + # query_lens, + # past_slot_mapping, + # permute_map, + # model.params, + # )[0].numpy() + + # verify_logits(logits, query_token_lens) + + # return query_token_lens = [3, 3, 3, 3] func_name = "decode_multi_query" From 487129cb2f6fffee76eec6e26a8617a9a42c0090 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 29 Jan 2024 23:47:56 +0000 Subject: [PATCH 10/15] all work --- examples/python/run_llama_batched_vllm.py | 131 +++++++++++++--------- mlc_llm/relax_model/llama_batched_vllm.py | 14 +-- 2 files changed, 83 insertions(+), 62 deletions(-) diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index 9e154f3a6b..aece4b081b 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -29,7 +29,14 @@ def __init__(self, num_blocks, block_size, num_layers, num_heads, head_size, ini class CacheManager: def __init__( - self, num_blocks, block_size, num_layers, num_heads, head_size, init_cache_func, sliding_window=None + self, + num_blocks, + block_size, + num_layers, + num_heads, + head_size, + init_cache_func, + sliding_window=None, ): self.block_size = block_size self.num_blocks = num_blocks @@ -199,7 +206,8 @@ def _prepare_inputs( for i in range(query_token_len): positions.append(len(token_ids) - (query_token_len - i)) - slot_mapping += all_slot_mappings[request_id][-query_token_len:] + + slot_mapping += all_slot_mappings[request_id][-query_token_len:] block_table = all_block_tables[request_id] max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table)) @@ -312,7 +320,15 @@ def _prepare_eval_queries( class Model: def __init__( - self, artifact_path, model_name, quant, vocab_size, num_shards, dev, sliding_window + self, + artifact_path, + model_name, + quant, + vocab_size, + num_shards, + dev, + sliding_window, + block_size, ): self.mod, self.params, self.disco_session = get_tvm_model( artifact_path, model_name, quant, num_shards, dev @@ -322,7 +338,8 @@ def __init__( self.sliding_window = sliding_window if sliding_window: - self.block_sliding_window = sliding_window // CacheManager.block_size + # TODO + self.block_sliding_window = sliding_window // block_size else: self.block_sliding_window = None @@ -440,6 +457,18 @@ def run(args): with open(os.path.join(model_path, "config.json"), encoding="utf-8") as i_f: config = LlamaConfig(**json.load(i_f)) + # TODO + use_flash_decoding = True + + if use_flash_decoding: + allocate_func_name = "tvm.contrib.flash_attn.allocate_kv_cache" + block_size = 256 + num_blocks = 30 + else: + allocate_func_name = "tvm.contrib.vllm.allocate_kv_cache" + block_size = 16 + num_blocks = 500 + model = Model( artifact_path, model_name, @@ -448,6 +477,7 @@ def run(args): args.num_shards, dev, config.sliding_window, + block_size, ) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False) @@ -455,17 +485,6 @@ def run(args): num_kv_heads = config.get_num_key_value_heads() // args.num_shards head_size = config.hidden_size // config.num_attention_heads - use_flash_decoding = True - - if use_flash_decoding: - allocate_func_name = "tvm.contrib.flash_attn.allocate_kv_cache" - block_size = 256 - num_blocks = 30 - else: - allocate_func_name = "tvm.contrib.vllm.allocate_kv_cache" - block_size = 16 - num_blocks = 500 - if model.disco_session: init_cache_func = model.disco_session.get_global_func(allocate_func_name) else: @@ -549,46 +568,48 @@ def verify_logits(logits, query_token_lens): logits_offset += query_token_len - # query_token_lens = [4, 3, 5, 2] - # func_name = "evaluate_multi_query" - - # eval_query_requests = [] - - # for request_id, query_token_len in zip(request_ids, query_token_lens): - # queries_to_eval = requests[request_id].token_ids[-query_token_len:] - # num_past = len(requests[request_id].token_ids) - query_token_len - # eval_query_requests.append(EvalQueryRequest(request_id, num_past, queries_to_eval)) - - # ( - # input_ids, - # positions, - # seq_lens, - # slot_mapping, - # query_lens, - # past_slot_mapping, - # permute_map, - # ) = _prepare_eval_queries( - # eval_query_requests, - # cache.slot_mappings, - # None, - # model.dev, - # ) - - # logits = model.mod[func_name]( - # input_ids, - # positions, - # seq_lens, - # cache.cache, - # slot_mapping, - # query_lens, - # past_slot_mapping, - # permute_map, - # model.params, - # )[0].numpy() - - # verify_logits(logits, query_token_lens) - - # return + query_token_lens = [4, 3, 5, 2] + func_name = "evaluate_multi_query" + + eval_query_requests = [] + + for request_id, query_token_len in zip(request_ids, query_token_lens): + queries_to_eval = requests[request_id].token_ids[-query_token_len:] + num_past = len(requests[request_id].token_ids) - query_token_len + eval_query_requests.append(EvalQueryRequest(request_id, num_past, queries_to_eval)) + + ( + input_ids, + positions, + seq_lens, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + ) = _prepare_eval_queries( + eval_query_requests, + cache.slot_mappings, + None, + model.dev, + ) + + logits = model.mod[func_name]( + input_ids, + positions, + seq_lens, + cache.cache, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + model.params, + )[0].numpy() + + verify_logits(logits, query_token_lens) + + if not use_flash_decoding: + return + query_token_lens = [3, 3, 3, 3] func_name = "decode_multi_query" diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index ed13e213cf..0ca1c06f20 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -174,10 +174,10 @@ def forward( if self.kv_type == KVCacheType.VLLM: num_kv_head = v_cache.struct_info.shape[1] + head_size = v_cache.struct_info.shape[2] else: num_kv_head = v_cache.struct_info.shape[2] - - head_size = v_cache.struct_info.shape[-1] + head_size = v_cache.struct_info.shape[-1] num_past_token = attn_input.aux_info.past_slot_mapping.struct_info.shape[0] kv_shape = (num_past_token, num_kv_head, head_size) @@ -794,20 +794,20 @@ def create_decoding_func( func_name = "decode" num_seq = tvm.tir.SizeVar("num_seq", "int64") - seqlen_q = tvm.tir.SizeVar("seqlen_q", "int64") - seqlen_q_info = [("decode", 1)] + func_names = ["decode"] if kv_type == KVCacheType.FlashDecoding: - seqlen_q_info.append(("decode_multi_query", seqlen_q)) + func_names.append("decode_multi_query") - for (func_name, seqlen_q) in seqlen_q_info: + for func_name in func_names: max_num_blocks_per_seq = tvm.tir.SizeVar("max_num_blocks_per_seq", "int64") - if seqlen_q == 1: + if func_name == "decode": num_query_token = num_seq input_shape = (num_query_token,) else: + seqlen_q = tvm.tir.SizeVar("seqlen_q", "int64") num_query_token = num_seq * seqlen_q input_shape = (num_seq, seqlen_q) From 8114197740d751fd2c7c6617ef84704df83cfc4e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 30 Jan 2024 00:06:50 +0000 Subject: [PATCH 11/15] add paged_kv_cache_type option --- mlc_llm/core.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index f7afbbb693..30ffc73d23 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -402,6 +402,10 @@ class BuildArgs: "action": "store_true", }, ) + paged_kv_cache_type: str = field( + default="vllm", + metadata={"help": "The type of paged KV cache, either vllm or flash-decoding"}, + ) @property def convert_weight_only(self): @@ -595,6 +599,9 @@ def mod_transform_before_build( model_names.append("evaluate") model_names.append("evaluate_multi_query") + if args.paged_kv_cache_type == "flash-decoding": + model_names.append("decode_multi_query") + if args.sep_embed: model_names = ["embed", "prefill_with_embed"] + model_names[1:] if args.enable_batching: @@ -706,6 +713,7 @@ def dump_build_config( config: Dict[str, Any] = { "num_shards": args.num_shards, "quantization": args.quantization.name, + "paged_kv_cache_type": args.paged_kv_cache_type, "library_name": args.lib_name, "build_options": str(args) } From 2d6c81b827c52d2353bad687efd18c8f49cc2447 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 30 Jan 2024 00:13:34 +0000 Subject: [PATCH 12/15] read kv_type from artifact --- examples/python/run_llama_batched_vllm.py | 18 +++++++++++++++--- mlc_llm/core.py | 1 + 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index aece4b081b..ddba9bcbe2 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -422,6 +422,17 @@ def generate( ] +def get_paged_kv_cache_type(model_artifact_path): + config_file_path = os.path.join(model_artifact_path, "build_config.json") + + assert os.path.exists(config_file_path) + + with open(config_file_path, mode="rt", encoding="utf-8") as f: + build_cfg = json.load(f) + + return build_cfg["paged_kv_cache_type"] + + def parse_args(): # Example # python build.py --model vicuna-v1-7b --quantization q4f16_ft --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention @@ -457,8 +468,9 @@ def run(args): with open(os.path.join(model_path, "config.json"), encoding="utf-8") as i_f: config = LlamaConfig(**json.load(i_f)) - # TODO - use_flash_decoding = True + kv_type = get_paged_kv_cache_type(args.artifact_path) + + use_flash_decoding = kv_type == "flash-decoding" if use_flash_decoding: allocate_func_name = "tvm.contrib.flash_attn.allocate_kv_cache" @@ -630,7 +642,7 @@ def verify_logits(logits, query_token_lens): cache.block_tables, model.sliding_window, model.dev, - False, + False, # is_prefill query_len, ) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 30ffc73d23..b834ff9c33 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -392,6 +392,7 @@ class BuildArgs: "action": "store_true", }, ) + # TODO(masahi): Remove the use of this option with paged_kv_cache_type use_vllm_attention: bool = field( default=False, metadata={ From 67353b21f10f3a8874fc505f4aee6e0f56386906 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 30 Jan 2024 00:31:03 +0000 Subject: [PATCH 13/15] black --- examples/python/run_llama_batched_vllm.py | 4 -- mlc_llm/relax_model/llama_batched_vllm.py | 48 +++++++++++++++-------- 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index ddba9bcbe2..1e0167724c 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -338,7 +338,6 @@ def __init__( self.sliding_window = sliding_window if sliding_window: - # TODO self.block_sliding_window = sliding_window // block_size else: self.block_sliding_window = None @@ -424,12 +423,10 @@ def generate( def get_paged_kv_cache_type(model_artifact_path): config_file_path = os.path.join(model_artifact_path, "build_config.json") - assert os.path.exists(config_file_path) with open(config_file_path, mode="rt", encoding="utf-8") as f: build_cfg = json.load(f) - return build_cfg["paged_kv_cache_type"] @@ -469,7 +466,6 @@ def run(args): config = LlamaConfig(**json.load(i_f)) kv_type = get_paged_kv_cache_type(args.artifact_path) - use_flash_decoding = kv_type == "flash-decoding" if use_flash_decoding: diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index 0ca1c06f20..24cb73dd64 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -314,7 +314,9 @@ def forward( num_seq = hidden_states.struct_info.shape[0] seqlen_q = 1 - queries = nn.emit(reshape(queries, (num_seq, seqlen_q, self.num_query_heads, self.head_dim))) + queries = nn.emit( + reshape(queries, (num_seq, seqlen_q, self.num_query_heads, self.head_dim)) + ) softmax_lse_accum = nn.emit( relax.op.builtin.alloc_tensor( @@ -328,7 +330,13 @@ def forward( output_accum = nn.emit( relax.op.builtin.alloc_tensor( relax.ShapeExpr( - (self.max_num_partitions, num_seq, self.num_query_heads, seqlen_q, self.head_dim) + ( + self.max_num_partitions, + num_seq, + self.num_query_heads, + seqlen_q, + self.head_dim, + ) ), dtype="float32", runtime_device_index=0, @@ -349,9 +357,7 @@ def forward( out_sinfo=queries.struct_info, ) - attn_output = nn.emit( - reshape(attn_output, hidden_states.struct_info.shape) - ) + attn_output = nn.emit(reshape(attn_output, hidden_states.struct_info.shape)) attn_output = self.o_proj(attn_output) return attn_output, (k_cache, v_cache) @@ -590,7 +596,13 @@ def get_logits_last_tokens(x, seq_len_tensor, seq_start): def get_inputs( - num_query_token, num_seq, input_shape, config, kv_type=None, max_num_blocks_per_seq=None, sep_embed=False + num_query_token, + num_seq, + input_shape, + config, + kv_type=None, + max_num_blocks_per_seq=None, + sep_embed=False, ): hidden_size = config.hidden_size @@ -627,12 +639,7 @@ def get_inputs( num_key_value_heads = config.get_num_key_value_heads() // config.num_shards head_size = hidden_size // config.num_attention_heads - k_cache_shape = ( - num_blocks, - block_size, - num_key_value_heads, - head_size - ) + k_cache_shape = (num_blocks, block_size, num_key_value_heads, head_size) v_cache_shape = k_cache_shape get_cache_sinfo = lambda i: relax.TensorStructInfo( @@ -731,7 +738,9 @@ def create_encoding_func( num_inputs = 5 with bb.function(func_name): - model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), kv_type, sep_embed) + model = LlamaForCausalLM( + config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), kv_type, sep_embed + ) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids, positions, seq_lens, past_key_values, slot_mapping, _ = get_inputs( @@ -817,7 +826,9 @@ def create_decoding_func( ) with bb.dataflow(): - model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), kv_type) + model = LlamaForCausalLM( + config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), kv_type + ) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) logits, new_kvs = model( @@ -866,7 +877,9 @@ def create_evaluate_multi_query_func( num_inputs = 8 with bb.function(func_name): - model = LlamaForCausalLM(config, cpu_dev, tvm.tir.Var("vocab_size", "int64"), kv_type, False) + model = LlamaForCausalLM( + config, cpu_dev, tvm.tir.Var("vocab_size", "int64"), kv_type, False + ) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids, positions, seq_lens, past_key_values, slot_mapping, _ = get_inputs( @@ -977,7 +990,10 @@ def get_model(args, hf_config): # The CPU device to copy the result of relax.op.max(seq_lens) to CPU. cpu_dev = VDevice("llvm", 0, "global") - kv_type = KVCacheType.FlashDecoding + if args.paged_kv_cache_type == "flash-decoding": + kv_type = KVCacheType.FlashDecoding + else: + kv_type = KVCacheType.VLLM create_evaluate_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) create_encoding_func(bb, param_manager, config, kv_type, cpu_dev, args.quantization, sep_embed) From b9e41e10f6429aa4b5a0ebecde8a81b8943c239e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 30 Jan 2024 04:38:49 +0000 Subject: [PATCH 14/15] refactor attention backend --- mlc_llm/relax_model/llama_batched_vllm.py | 417 +++++++++++++--------- 1 file changed, 250 insertions(+), 167 deletions(-) diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index 24cb73dd64..97e7656c9b 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -86,23 +86,241 @@ class AttentionInput: aux_info: Union[PrefillAttentionInput, DecodeAttentionInput, EvaluateMultiQueryInput] +class AttentionBackend: + def __init__(self, num_query_heads, num_key_value_heads, head_dim): + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + + def decode_attention( + self, + queries, + k_cache, + v_cache, + block_tables, + context_lens, + max_context_len, + num_seq, + seqlen_q, + ): + pass + + def update_cache(self, keys, values, k_cache, v_cache, slot_mapping): + pass + + def reconstruct_from_cache(self, k_cache, v_cache, past_slot_mapping): + pass + + +class VllmAttention(AttentionBackend): + block_size: int = 16 + + def __init__(self, num_query_heads, num_key_value_heads, head_dim, max_context_length): + super().__init__(num_query_heads, num_key_value_heads, head_dim) + + partition_size = 512 # partition_size in vLLM attention + self.max_num_partitions = (max_context_length + partition_size - 1) // partition_size + + def decode_attention( + self, + queries, + k_cache, + v_cache, + block_tables, + context_lens, + max_context_len, + num_seq, + seqlen_q, + ): + num_query_tokens = queries.struct_info.shape[0] + exp_sums = nn.emit( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr((num_query_tokens, self.num_query_heads, self.max_num_partitions)), + dtype="float32", + runtime_device_index=0, + ) + ) + max_logits = nn.emit( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr((num_query_tokens, self.num_query_heads, self.max_num_partitions)), + dtype="float32", + runtime_device_index=0, + ) + ) + tmp_out = nn.emit( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr( + ( + num_query_tokens, + self.num_query_heads, + self.max_num_partitions, + self.head_dim, + ) + ), + dtype=queries.struct_info.dtype, + runtime_device_index=0, + ) + ) + return nn.emit( + relax.op.call_dps_packed( + "tvm.contrib.vllm.single_query_cached_kv_attention", + [ + queries, + k_cache, + v_cache, + block_tables, + context_lens, + 16, # block_size + max_context_len, + exp_sums, + max_logits, + tmp_out, + ], + out_sinfo=queries.struct_info, + ) + ) + + def update_cache(self, keys, values, k_cache, v_cache, slot_mapping): + return nn.emit( + relax.op.call_pure_packed( + "tvm.contrib.vllm.reshape_and_cache", + keys, + values, + k_cache, + v_cache, + slot_mapping, + sinfo_args=[k_cache.struct_info, v_cache.struct_info], + ) + ) + + def reconstruct_from_cache(self, k_cache, v_cache, past_slot_mapping): + num_kv_head = v_cache.struct_info.shape[1] + head_size = v_cache.struct_info.shape[2] + + num_past_token = past_slot_mapping.struct_info.shape[0] + kv_shape = (num_past_token, num_kv_head, head_size) + kv_sinfo = relax.TensorStructInfo(kv_shape, k_cache.struct_info.dtype) + + return nn.emit( + relax.op.call_pure_packed( + "tvm.contrib.vllm.reconstruct_from_cache", + k_cache, + v_cache, + past_slot_mapping, + sinfo_args=[kv_sinfo, kv_sinfo], + ) + ) + + +class FlashDecodingAttention(AttentionBackend): + block_size: int = 256 + + def __init__(self, num_query_heads, num_key_value_heads, head_dim): + super().__init__(num_query_heads, num_key_value_heads, head_dim) + self.max_num_partitions = 128 + + def decode_attention( + self, + queries, + k_cache, + v_cache, + block_tables, + context_lens, + max_context_len, + num_seq, + seqlen_q, + ): + queries = nn.emit( + reshape(queries, (num_seq, seqlen_q, self.num_query_heads, self.head_dim)) + ) + + softmax_lse_accum = nn.emit( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr((self.max_num_partitions, num_seq, self.num_query_heads, seqlen_q)), + dtype="float32", + runtime_device_index=0, + ) + ) + output_accum = nn.emit( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr( + ( + self.max_num_partitions, + num_seq, + self.num_query_heads, + seqlen_q, + self.head_dim, + ) + ), + dtype="float32", + runtime_device_index=0, + ) + ) + + return R.call_dps_packed( + "tvm.contrib.flash_attn.flash_decoding_with_paged_kvcache", + [ + queries, + k_cache, + v_cache, + block_tables, + context_lens, + softmax_lse_accum, + output_accum, + ], + out_sinfo=queries.struct_info, + ) + + def update_cache(self, keys, values, k_cache, v_cache, slot_mapping): + return nn.emit( + relax.op.call_pure_packed( + "tvm.contrib.flash_attn.update_cache", + keys, + values, + k_cache, + v_cache, + slot_mapping, + sinfo_args=[k_cache.struct_info, v_cache.struct_info], + ) + ) + + def reconstruct_from_cache(self, k_cache, v_cache, past_slot_mapping): + num_kv_head = v_cache.struct_info.shape[2] + head_size = v_cache.struct_info.shape[-1] + + num_past_token = past_slot_mapping.struct_info.shape[0] + kv_shape = (num_past_token, num_kv_head, head_size) + kv_sinfo = relax.TensorStructInfo(kv_shape, k_cache.struct_info.dtype) + + return nn.emit( + relax.op.call_pure_packed( + "tvm.contrib.flash_attn.reconstruct_from_cache", + k_cache, + v_cache, + past_slot_mapping, + sinfo_args=[kv_sinfo, kv_sinfo], + ) + ) + + class LlamaAttentionBatched(LlamaAttentionBase): def __init__(self, config: LlamaConfig, kv_type: KVCacheType): super().__init__(config) - self.kv_type = kv_type + if kv_type == KVCacheType.VLLM: + max_context_length = config.sliding_window or config.max_sequence_length + self.attn_backend = VllmAttention( + self.num_query_heads, self.num_key_value_heads, self.head_dim, max_context_length + ) + else: + self.attn_backend = FlashDecodingAttention( + self.num_query_heads, self.num_key_value_heads, self.head_dim + ) + self.sliding_window = None if config.sliding_window: self.sliding_window = T.IntImm("int32", config.sliding_window) - max_context_length = config.sliding_window or config.max_sequence_length - - if kv_type == KVCacheType.VLLM: - partition_size = 512 # partition_size in vLLM attention - self.max_num_partitions = (max_context_length + partition_size - 1) // partition_size - else: - self.max_num_partitions = 128 - def forward( self, hidden_states: relax.Expr, # (num_query_token, hidden_size) or (num_seq, seqlen_q, hidden_size) @@ -140,31 +358,9 @@ def forward( slot_mapping = attn_input.slot_mapping # kv caches are updated inplace, but make it look like a pure operation - if self.kv_type == KVCacheType.VLLM: - kv = nn.emit( - relax.op.call_pure_packed( - "tvm.contrib.vllm.reshape_and_cache", - keys_to_cache, - values_to_cache, - k_cache, - v_cache, - slot_mapping, - sinfo_args=[k_cache.struct_info, v_cache.struct_info], - ) - ) - else: - kv = nn.emit( - relax.op.call_pure_packed( - "tvm.contrib.flash_attn.update_cache", - keys_to_cache, - values_to_cache, - k_cache, - v_cache, - slot_mapping, - sinfo_args=[k_cache.struct_info, v_cache.struct_info], - ) - ) - + kv = self.attn_backend.update_cache( + keys_to_cache, values_to_cache, k_cache, v_cache, slot_mapping + ) k_cache, v_cache = kv[0], kv[1] else: k_cache = v_cache = None @@ -172,38 +368,9 @@ def forward( if isinstance(attn_input.aux_info, EvaluateMultiQueryInput): assert k_cache and v_cache - if self.kv_type == KVCacheType.VLLM: - num_kv_head = v_cache.struct_info.shape[1] - head_size = v_cache.struct_info.shape[2] - else: - num_kv_head = v_cache.struct_info.shape[2] - head_size = v_cache.struct_info.shape[-1] - - num_past_token = attn_input.aux_info.past_slot_mapping.struct_info.shape[0] - kv_shape = (num_past_token, num_kv_head, head_size) - kv_sinfo = relax.TensorStructInfo(kv_shape, k_cache.struct_info.dtype) - - if self.kv_type == KVCacheType.VLLM: - kv_tensors = nn.emit( - relax.op.call_pure_packed( - "tvm.contrib.vllm.reconstruct_from_cache", - k_cache, - v_cache, - attn_input.aux_info.past_slot_mapping, - sinfo_args=[kv_sinfo, kv_sinfo], - ) - ) - else: - kv_tensors = nn.emit( - relax.op.call_pure_packed( - "tvm.contrib.flash_attn.reconstruct_from_cache", - k_cache, - v_cache, - attn_input.aux_info.past_slot_mapping, - sinfo_args=[kv_sinfo, kv_sinfo], - ) - ) - + kv_tensors = self.attn_backend.reconstruct_from_cache( + k_cache, v_cache, attn_input.aux_info.past_slot_mapping + ) keys_past, values_past = kv_tensors[0], kv_tensors[1] # Say we have past tokens [P1, P2, P3] and the current ones [C1, C2, C3]. # Each of P1, C1 etc is a sequence of tokens. @@ -256,106 +423,22 @@ def forward( # Decode, using vLLM or Flash-Decoding kernel assert isinstance(attn_input.aux_info, DecodeAttentionInput) - if self.kv_type == KVCacheType.VLLM: - exp_sums = nn.emit( - relax.op.builtin.alloc_tensor( - relax.ShapeExpr( - (num_query_tokens, self.num_query_heads, self.max_num_partitions) - ), - dtype="float32", - runtime_device_index=0, - ) - ) - max_logits = nn.emit( - relax.op.builtin.alloc_tensor( - relax.ShapeExpr( - (num_query_tokens, self.num_query_heads, self.max_num_partitions) - ), - dtype="float32", - runtime_device_index=0, - ) - ) - tmp_out = nn.emit( - relax.op.builtin.alloc_tensor( - relax.ShapeExpr( - ( - num_query_tokens, - self.num_query_heads, - self.max_num_partitions, - self.head_dim, - ) - ), - dtype=queries.struct_info.dtype, - runtime_device_index=0, - ) - ) - attn_output = nn.emit( - relax.op.call_dps_packed( - "tvm.contrib.vllm.single_query_cached_kv_attention", - [ - queries, - k_cache, - v_cache, - attn_input.aux_info.block_tables, - attn_input.aux_info.seq_lens, - 16, # block_size - attn_input.max_seqlen, - exp_sums, - max_logits, - tmp_out, - ], - out_sinfo=queries.struct_info, - ) - ) + if len(hidden_states.struct_info.shape) == 3: + num_seq, seqlen_q, _ = hidden_states.struct_info.shape else: - if len(hidden_states.struct_info.shape) == 3: - num_seq, seqlen_q, _ = hidden_states.struct_info.shape - else: - num_seq = hidden_states.struct_info.shape[0] - seqlen_q = 1 - - queries = nn.emit( - reshape(queries, (num_seq, seqlen_q, self.num_query_heads, self.head_dim)) - ) - - softmax_lse_accum = nn.emit( - relax.op.builtin.alloc_tensor( - relax.ShapeExpr( - (self.max_num_partitions, num_seq, self.num_query_heads, seqlen_q) - ), - dtype="float32", - runtime_device_index=0, - ) - ) - output_accum = nn.emit( - relax.op.builtin.alloc_tensor( - relax.ShapeExpr( - ( - self.max_num_partitions, - num_seq, - self.num_query_heads, - seqlen_q, - self.head_dim, - ) - ), - dtype="float32", - runtime_device_index=0, - ) - ) - - attn_output = R.call_dps_packed( - "tvm.contrib.flash_attn.flash_decoding_with_paged_kvcache", - [ - queries, - k_cache, - v_cache, - attn_input.aux_info.block_tables, - attn_input.aux_info.seq_lens, - softmax_lse_accum, - output_accum, - ], - out_sinfo=queries.struct_info, - ) + num_seq = hidden_states.struct_info.shape[0] + seqlen_q = 1 + + attn_output = self.attn_backend.decode_attention( + queries, + k_cache, + v_cache, + attn_input.aux_info.block_tables, + attn_input.aux_info.seq_lens, + attn_input.max_seqlen, + num_seq, + seqlen_q, + ) attn_output = nn.emit(reshape(attn_output, hidden_states.struct_info.shape)) attn_output = self.o_proj(attn_output) @@ -619,7 +702,7 @@ def get_inputs( num_blocks = tvm.tir.Var("num_blocks", "int64") if kv_type == KVCacheType.VLLM: - block_size = 16 + block_size = VllmAttention.block_size vec_size = 8 # 128 bit, fp16 x 8 num_key_value_heads = config.get_num_key_value_heads() // config.num_shards @@ -634,7 +717,7 @@ def get_inputs( ) v_cache_shape = (num_blocks, num_key_value_heads, head_size, block_size) else: - block_size = 256 + block_size = FlashDecodingAttention.block_size num_key_value_heads = config.get_num_key_value_heads() // config.num_shards head_size = hidden_size // config.num_attention_heads From 910e31b90f083f0a722abaa15d8605a4960f9a03 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 30 Jan 2024 10:22:57 +0000 Subject: [PATCH 15/15] minor clean up --- examples/python/run_llama_batched_vllm.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index 1e0167724c..98cbbe9c3e 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -577,8 +577,6 @@ def verify_logits(logits, query_token_lens): logits_offset += query_token_len query_token_lens = [4, 3, 5, 2] - func_name = "evaluate_multi_query" - eval_query_requests = [] for request_id, query_token_len in zip(request_ids, query_token_lens): @@ -601,7 +599,7 @@ def verify_logits(logits, query_token_lens): model.dev, ) - logits = model.mod[func_name]( + logits = model.mod["evaluate_multi_query"]( input_ids, positions, seq_lens, @@ -619,10 +617,7 @@ def verify_logits(logits, query_token_lens): return query_token_lens = [3, 3, 3, 3] - func_name = "decode_multi_query" - decode_multi_query_requests = requests - query_len = query_token_lens[0] ( @@ -644,7 +639,7 @@ def verify_logits(logits, query_token_lens): input_ids = tvm.nd.array(np.reshape(input_ids.numpy(), [-1, query_len]), dev) - logits = model.mod[func_name]( + logits = model.mod["decode_multi_query"]( input_ids, positions, seq_lens,