diff --git a/tests/test_config.py b/tests/test_config.py index 19db10630bbae..676065f9f0305 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,29 @@ +import pytest + from vllm.config import ModelConfig +MODEL_IDS_EXPECTED = [ + ("Qwen/Qwen1.5-7B", 32768), + ("mistralai/Mistral-7B-v0.1", 4096), + ("mistralai/Mistral-7B-Instruct-v0.2", 32768), +] + + +@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED) +def test_disable_sliding_window(model_id_expected): + model_id, expected = model_id_expected + model_config = ModelConfig( + model_id, + model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + disable_sliding_window=True, + ) + assert model_config.max_model_len == expected + def test_get_sliding_window(): TEST_SLIDING_WINDOW = 4096 @@ -36,4 +60,4 @@ def test_get_sliding_window(): assert mistral_model_config.get_sliding_window() is None mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW - assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW \ No newline at end of file + assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW diff --git a/vllm/config.py b/vllm/config.py index aedb589247646..bc01e877dfbe5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -25,6 +25,16 @@ VLLM_USE_MODELSCOPE = os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true" +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.float16, + "float16": torch.float16, + "float": torch.float32, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + +_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"] + _GB = 1 << 30 @@ -66,6 +76,10 @@ class ModelConfig: max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. + disable_sliding_window: Whether to disable sliding window. If True, + we will disable the sliding window functionality of the model. + If the model does not support sliding window, this argument is + ignored. skip_tokenizer_init: If true, skip initialization of tokenizer and detokenizer. """ @@ -87,6 +101,7 @@ def __init__( enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, max_logprobs: int = 5, + disable_sliding_window: bool = False, skip_tokenizer_init: bool = False, ) -> None: self.model = model @@ -102,14 +117,15 @@ def __init__( self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture self.max_logprobs = max_logprobs + self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, code_revision) self.hf_text_config = get_hf_text_config(self.hf_config) - self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) - self.max_model_len = _get_and_verify_max_len(self.hf_text_config, - max_model_len) + self.dtype = self._get_and_verify_dtype(self.hf_text_config, dtype) + self.max_model_len = self._get_and_verify_max_len( + self.hf_text_config, self.disable_sliding_window, max_model_len) if not self.skip_tokenizer_init: self._verify_tokenizer_mode() self._verify_quantization() @@ -197,10 +213,7 @@ def verify_with_parallel_config( "must be divisible by pipeline parallel size " f"({pipeline_parallel_size}).") - def get_sliding_window(self) -> Optional[int]: - """Get the sliding window size, or None if disabled. - """ - + def get_hf_config_sliding_window(self) -> Optional[int]: # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in # addition to sliding window size. We check if that field is present # and if it's False, return None. @@ -209,6 +222,18 @@ def get_sliding_window(self) -> Optional[int]: return None return getattr(self.hf_text_config, "sliding_window", None) + def get_sliding_window(self) -> Optional[int]: + """Get the sliding window size, or None if disabled. + """ + # If user disables sliding window, return None. + if self.disable_sliding_window: + logger.info("Sliding window is disabled per configuration. " + "Model max length will be capped at sliding window " + "length.") + return None + # Otherwise get the value from the hf config. + return self.get_hf_config_sliding_window() + def get_vocab_size(self) -> int: return self.hf_text_config.vocab_size @@ -275,6 +300,160 @@ def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_text_config.num_hidden_layers return total_num_hidden_layers // parallel_config.pipeline_parallel_size + def _get_and_verify_dtype( + self, + config: PretrainedConfig, + dtype: Union[str, torch.dtype], + ) -> torch.dtype: + # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct + # because config.torch_dtype can be None. + config_dtype = getattr(config, "torch_dtype", None) + if config_dtype is None: + config_dtype = torch.float32 + + if isinstance(dtype, str): + dtype = dtype.lower() + if dtype == "auto": + if config_dtype == torch.float32: + # Following the common practice, we use float16 for float32 + # models. + torch_dtype = torch.float16 + else: + torch_dtype = config_dtype + else: + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {dtype}") + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + elif isinstance(dtype, torch.dtype): + torch_dtype = dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + if is_hip() and torch_dtype == torch.float32: + rocm_supported_dtypes = [ + k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() + if (k not in _ROCM_NOT_SUPPORTED_DTYPE) + ] + raise ValueError(f"dtype '{dtype}' is not supported in ROCm. " + f"Supported dtypes are {rocm_supported_dtypes}") + + # Verify the dtype. + if torch_dtype != config_dtype: + if torch_dtype == torch.float32: + # Upcasting to float32 is allowed. + pass + elif config_dtype == torch.float32: + # Downcasting from float32 to float16 or bfloat16 is allowed. + pass + else: + # Casting between float16 and bfloat16 is allowed with warning. + logger.warning("Casting %s to %s.", config_dtype, torch_dtype) + + return torch_dtype + + def _get_and_verify_max_len( + self, + hf_config: PretrainedConfig, + disable_sliding_window: bool, + max_model_len: Optional[int], + ) -> int: + """Get and verify the model's maximum length.""" + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # ChatGLM2 + "seq_length", + # Command-R + "model_max_length", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + # Choose the smallest "max_length" from the possible keys. + max_len_key = None + for key in possible_keys: + max_len = getattr(hf_config, key, None) + if max_len is not None: + max_len_key = key if max_len < derived_max_model_len \ + else max_len_key + derived_max_model_len = min(derived_max_model_len, max_len) + + # If sliding window is manually disabled, max_length should be less + # than the sliding window length in the model config. + max_len = self.get_hf_config_sliding_window() + if disable_sliding_window and max_len is not None: + max_len_key = "sliding_window" \ + if max_len < derived_max_model_len else max_len_key + derived_max_model_len = min(derived_max_model_len, max_len) + + # If none of the keys were found in the config, use a default and + # log a warning. + if derived_max_model_len == float("inf"): + if max_model_len is not None: + # If max_model_len is specified, we use it. + return max_model_len + + default_max_len = 2048 + logger.warning( + "The model's config.json does not contain any of the following " + "keys to determine the original maximum length of the model: " + "%d. Assuming the model's maximum length is %d.", + possible_keys, default_max_len) + derived_max_model_len = default_max_len + + rope_scaling = getattr(hf_config, "rope_scaling", None) + if rope_scaling is not None and rope_scaling["type"] != "su": + if disable_sliding_window: + # TODO(robertgshaw): Find a model that supports rope_scaling + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "with rope_scaling. Please raise an issue so we can " + "investigate.") + assert "factor" in rope_scaling + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "yarn": + derived_max_model_len = rope_scaling[ + "original_max_position_embeddings"] + derived_max_model_len *= scaling_factor + + # If the user specified a max length, make sure it is smaller than the + # derived length from the HF model config. + if max_model_len is None: + max_model_len = int(derived_max_model_len) + elif max_model_len > derived_max_model_len: + # Some models might have a separate key for specifying + # model_max_length that will be bigger than derived_max_model_len. + # We compare user input with model_max_length and allow this + # override when it's smaller. + model_max_length = getattr(hf_config, "model_max_length", None) + if (model_max_length is not None + and max_model_len <= model_max_length): + if disable_sliding_window: + # TODO(robertgshaw): Find a model that has model_max_length + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "model_max_length in the config. Please raise an issue " + "so we can investigate.") + pass + else: + raise ValueError( + f"User-specified max_model_len ({max_model_len}) is " + "greater than the derived max_model_len " + f"({max_len_key}={derived_max_model_len} or " + f"model_max_length={model_max_length} in model's " + "config.json). This may lead to incorrect model outputs " + "or CUDA errors. Make sure the value is correct and " + "within the model context size.") + return int(max_model_len) + class CacheConfig: """Configuration for the KV cache. @@ -308,6 +487,7 @@ def __init__( self.enable_prefix_caching = enable_prefix_caching self._verify_args() self._verify_cache_dtype() + self._verify_prefix_caching() # Will be set after profiling. self.num_gpu_blocks = None @@ -344,6 +524,19 @@ def _verify_cache_dtype(self) -> None: else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + def _verify_prefix_caching(self) -> None: + if not self.enable_prefix_caching: + return + + if self.sliding_window is not None: + raise NotImplementedError( + "Prefix caching is not supported with sliding window. " + "Run with --disable-sliding-window to use prefix caching.") + if self.cache_dtype == "fp8": + raise NotImplementedError( + "Prefix caching is not supported for fp8 cache_dtype. " + "Run with --kv-cache-dtype auto to use prefix caching.") + def verify_with_parallel_config( self, parallel_config: "ParallelConfig", @@ -952,139 +1145,6 @@ def get_image_input_enum_type( f"{[x.name for x in cls.ImageInputType]}.") from e -_STR_DTYPE_TO_TORCH_DTYPE = { - "half": torch.float16, - "float16": torch.float16, - "float": torch.float32, - "float32": torch.float32, - "bfloat16": torch.bfloat16, -} - -_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"] - - -def _get_and_verify_dtype( - config: PretrainedConfig, - dtype: Union[str, torch.dtype], -) -> torch.dtype: - # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct - # because config.torch_dtype can be None. - config_dtype = getattr(config, "torch_dtype", None) - if config_dtype is None: - config_dtype = torch.float32 - - if isinstance(dtype, str): - dtype = dtype.lower() - if dtype == "auto": - if config_dtype == torch.float32: - # Following the common practice, we use float16 for float32 - # models. - torch_dtype = torch.float16 - else: - torch_dtype = config_dtype - else: - if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: - raise ValueError(f"Unknown dtype: {dtype}") - torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] - elif isinstance(dtype, torch.dtype): - torch_dtype = dtype - else: - raise ValueError(f"Unknown dtype: {dtype}") - - if is_hip() and torch_dtype == torch.float32: - rocm_supported_dtypes = [ - k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() - if (k not in _ROCM_NOT_SUPPORTED_DTYPE) - ] - raise ValueError(f"dtype '{dtype}' is not supported in ROCm. " - f"Supported dtypes are {rocm_supported_dtypes}") - - # Verify the dtype. - if torch_dtype != config_dtype: - if torch_dtype == torch.float32: - # Upcasting to float32 is allowed. - pass - elif config_dtype == torch.float32: - # Downcasting from float32 to float16 or bfloat16 is allowed. - pass - else: - # Casting between float16 and bfloat16 is allowed with a warning. - logger.warning("Casting %s to %s.", config_dtype, torch_dtype) - - return torch_dtype - - -def _get_and_verify_max_len( - hf_config: PretrainedConfig, - max_model_len: Optional[int], -) -> int: - """Get and verify the model's maximum length.""" - derived_max_model_len = float("inf") - possible_keys = [ - # OPT - "max_position_embeddings", - # GPT-2 - "n_positions", - # MPT - "max_seq_len", - # ChatGLM2 - "seq_length", - # Command-R - "model_max_length", - # Others - "max_sequence_length", - "max_seq_length", - "seq_len", - ] - max_len_key = None - for key in possible_keys: - max_len = getattr(hf_config, key, None) - if max_len is not None: - max_len_key = key if max_len < derived_max_model_len \ - else max_len_key - derived_max_model_len = min(derived_max_model_len, max_len) - if derived_max_model_len == float("inf"): - if max_model_len is not None: - # If max_model_len is specified, we use it. - return max_model_len - - default_max_len = 2048 - logger.warning( - "The model's config.json does not contain any of the following " - "keys to determine the original maximum length of the model: " - "%d. Assuming the model's maximum length is %d.", possible_keys, - default_max_len) - derived_max_model_len = default_max_len - - rope_scaling = getattr(hf_config, "rope_scaling", None) - if rope_scaling is not None and rope_scaling["type"] != "su": - assert "factor" in rope_scaling - scaling_factor = rope_scaling["factor"] - if rope_scaling["type"] == "yarn": - derived_max_model_len = rope_scaling[ - "original_max_position_embeddings"] - derived_max_model_len *= scaling_factor - - if max_model_len is None: - max_model_len = int(derived_max_model_len) - elif max_model_len > derived_max_model_len: - # Some models might have a separate key for specifying model_max_length - # that will be bigger than derived_max_model_len. We compare user input - # with model_max_length and allow this override when it's smaller. - model_max_length = getattr(hf_config, "model_max_length", None) - if model_max_length is not None and max_model_len <= model_max_length: - pass - else: - raise ValueError( - f"User-specified max_model_len ({max_model_len}) is greater " - "than the derived max_model_len " - f"({max_len_key}={derived_max_model_len} or model_max_length=" - f"{model_max_length} in model's config.json). This may lead " - "to incorrect model outputs or CUDA errors. Make sure the " - "value is correct and within the model context size.") - return int(max_model_len) - - @dataclass class DecodingConfig: """Dataclass which contains the decoding strategy of the engine""" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bd6437ee44c28..1e6428277b80e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -32,6 +32,7 @@ class EngineArgs: max_parallel_loading_workers: Optional[int] = None block_size: int = 16 enable_prefix_caching: bool = False + disable_sliding_window: bool = False use_v2_block_manager: bool = False swap_space: int = 4 # GiB gpu_memory_utilization: float = 0.90 @@ -246,6 +247,10 @@ def add_cli_args( parser.add_argument('--enable-prefix-caching', action='store_true', help='Enables automatic prefix caching.') + parser.add_argument('--disable-sliding-window', + action='store_true', + help='Disables sliding window if the model ' + 'supports sliding window') parser.add_argument('--use-v2-block-manager', action='store_true', help='Use BlockSpaceMangerV2.') @@ -476,7 +481,8 @@ def create_engine_config(self, ) -> EngineConfig: self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, self.quantization_param_path, self.enforce_eager, self.max_context_len_to_capture, - self.max_logprobs, self.skip_tokenizer_init) + self.max_logprobs, self.disable_sliding_window, + self.skip_tokenizer_init) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype,