Skip to content

Commit

Permalink
[Core] Skip PrimExpr index int32 downcasting for batching
Browse files Browse the repository at this point in the history
This PR makes the ForceNarrowIndexToInt32 to skip application
when batching is enabled.

The reason is because the flattened index of the KV cache append
function may exceed the range of int32 when the cache is large.
For example, in Llama-7b, when a KV cache supports more than
8192 tokens, the total cache size will be at least
```
8192 * 2 (K/V) * 32 (layers) * 4096 = 2147483648,
```
which reaches the maximum int32 value.
  • Loading branch information
MasterJH5574 committed Nov 11, 2023
1 parent c12fe04 commit a54b4bd
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@
rwkv,
stablelm_3b,
)
from mlc_llm.relax_model.commons import create_shard_info_func, create_shard_transformation_func
from mlc_llm.relax_model.param_manager import transform_params_for_each_rank, chain_parameter_transforms
from mlc_llm.relax_model.commons import (
create_shard_info_func,
create_shard_transformation_func,
)
from mlc_llm.relax_model.param_manager import (
chain_parameter_transforms,
transform_params_for_each_rank,
)
from mlc_llm.transform import fuse_split_rotary_embedding, rewrite_attention


Expand Down Expand Up @@ -679,6 +685,7 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:
mod_deploy
)
)
if not args.enable_batching:
mod_deploy = tvm.tir.transform.ForceNarrowIndexToInt32()(mod_deploy)

if args.debug_load_script:
Expand Down Expand Up @@ -806,7 +813,9 @@ def build_model_from_args(args: argparse.Namespace):
mod_transform = seq(mod_transform)

params = utils.convert_weights(mod_transform, param_manager, params, args)
utils.save_params(params, args.artifact_path, args.num_shards if args.use_presharded_weights else 1)
utils.save_params(
params, args.artifact_path, args.num_shards if args.use_presharded_weights else 1
)

if args.model_category != "minigpt":
utils.copy_tokenizer(args)
Expand Down

0 comments on commit a54b4bd

Please sign in to comment.