Skip to content

Commit

Permalink
Support exporting Nemotron-340B for TensorRT-LLM (NVIDIA#11015)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinyang Yuan <[email protected]>
Co-authored-by: Jinyang Yuan <[email protected]>
Co-authored-by: meatybobby <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>
  • Loading branch information
3 people authored and Hainan Xu committed Nov 5, 2024
1 parent dcd0446 commit e75e99c
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions scripts/export/export_to_trt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_args(argv):
parser.add_argument(
"-mr", "--model_repository", required=True, default=None, type=str, help="Folder for the trt-llm model files"
)
parser.add_argument("-ng", "--num_gpus", default=1, type=int, help="Number of GPUs for the deployment")
parser.add_argument("-ng", "--num_gpus", default=None, type=int, help="Number of GPUs for the deployment")
parser.add_argument("-tps", "--tensor_parallelism_size", default=1, type=int, help="Tensor parallelism size")
parser.add_argument("-pps", "--pipeline_parallelism_size", default=1, type=int, help="Pipeline parallelism size")
parser.add_argument(
Expand All @@ -64,7 +64,14 @@ def get_args(argv):
"-mpet", "--max_prompt_embedding_table_size", default=None, type=int, help="Max prompt embedding table size"
)
parser.add_argument(
"-npkc", "--no_paged_kv_cache", default=False, action='store_true', help="Enable paged kv cache."
"-upe",
"--use_parallel_embedding",
default=False,
action='store_true',
help="Use parallel embedding.",
)
parser.add_argument(
"-npkc", "--no_paged_kv_cache", default=False, action='store_true', help="Disable paged kv cache."
)
parser.add_argument(
"-drip",
Expand Down Expand Up @@ -183,6 +190,7 @@ def nemo_export_trt_llm(argv):
max_num_tokens=args.max_num_tokens,
opt_num_tokens=args.opt_num_tokens,
max_prompt_embedding_table_size=args.max_prompt_embedding_table_size,
use_parallel_embedding=args.use_parallel_embedding,
paged_kv_cache=(not args.no_paged_kv_cache),
remove_input_padding=(not args.disable_remove_input_padding),
dtype=args.dtype,
Expand All @@ -191,6 +199,7 @@ def nemo_export_trt_llm(argv):
max_lora_rank=args.max_lora_rank,
fp8_quantized=args.export_fp8_quantized,
fp8_kvcache=args.use_fp8_kv_cache,
load_model=False,
)

LOGGER.info("Export is successful.")
Expand Down

0 comments on commit e75e99c

Please sign in to comment.