From 798500cd470a05628424807a6b991081e4bea915 Mon Sep 17 00:00:00 2001 From: Siddharth Venkatesan Date: Fri, 24 Jan 2025 08:36:38 -0800 Subject: [PATCH] [lmi][vllm][trtllm] add support for generation parameters fron generation_config.json --- .../properties_manager/properties.py | 1 + .../rolling_batch/lmi_dist_rolling_batch.py | 26 +++++------ .../rolling_batch/trtllm_rolling_batch.py | 30 ++++++++++++- .../rolling_batch/vllm_rolling_batch.py | 14 +++--- .../lmi/user_guides/lmi-dist_user_guide.md | 45 ++++++++++--------- .../lmi/user_guides/trt_llm_user_guide.md | 2 +- .../docs/lmi/user_guides/vllm_user_guide.md | 38 ++++++++-------- 7 files changed, 91 insertions(+), 65 deletions(-) diff --git a/engines/python/setup/djl_python/properties_manager/properties.py b/engines/python/setup/djl_python/properties_manager/properties.py index 2b17c84db..a82ab80c8 100644 --- a/engines/python/setup/djl_python/properties_manager/properties.py +++ b/engines/python/setup/djl_python/properties_manager/properties.py @@ -64,6 +64,7 @@ class Properties(BaseModel): tgi_compat: Optional[bool] = False bedrock_compat: Optional[bool] = False enable_lora: Optional[bool] = False + generation_config: Optional[str] = None # Spec_dec draft_model_id: Optional[str] = None diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index 0c2dd666a..443847f20 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -13,14 +13,14 @@ import logging import os from typing import List, Optional -from collections import OrderedDict, defaultdict +from collections import OrderedDict from lmi_dist.api import Request, RequestParams from lmi_dist.arg_utils import VllmEngineArgs from lmi_dist.init_engine import engine_from_args from lmi_dist.seq2seq_engine import Seq2SeqPreprocessor -from vllm import SamplingParams from vllm.utils import AtomicCounter +from vllm.engine.llm_engine import _load_generation_config_dict from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params from djl_python.rolling_batch.rolling_batch_vllm_utils import ( @@ -96,15 +96,6 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs): kwargs = {} logging.info(f"engine_args: {engine_args}, kwargs: {kwargs}") - if self.lmi_dist_config.max_rolling_batch_prefill_tokens is None: - logging.warning( - "djl-serving/lmi has changed the default behavior for max_rolling_batch_prefill_tokens in 0.30.0 (lmi v12). " - "Previously, when max_rolling_batch_prefill_tokens was unset, djl-serving would use a warmup prefill limit of 4096 tokens. " - "This behavior differs from vLLM's default behavior, which (essentially) defaults to max_model_len. As a result of this change, " - "model deployments that worked previously may fail due to higher memory requirements at model loading time for the warmup phase. " - "For more information on this change, and guidance on what configurations to set, please see " - "https://github.com/deepjavalibrary/djl-serving/tree/master/serving/docs/lmi/announcements/breaking_changes.md" - ) self.engine = engine_from_args(engine_args, **kwargs) self.request_cache = OrderedDict() self.lora_id_counter = AtomicCounter(0) @@ -112,6 +103,9 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs): self.is_mistral_tokenizer = self.lmi_dist_config.tokenizer_mode == 'mistral' self.is_t5_model = isinstance(self.engine.preprocessor, Seq2SeqPreprocessor) + self.default_generation_params = _load_generation_config_dict( + self.engine.model_config + ) if self.lmi_dist_config.generation_config == 'auto' else {} def reset(self) -> None: """ @@ -139,11 +133,6 @@ def get_huggingface_model_config(self): # an interface method and retrieve it from there after v12 return self.engine.preprocessor.model_config.hf_config if not self.is_t5_model else None - def get_huggingface_model_config(self): - # TODO: this is a hack right now to get the model config from the engine. We should expose this as - # an interface method and retrieve it from there after v12 - return self.engine.preprocessor.model_config.hf_config if not self.is_t5_model else None - def translate_lmi_dist_params(self, parameters: dict): """ Helper function to convert DJL Serving parameter names to parameter names @@ -154,6 +143,11 @@ def translate_lmi_dist_params(self, parameters: dict): :return: The same parameters dict, but with lmi-dist style parameter names. """ parameters["max_tokens"] = parameters.pop("max_new_tokens", 30) + # when a default generation_config.json is provided, we still respect any overrides + # sent directly from the request + for k, v in self.default_generation_params.items(): + if k not in parameters and k in LMI_DIST_GENERATION_PARAMS: + parameters[k] = v do_sample = parameters.pop("do_sample", None) if do_sample is not None and do_sample is False: parameters["temperature"] = 0.0 diff --git a/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py index e1ba4651c..2cc60cb10 100644 --- a/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py @@ -16,7 +16,24 @@ from djl_python.request import Request from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception from djl_python.request_io import Token -from typing import List +from typing import List, Optional +from transformers import GenerationConfig + + +# https://github.com/vllm-project/vllm/blame/3132a933b65d8ed3383e082264c682940d92d803/vllm/config.py#L873 +def try_get_generation_config( + model: str, + trust_remote_code: bool, + revision: Optional[str] = None, +) -> dict: + try: + return GenerationConfig.from_pretrained( + model, + trust_remote_code=trust_remote_code, + revision=revision, + ).to_diff_dict() + except OSError: # Not found + return {} class TRTLLMRollingBatch(RollingBatch): @@ -41,6 +58,9 @@ def __init__(self, model_id_or_path: str, properties: dict, self.model = tensorrt_llm_toolkit.init_inference( model_id_or_path, **properties) self.request_cache = {} + self.default_generation_params = try_get_generation_config( + self.model.model.model_path, configs.trust_remote_code, + configs.revision) if configs.generation_config == 'auto' else {} def get_tokenizer(self): return self.model.tokenizer @@ -63,6 +83,11 @@ def translate_triton_params(self, parameters: dict) -> dict: :return: The same parameters dict, but with TensorRT-LLM style parameter names. """ + # when a default generation_config.json is provided, we still respect any overrides + # sent directly from the request + for k, v in self.default_generation_params.items(): + if k not in parameters: + parameters[k] = v if "request_output_len" not in parameters: parameters["request_output_len"] = parameters.pop( "max_new_tokens", 30) @@ -72,7 +97,8 @@ def translate_triton_params(self, parameters: dict) -> dict: parameters["runtime_top_p"] = parameters.pop("top_p") if "seed" in parameters: parameters["random_seed"] = int(parameters.pop("seed")) - if parameters.pop("do_sample", False): + do_sample = parameters.pop("do_sample", None) + if do_sample is not None and do_sample: parameters["runtime_top_k"] = parameters.get("runtime_top_k", 5) parameters["runtime_top_p"] = parameters.get("runtime_top_p", 0.85) parameters["temperature"] = parameters.get("temperature", 0.8) diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index 45262af82..1869099e8 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -10,8 +10,7 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. -import logging -from collections import OrderedDict, defaultdict +from collections import OrderedDict from vllm import LLMEngine, SamplingParams from vllm.utils import random_uuid, AtomicCounter @@ -24,10 +23,7 @@ from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties from typing import List, Optional -# FIXME: Once all vllm versions are past 0.6.0 we can move to just struct_fields -VLLM_GENERATION_PARAMS = set(SamplingParams().__struct_fields__) if hasattr( - SamplingParams(), "__struct_fields__") else set( - SamplingParams().__dict__.keys()) +VLLM_GENERATION_PARAMS = set(SamplingParams().__struct_fields__) class VLLMRollingBatch(RollingBatch): @@ -54,6 +50,7 @@ def __init__(self, model_id_or_path: str, properties: dict, self.lora_id_counter = AtomicCounter(0) self.lora_requests = {} self.is_mistral_tokenizer = self.vllm_configs.tokenizer_mode == 'mistral' + self.default_generation_params = self.engine.generation_config_fields if self.vllm_configs.generation_config == 'auto' else {} def get_tokenizer(self): return self.engine.tokenizer.tokenizer @@ -85,6 +82,11 @@ def translate_vllm_params(self, parameters: dict) -> dict: :return: The same parameters dict, but with VLLM style parameter names. """ + # when a default generation_config.json is provided, we still respect any overrides + # sent directly from the request + for k, v in self.default_generation_params.items(): + if k not in parameters and k in VLLM_GENERATION_PARAMS: + parameters[k] = v parameters["max_tokens"] = parameters.pop("max_new_tokens", 30) do_sample = parameters.pop("do_sample", None) if do_sample is not None and do_sample is False: diff --git a/serving/docs/lmi/user_guides/lmi-dist_user_guide.md b/serving/docs/lmi/user_guides/lmi-dist_user_guide.md index 030350b91..5df95cdb1 100644 --- a/serving/docs/lmi/user_guides/lmi-dist_user_guide.md +++ b/serving/docs/lmi/user_guides/lmi-dist_user_guide.md @@ -126,25 +126,26 @@ If you omit the `option.quantize` configuration, then the engine will determine Here are the advanced parameters that are available when using LMI-Dist. -| Item | LMI Version | Configuration Type | Description | Example value | -|-----------------------------------------|-------------|--------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------| -| option.quantize | \>= 0.23.0 | LMI | Quantize the model with the supported quantization methods(`gptq`, `awq`, `squeezellm`) | `awq` Default: `None` | -| option.max_rolling_batch_prefill_tokens | \>= 0.24.0 | Pass Through | Limits the number of tokens for prefill(a.k.a prompt processing). This needs to be tuned based on GPU memory available and request lengths. Setting this value too high can limit the number of kv cache blocks or run into GPU OOM. If you don't set this, `lmi-dist` will default to max model length from Hugging Face config(also accounts for rope scaling if applicable). | Default: `None` | -| option.max_model_len | \>= 0.27.0 | Pass Through | The maximum length (input+output) of the request. The request will be stopped if more tokens are generated. `lmi-dist` will default to max model length from Hugging Face config(also accounts for rope scaling if applicable). For models with larger maximum length support(for e.g. 32k for Mistral 7B), it could lead to GPU OOM. In such cases, to deploy on a smaller instances, reduce this value. | Default: `None` | -| option.load_format | \>= 0.27.0 | Pass Through | The checkpoint format of the model. Default is auto and means bin/safetensors will be used if found. | Default: `auto` | -| option.enforce_eager | \>= 0.27.0 | Pass Through | `lmi-dist` by default will run with CUDA graph optimization to reach to the best performance. However, in the situation of very less GPU memory, having CUDA graph enabled will cause OOM. So if you set this option to true, we will use PyTorch Eager mode and disable CUDA graph to save some GBs of memory. `T5` model will not use cuda graphs. | Default: `False` | -| option.gpu_memory_utilization | \>= 0.27.0 | Pass Through | This config controls the amount of GPU memory allocated to KV cache. Setting higher value will allocate more memory for KV cache. Default is 0.9. It recommended to reduce this value if GPU OOM's are encountered. | Default: `0.9` | -| option.speculative_draft_model | \>= 0.27.0 | Pass Through | Model id or path to speculative decoding draft model | Default: `None` | -| option.draft_model_tp_size | \>= 0.27.0 | Pass Through | Tensor parallel degree of speculative decoding draft model. Accepted values are `1` and target model's tensor parallel size(`option.tensor_parallel_degree`) | Default: `1` | -| option.speculative_length | \>= 0.27.0 | Pass Through | Determines the number of tokens draft model generates before verifying against target model | Default: `5` | -| option.record_acceptance_rate | \>= 0.27.0 | LMI | Enables logging speculative decoding acceptance rate | Default: `False` | -| option.enable_lora | \>= 0.27.0 | Pass Through | This config enables support for LoRA adapters. | Default: `false` | -| option.max_loras | \>= 0.27.0 | Pass Through | This config determines the maximum number of LoRA adapters that can be run at once. Allocates GPU memory for those number adapters. | Default: `4` | -| option.max_lora_rank | \>= 0.27.0 | Pass Through | This config determines the maximum rank allowed for a LoRA adapter. Set this value to maximum rank of your adapters. Setting a larger value will enable more adapters at a greater memory usage cost. | Default: `16` | -| option.lora_extra_vocab_size | \>= 0.27.0 | Pass Through | This config determines the maximum additional vocabulary that can be added through a LoRA adapter. | Default: `256` | -| option.max_cpu_loras | \>= 0.27.0 | Pass Through | This config determines the maximum number of LoRA adapters to cache in memory. All others will be evicted to disk. | Default: `None` | -| option.enable_chunked_prefill | \>= 0.29.0 | Pass Through | This config enables chunked prefill support. With chunked prefill, longer prompts will be chunked and batched with decode requests to reduce inter token latency. This option is EXPERIMENTAL and tested for llama and falcon models only. This does not work with LoRA and speculative decoding yet. | Default: `None` | -| option.cpu_offload_gb_per_gpu | \>= 0.29.0 | Pass Through | This config allows offloading model weights into CPU to enable large model running with limited GPU memory. | Default: `0` | -| option.enable_prefix_caching | \>= 0.29.0 | Pass Through | This config allows the engine to cache the context memory and reuse to speed up inference. | Default: `False` | -| option.disable_sliding_window | \>= 0.30.0 | Pass Through | This config disables sliding window, capping to sliding window size inference. | Default: `False` | -| option.tokenizer_mode | \>= 0.30.0 | Pass Through | This config sets the tokenizer mode for vllm. When using mistral models with mistral tokenizers, you must set this to `mistral` explicitly. | Default: `auto` | +| Item | LMI Version | Configuration Type | Description | Example value | +|-----------------------------------------|-------------|--------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------| +| option.quantize | \>= 0.23.0 | LMI | Quantize the model with the supported quantization methods(`gptq`, `awq`, `squeezellm`) | `awq` Default: `None` | +| option.max_rolling_batch_prefill_tokens | \>= 0.24.0 | Pass Through | Limits the number of tokens for prefill(a.k.a prompt processing). This needs to be tuned based on GPU memory available and request lengths. Setting this value too high can limit the number of kv cache blocks or run into GPU OOM. If you don't set this, `lmi-dist` will default to max model length from Hugging Face config(also accounts for rope scaling if applicable). | Default: `None` | +| option.max_model_len | \>= 0.27.0 | Pass Through | The maximum length (input+output) of the request. The request will be stopped if more tokens are generated. `lmi-dist` will default to max model length from Hugging Face config(also accounts for rope scaling if applicable). For models with larger maximum length support(for e.g. 32k for Mistral 7B), it could lead to GPU OOM. In such cases, to deploy on a smaller instances, reduce this value. | Default: `None` | +| option.load_format | \>= 0.27.0 | Pass Through | The checkpoint format of the model. Default is auto and means bin/safetensors will be used if found. | Default: `auto` | +| option.enforce_eager | \>= 0.27.0 | Pass Through | `lmi-dist` by default will run with CUDA graph optimization to reach to the best performance. However, in the situation of very less GPU memory, having CUDA graph enabled will cause OOM. So if you set this option to true, we will use PyTorch Eager mode and disable CUDA graph to save some GBs of memory. `T5` model will not use cuda graphs. | Default: `False` | +| option.gpu_memory_utilization | \>= 0.27.0 | Pass Through | This config controls the amount of GPU memory allocated to KV cache. Setting higher value will allocate more memory for KV cache. Default is 0.9. It recommended to reduce this value if GPU OOM's are encountered. | Default: `0.9` | +| option.speculative_draft_model | \>= 0.27.0 | Pass Through | Model id or path to speculative decoding draft model | Default: `None` | +| option.draft_model_tp_size | \>= 0.27.0 | Pass Through | Tensor parallel degree of speculative decoding draft model. Accepted values are `1` and target model's tensor parallel size(`option.tensor_parallel_degree`) | Default: `1` | +| option.speculative_length | \>= 0.27.0 | Pass Through | Determines the number of tokens draft model generates before verifying against target model | Default: `5` | +| option.record_acceptance_rate | \>= 0.27.0 | LMI | Enables logging speculative decoding acceptance rate | Default: `False` | +| option.enable_lora | \>= 0.27.0 | Pass Through | This config enables support for LoRA adapters. | Default: `false` | +| option.max_loras | \>= 0.27.0 | Pass Through | This config determines the maximum number of LoRA adapters that can be run at once. Allocates GPU memory for those number adapters. | Default: `4` | +| option.max_lora_rank | \>= 0.27.0 | Pass Through | This config determines the maximum rank allowed for a LoRA adapter. Set this value to maximum rank of your adapters. Setting a larger value will enable more adapters at a greater memory usage cost. | Default: `16` | +| option.lora_extra_vocab_size | \>= 0.27.0 | Pass Through | This config determines the maximum additional vocabulary that can be added through a LoRA adapter. | Default: `256` | +| option.max_cpu_loras | \>= 0.27.0 | Pass Through | This config determines the maximum number of LoRA adapters to cache in memory. All others will be evicted to disk. | Default: `None` | +| option.enable_chunked_prefill | \>= 0.29.0 | Pass Through | This config enables chunked prefill support. With chunked prefill, longer prompts will be chunked and batched with decode requests to reduce inter token latency. This option is EXPERIMENTAL and tested for llama and falcon models only. This does not work with LoRA and speculative decoding yet. | Default: `None` | +| option.cpu_offload_gb_per_gpu | \>= 0.29.0 | Pass Through | This config allows offloading model weights into CPU to enable large model running with limited GPU memory. | Default: `0` | +| option.enable_prefix_caching | \>= 0.29.0 | Pass Through | This config allows the engine to cache the context memory and reuse to speed up inference. | Default: `False` | +| option.disable_sliding_window | \>= 0.30.0 | Pass Through | This config disables sliding window, capping to sliding window size inference. | Default: `False` | +| option.tokenizer_mode | \>= 0.30.0 | Pass Through | This config sets the tokenizer mode for vllm. When using mistral models with mistral tokenizers, you must set this to `mistral` explicitly. | Default: `auto` | +| option.generation_config | >= 0.32.0 | LMI | Set to 'auto' to use the generation parameters from the model's `generation_config.json`. Any generation parameters in the request payload take priority over the generation parameters from this config | Possible values: 'auto', None. Default: None | diff --git a/serving/docs/lmi/user_guides/trt_llm_user_guide.md b/serving/docs/lmi/user_guides/trt_llm_user_guide.md index 63f88d55a..e00d71375 100644 --- a/serving/docs/lmi/user_guides/trt_llm_user_guide.md +++ b/serving/docs/lmi/user_guides/trt_llm_user_guide.md @@ -123,6 +123,6 @@ In that situation, there is nothing LMI can do until the issue is fixed in the b | option.use_fp8_context_fmha | >= 0.28.0 | Pass Through | Paged attention for fp8; should only be turned on for p5 instances | `true`, `false`.
Default is `false` | | option.calib_size | >= 0.27.0 | Pass Through | This is applied when `option.quantize` is set to `fp8`. Number of samples for calibration. | Default is `512` | | option.calib_batch_size | >= 0.28.0 | Pass Through | This is applied when `option.quantize` is set to `fp8`. Batch size for calibration. | Default is `32` | - +| option.generation_config | >= 0.32.0 | LMI | Set to 'auto' to use the generation parameters from the model's `generation_config.json`. Any generation parameters in the request payload take priority over the generation parameters from this config | Possible values: 'auto', None. Default: None | diff --git a/serving/docs/lmi/user_guides/vllm_user_guide.md b/serving/docs/lmi/user_guides/vllm_user_guide.md index c5911fe30..0ddcdc8ed 100644 --- a/serving/docs/lmi/user_guides/vllm_user_guide.md +++ b/serving/docs/lmi/user_guides/vllm_user_guide.md @@ -109,22 +109,24 @@ For `LMI` configurations, if we determine an issue with the configuration, we wi For `Pass Through` configurations it is possible that our investigation reveals an issue with the backend library. In that situation, there is nothing LMI can do until the issue is fixed in the backend library. -| Item | LMI Version | Configuration Type | Description | Example value | -|-----------------------------------------|-------------|--------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------| -| option.quantize | \>= 0.26.0 | LMI | Quantize the model with the supported quantization methods. LMI uses this to set the right quantization configs in VLLM | `awq` Default: `None` | -| option.max_rolling_batch_prefill_tokens | \>= 0.26.0 | LMI | Limits the number of tokens for prefill(a.k.a prompt processing). This needs to be tuned based on GPU memory available and request lengths. Setting this value too high can limit the number of kv cache blocks or run into OOM. If you don't set this, `vllm` will default to max model length from Hugging Face config(also accounts for rope scaling if applicable). | Default: `None` | -| option.max_model_len | \>= 0.26.0 | Pass Through | the maximum length (input+output) vLLM should preserve memory for. If not specified, will use the default length the model is capable in config.json. Sometimes model's maximum length could go to 32k (Mistral 7B) and way beyond the supported KV token size. In that case to deploy on a small instance, we need to adjust this value within the range of KV Cache limit. | Default: `None` | -| option.load_format | \>= 0.26.0 | Pass Through | The checkpoint format of the model. Default is auto and means bin/safetensors will be used if found. | Default: `auto` | -| option.enforce_eager | \>= 0.27.0 | Pass Through | vLLM by default will run with CUDA graph optimization to reach to the best performance. However, in the situation of very less GPU memory, having CUDA graph enabled will cause OOM. So if you set this option to true, we will use PyTorch Eager mode and disable CUDA graph to save some GBs of memory. | Default: `False` | -| option.gpu_memory_utilization | \>= 0.27.0 | Pass Through | This config controls the amount of GPU memory allocated to KV cache. Setting higher value will allocate more memory for KV cache.Default is 0.9. It recommended to reduce this value if GPU OOM's are encountered. | Default: `0.9` | -| option.enable_lora | \>= 0.27.0 | Pass Through | This config enables support for LoRA adapters. | Default: `false` | -| option.max_loras | \>= 0.27.0 | Pass Through | This config determines the maximum number of LoRA adapters that can be run at once. Allocates GPU memory for those number of adapters. | Default: `4` | -| option.max_lora_rank | \>= 0.27.0 | Pass Through | This config determines the maximum rank allowed for a LoRA adapter. Set this value to maximum rank of your adapters. Setting a larger value will enable more adapters at a greater memory usage cost. | Default: `16` | -| option.lora_extra_vocab_size | \>= 0.27.0 | Pass Through | This config determines the maximum additional vocabulary that can be added through a LoRA adapter. | Default: `256` | -| option.max_cpu_loras | \>= 0.27.0 | Pass Through | This config determines the maximum number of LoRA adapters to cache in memory. All others will be evicted to disk. | Default: `None` | -| option.enable_chunked_prefill | \>= 0.29.0 | Pass Through | This config enables chunked prefill support. With chunked prefill, longer prompts will be chunked and batched with decode requests to reduce inter token latency. This option is EXPERIMENTAL and tested for llama and falcon models only. This does not work with LoRA and speculative decoding yet. | Default: `None` | -| option.cpu_offload_gb_per_gpu | \>= 0.29.0 | Pass Through | This config allows offloading model weights into CPU to enable large model running with limited GPU memory. | Default: `0` | -| option.enable_prefix_caching | \>= 0.29.0 | Pass Through | This config allows the engine to cache the context memory and reuse to speed up inference. | Default: `False` | -| option.disable_sliding_window | \>= 0.30.0 | Pass Through | This config disables sliding window, capping to sliding window size inference. | Default: `False` | -| option.tokenizer_mode | \>= 0.30.0 | Pass Through | This config sets the tokenizer mode for vllm. When using mistral models with mistral tokenizers, you must set this to `mistral` explicitly. | Default: `auto` | +| Item | LMI Version | Configuration Type | Description | Example value | +|-----------------------------------------|-------------|--------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------| +| option.quantize | \>= 0.26.0 | LMI | Quantize the model with the supported quantization methods. LMI uses this to set the right quantization configs in VLLM | `awq` Default: `None` | +| option.max_rolling_batch_prefill_tokens | \>= 0.26.0 | LMI | Limits the number of tokens for prefill(a.k.a prompt processing). This needs to be tuned based on GPU memory available and request lengths. Setting this value too high can limit the number of kv cache blocks or run into OOM. If you don't set this, `vllm` will default to max model length from Hugging Face config(also accounts for rope scaling if applicable). | Default: `None` | +| option.max_model_len | \>= 0.26.0 | Pass Through | the maximum length (input+output) vLLM should preserve memory for. If not specified, will use the default length the model is capable in config.json. Sometimes model's maximum length could go to 32k (Mistral 7B) and way beyond the supported KV token size. In that case to deploy on a small instance, we need to adjust this value within the range of KV Cache limit. | Default: `None` | +| option.load_format | \>= 0.26.0 | Pass Through | The checkpoint format of the model. Default is auto and means bin/safetensors will be used if found. | Default: `auto` | +| option.enforce_eager | \>= 0.27.0 | Pass Through | vLLM by default will run with CUDA graph optimization to reach to the best performance. However, in the situation of very less GPU memory, having CUDA graph enabled will cause OOM. So if you set this option to true, we will use PyTorch Eager mode and disable CUDA graph to save some GBs of memory. | Default: `False` | +| option.gpu_memory_utilization | \>= 0.27.0 | Pass Through | This config controls the amount of GPU memory allocated to KV cache. Setting higher value will allocate more memory for KV cache.Default is 0.9. It recommended to reduce this value if GPU OOM's are encountered. | Default: `0.9` | +| option.enable_lora | \>= 0.27.0 | Pass Through | This config enables support for LoRA adapters. | Default: `false` | +| option.max_loras | \>= 0.27.0 | Pass Through | This config determines the maximum number of LoRA adapters that can be run at once. Allocates GPU memory for those number of adapters. | Default: `4` | +| option.max_lora_rank | \>= 0.27.0 | Pass Through | This config determines the maximum rank allowed for a LoRA adapter. Set this value to maximum rank of your adapters. Setting a larger value will enable more adapters at a greater memory usage cost. | Default: `16` | +| option.lora_extra_vocab_size | \>= 0.27.0 | Pass Through | This config determines the maximum additional vocabulary that can be added through a LoRA adapter. | Default: `256` | +| option.max_cpu_loras | \>= 0.27.0 | Pass Through | This config determines the maximum number of LoRA adapters to cache in memory. All others will be evicted to disk. | Default: `None` | +| option.enable_chunked_prefill | \>= 0.29.0 | Pass Through | This config enables chunked prefill support. With chunked prefill, longer prompts will be chunked and batched with decode requests to reduce inter token latency. This option is EXPERIMENTAL and tested for llama and falcon models only. This does not work with LoRA and speculative decoding yet. | Default: `None` | +| option.cpu_offload_gb_per_gpu | \>= 0.29.0 | Pass Through | This config allows offloading model weights into CPU to enable large model running with limited GPU memory. | Default: `0` | +| option.enable_prefix_caching | \>= 0.29.0 | Pass Through | This config allows the engine to cache the context memory and reuse to speed up inference. | Default: `False` | +| option.disable_sliding_window | \>= 0.30.0 | Pass Through | This config disables sliding window, capping to sliding window size inference. | Default: `False` | +| option.tokenizer_mode | \>= 0.30.0 | Pass Through | This config sets the tokenizer mode for vllm. When using mistral models with mistral tokenizers, you must set this to `mistral` explicitly. | Default: `auto` | +| option.generation_config | >= 0.32.0 | LMI | Set to 'auto' to use the generation parameters from the model's `generation_config.json`. Any generation parameters in the request payload take priority over the generation parameters from this config | Possible values: 'auto', None. Default: None | +