diff --git a/scripts/export/export_to_trt_llm.py b/scripts/export/export_to_trt_llm.py index 6b246131b69e..d9e846547c68 100644 --- a/scripts/export/export_to_trt_llm.py +++ b/scripts/export/export_to_trt_llm.py @@ -44,7 +44,7 @@ def get_args(argv): parser.add_argument( "-mr", "--model_repository", required=True, default=None, type=str, help="Folder for the trt-llm model files" ) - parser.add_argument("-ng", "--num_gpus", default=1, type=int, help="Number of GPUs for the deployment") + parser.add_argument("-ng", "--num_gpus", default=None, type=int, help="Number of GPUs for the deployment") parser.add_argument("-tps", "--tensor_parallelism_size", default=1, type=int, help="Tensor parallelism size") parser.add_argument("-pps", "--pipeline_parallelism_size", default=1, type=int, help="Pipeline parallelism size") parser.add_argument( @@ -64,7 +64,14 @@ def get_args(argv): "-mpet", "--max_prompt_embedding_table_size", default=None, type=int, help="Max prompt embedding table size" ) parser.add_argument( - "-npkc", "--no_paged_kv_cache", default=False, action='store_true', help="Enable paged kv cache." + "-upe", + "--use_parallel_embedding", + default=False, + action='store_true', + help="Use parallel embedding.", + ) + parser.add_argument( + "-npkc", "--no_paged_kv_cache", default=False, action='store_true', help="Disable paged kv cache." ) parser.add_argument( "-drip", @@ -183,6 +190,7 @@ def nemo_export_trt_llm(argv): max_num_tokens=args.max_num_tokens, opt_num_tokens=args.opt_num_tokens, max_prompt_embedding_table_size=args.max_prompt_embedding_table_size, + use_parallel_embedding=args.use_parallel_embedding, paged_kv_cache=(not args.no_paged_kv_cache), remove_input_padding=(not args.disable_remove_input_padding), dtype=args.dtype, @@ -191,6 +199,7 @@ def nemo_export_trt_llm(argv): max_lora_rank=args.max_lora_rank, fp8_quantized=args.export_fp8_quantized, fp8_kvcache=args.use_fp8_kv_cache, + load_model=False, ) LOGGER.info("Export is successful.")