Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update model definition to support Flash-Decoding #177

Merged
merged 15 commits into from
Jan 30, 2024
150 changes: 117 additions & 33 deletions examples/python/run_llama_batched_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,29 @@


class KVCache:
def __init__(self, num_blocks, block_size, num_layers, num_heads, head_size, disco_session):
if disco_session:
init_cache_func = disco_session.get_global_func("tvm.contrib.vllm.allocate_kv_cache")
else:
init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache")

def __init__(self, num_blocks, block_size, num_layers, num_heads, head_size, init_cache_func):
self.cache = init_cache_func(head_size, num_layers, num_heads, block_size, num_blocks)

self.block_tables = defaultdict(list)
self.slot_mappings = defaultdict(list)
self.block_size = block_size


class CacheManager:
block_size: int = 16

def __init__(
self, num_blocks, num_layers, num_heads, head_size, disco_session=None, sliding_window=None
self,
num_blocks,
block_size,
num_layers,
num_heads,
head_size,
init_cache_func,
sliding_window=None,
):
self.block_size = block_size
self.num_blocks = num_blocks
self.free_blocks = list(range(num_blocks))
self.kv_cache = KVCache(
num_blocks, self.block_size, num_layers, num_heads, head_size, disco_session
num_blocks, self.block_size, num_layers, num_heads, head_size, init_cache_func
)

if sliding_window:
Expand Down Expand Up @@ -172,6 +172,7 @@ def _prepare_inputs(
sliding_window,
dev,
is_prefill,
query_token_len=1,
):
block_tables = []
seq_lens = []
Expand Down Expand Up @@ -201,13 +202,16 @@ def _prepare_inputs(
start_idx += prompt_len

else:
input_ids.append(token_ids[-1])
pos = len(token_ids) - 1
positions.append(pos)
input_ids += token_ids[-query_token_len:]

for i in range(query_token_len):
positions.append(len(token_ids) - (query_token_len - i))

slot_mapping += all_slot_mappings[request_id][-query_token_len:]

block_table = all_block_tables[request_id]
max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table))
block_tables.append(block_table)
slot_mapping.append(all_slot_mappings[request_id][-1])

if sliding_window:
seq_lens.append(min(len(token_ids), sliding_window))
Expand Down Expand Up @@ -316,7 +320,15 @@ def _prepare_eval_queries(

class Model:
def __init__(
self, artifact_path, model_name, quant, vocab_size, num_shards, dev, sliding_window
self,
artifact_path,
model_name,
quant,
vocab_size,
num_shards,
dev,
sliding_window,
block_size,
):
self.mod, self.params, self.disco_session = get_tvm_model(
artifact_path, model_name, quant, num_shards, dev
Expand All @@ -326,7 +338,7 @@ def __init__(
self.sliding_window = sliding_window

if sliding_window:
self.block_sliding_window = sliding_window // CacheManager.block_size
self.block_sliding_window = sliding_window // block_size
else:
self.block_sliding_window = None

Expand Down Expand Up @@ -409,6 +421,15 @@ def generate(
]


def get_paged_kv_cache_type(model_artifact_path):
config_file_path = os.path.join(model_artifact_path, "build_config.json")
assert os.path.exists(config_file_path)

with open(config_file_path, mode="rt", encoding="utf-8") as f:
build_cfg = json.load(f)
return build_cfg["paged_kv_cache_type"]


def parse_args():
# Example
# python build.py --model vicuna-v1-7b --quantization q4f16_ft --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention
Expand Down Expand Up @@ -444,6 +465,18 @@ def run(args):
with open(os.path.join(model_path, "config.json"), encoding="utf-8") as i_f:
config = LlamaConfig(**json.load(i_f))

kv_type = get_paged_kv_cache_type(args.artifact_path)
use_flash_decoding = kv_type == "flash-decoding"

if use_flash_decoding:
allocate_func_name = "tvm.contrib.flash_attn.allocate_kv_cache"
block_size = 256
num_blocks = 30
else:
allocate_func_name = "tvm.contrib.vllm.allocate_kv_cache"
block_size = 16
num_blocks = 500

model = Model(
artifact_path,
model_name,
Expand All @@ -452,20 +485,26 @@ def run(args):
args.num_shards,
dev,
config.sliding_window,
block_size,
)

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False)

num_kv_heads = config.get_num_key_value_heads() // args.num_shards
head_size = config.hidden_size // config.num_attention_heads
num_blocks = 500

if model.disco_session:
init_cache_func = model.disco_session.get_global_func(allocate_func_name)
else:
init_cache_func = tvm.get_global_func(allocate_func_name)

cache_manager = CacheManager(
num_blocks,
block_size,
config.num_hidden_layers,
num_kv_heads,
head_size,
model.disco_session,
init_cache_func,
sliding_window=config.sliding_window,
)
cache = cache_manager.get()
Expand Down Expand Up @@ -516,8 +555,28 @@ def run(args):
for p, g in zip(prompts, generated):
print("Prompt = '{}', generated text = '{}'".format(p, g))

query_token_lens = [4, 3, 5, 2]
if model.disco_session:
return

def verify_logits(logits, query_token_lens):
assert logits.shape[0] == sum(query_token_lens)

logits_offset = 0

for request_id, query_token_len in zip(request_ids, query_token_lens):
for i in range(query_token_len - 1):
# requests[request_id].token_ids[-query_token_len:] are the "ground truth" tokens.
# Doing argmax over multi-timestep logits computed in parallel should yield the same
# tokens at the corresponding positions.
past_tokens = requests[request_id].token_ids[:-query_token_len]
assert (
np.argmax(logits[logits_offset + i])
== requests[request_id].token_ids[len(past_tokens) + i + 1]
)

logits_offset += query_token_len

query_token_lens = [4, 3, 5, 2]
eval_query_requests = []

for request_id, query_token_len in zip(request_ids, query_token_lens):
Expand Down Expand Up @@ -552,22 +611,47 @@ def run(args):
model.params,
)[0].numpy()

assert logits.shape[0] == sum(query_token_lens)
verify_logits(logits, query_token_lens)

logits_offset = 0
if not use_flash_decoding:
return

for request_id, query_token_len in zip(request_ids, query_token_lens):
for i in range(query_token_len - 1):
# requests[request_id].token_ids[-query_token_len:] are the "ground truth" tokens.
# Doing argmax over multi-timestep logits computed in parallel should yield the same
# tokens at the corresponding positions.
past_tokens = requests[request_id].token_ids[:-query_token_len]
assert (
np.argmax(logits[logits_offset + i])
== requests[request_id].token_ids[len(past_tokens) + i + 1]
)
query_token_lens = [3, 3, 3, 3]
decode_multi_query_requests = requests
query_len = query_token_lens[0]

(
input_ids,
positions,
seq_lens,
slot_mapping,
_,
block_tables,
) = _prepare_inputs(
decode_multi_query_requests,
cache.slot_mappings,
cache.block_tables,
model.sliding_window,
model.dev,
False, # is_prefill
query_len,
)

input_ids = tvm.nd.array(np.reshape(input_ids.numpy(), [-1, query_len]), dev)

logits = model.mod["decode_multi_query"](
input_ids,
positions,
seq_lens,
cache.cache,
slot_mapping,
block_tables,
model.params,
)[0].numpy()

logits = np.reshape(logits, (-1, logits.shape[-1]))

logits_offset += query_token_len
verify_logits(logits, query_token_lens)


if __name__ == "__main__":
Expand Down
9 changes: 9 additions & 0 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ class BuildArgs:
"action": "store_true",
},
)
# TODO(masahi): Remove the use of this option with paged_kv_cache_type
use_vllm_attention: bool = field(
default=False,
metadata={
Expand All @@ -402,6 +403,10 @@ class BuildArgs:
"action": "store_true",
},
)
paged_kv_cache_type: str = field(
default="vllm",
metadata={"help": "The type of paged KV cache, either vllm or flash-decoding"},
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new option makes --use_vllm_attention obsolete. Since removing it is a breaking change, I'll do that later when I integrate Flash-Decoding into mlc-serve. @sunggg

)

@property
def convert_weight_only(self):
Expand Down Expand Up @@ -595,6 +600,9 @@ def mod_transform_before_build(
model_names.append("evaluate")
model_names.append("evaluate_multi_query")

if args.paged_kv_cache_type == "flash-decoding":
model_names.append("decode_multi_query")

if args.sep_embed:
model_names = ["embed", "prefill_with_embed"] + model_names[1:]
if args.enable_batching:
Expand Down Expand Up @@ -706,6 +714,7 @@ def dump_build_config(
config: Dict[str, Any] = {
"num_shards": args.num_shards,
"quantization": args.quantization.name,
"paged_kv_cache_type": args.paged_kv_cache_type,
"library_name": args.lib_name,
"build_options": str(args)
}
Expand Down
Loading
Loading