From 2a7ba6ddc6b77479cff51b37b90a861e84e333eb Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 7 Nov 2024 14:00:21 +0800 Subject: [PATCH] [Misc] Consolidate ModelConfig code related to HF config (#10104) Signed-off-by: DarkLight1337 --- docs/source/serving/compatibility_matrix.rst | 2 +- tests/test_config.py | 38 ++++++++++++++++++++ vllm/config.py | 14 ++++---- vllm/inputs/preprocess.py | 2 +- vllm/transformers_utils/config.py | 9 +++++ vllm/utils.py | 4 --- vllm/worker/cpu_model_runner.py | 9 +---- vllm/worker/cpu_worker.py | 5 +-- vllm/worker/model_runner.py | 23 +++++------- vllm/worker/worker.py | 5 +-- 10 files changed, 68 insertions(+), 43 deletions(-) diff --git a/docs/source/serving/compatibility_matrix.rst b/docs/source/serving/compatibility_matrix.rst index cab19e4ec5b6c..f629b3ca78318 100644 --- a/docs/source/serving/compatibility_matrix.rst +++ b/docs/source/serving/compatibility_matrix.rst @@ -359,7 +359,7 @@ Feature x Hardware - ✅ - ✅ - ✅ - - `✗ `__ + - ✅ - ✗ * - :abbr:`logP (Logprobs)` - ✅ diff --git a/tests/test_config.py b/tests/test_config.py index 69918b67607d9..5211049bf0011 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -165,3 +165,41 @@ def test_rope_customization(): assert getattr(longchat_model_config.hf_config, "rope_scaling", None) == TEST_ROPE_SCALING assert longchat_model_config.max_model_len == 4096 + + +@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [ + ("facebook/opt-125m", False), + ("facebook/bart-base", True), + ("meta-llama/Llama-3.2-1B", False), + ("meta-llama/Llama-3.2-11B-Vision", True), +]) +def test_is_encoder_decoder(model_id, is_encoder_decoder): + config = ModelConfig( + model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + ) + + assert config.is_encoder_decoder == is_encoder_decoder + + +@pytest.mark.parametrize(("model_id", "uses_mrope"), [ + ("facebook/opt-125m", False), + ("Qwen/Qwen2-VL-2B-Instruct", True), +]) +def test_uses_mrope(model_id, uses_mrope): + config = ModelConfig( + model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + ) + + assert config.uses_mrope == uses_mrope diff --git a/vllm/config.py b/vllm/config.py index 91bbbfec4b7b3..c7fad3a261858 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -15,7 +15,8 @@ from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.transformers_utils.config import (ConfigFormat, get_config, get_hf_image_processor_config, - get_hf_text_config) + get_hf_text_config, + is_encoder_decoder, uses_mrope) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, print_warning_once) @@ -667,12 +668,13 @@ def get_multimodal_config(self) -> "MultiModalConfig": return self.multimodal_config @property - def is_encoder_decoder_model(self) -> bool: + def is_encoder_decoder(self) -> bool: """Extract the HF encoder/decoder model flag.""" - return getattr( - self.hf_config, "is_encoder_decoder", - False) or (hasattr(self.hf_config, "text_config") and getattr( - self.hf_config.text_config, "is_encoder_decoder", False)) + return is_encoder_decoder(self.hf_config) + + @property + def uses_mrope(self) -> bool: + return uses_mrope(self.hf_config) @property def is_multimodal_model(self) -> bool: diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index a5c787a56b5a9..509b0448b9e51 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -580,4 +580,4 @@ async def preprocess_async( ) def is_encoder_decoder_model(self): - return self.model_config.is_encoder_decoder_model + return self.model_config.is_encoder_decoder diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 1a5870aa4f84c..415d8bf7cc2bb 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -129,6 +129,15 @@ def uses_mrope(config: PretrainedConfig) -> bool: return "mrope_section" in rope_scaling +def is_encoder_decoder(config: PretrainedConfig) -> bool: + """Detect if the model with this config is used as an encoder/decoder.""" + text_config = getattr(config, "text_config", None) + if text_config is not None: + return is_encoder_decoder(text_config) + + return getattr(config, "is_encoder_decoder", False) + + def get_config( model: Union[str, Path], trust_remote_code: bool, diff --git a/vllm/utils.py b/vllm/utils.py index d78130873d3dc..13d7f6d475346 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -88,9 +88,6 @@ "currently supported with encoder/" "decoder models.") -STR_NOT_IMPL_ENC_DEC_CPU = ("CPU is not currently supported with " - "encoder/decoder models.") - # Efficiently import all enc/dec error strings # rather than having to import all of the above STR_NOT_IMPL_ENC_DEC_ERR_STRS = { @@ -105,7 +102,6 @@ "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC, "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND, "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER, - "STR_NOT_IMPL_ENC_DEC_CPU": STR_NOT_IMPL_ENC_DEC_CPU } # Constants related to forcing the attention backend selection diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index fdd72a452f2ad..26a15ed645c43 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -18,7 +18,6 @@ MultiModalInputs, MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) -from vllm.transformers_utils.config import uses_mrope from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, @@ -163,7 +162,7 @@ def _compute_multi_modal_input(self, seq_group: SequenceGroupMetadata, # special processing for mrope position deltas. mrope_positions = None - if self.runner.model_is_mrope: + if self.runner.model_config.uses_mrope: image_grid_thw = mm_kwargs.get("image_grid_thw", None) video_grid_thw = mm_kwargs.get("video_grid_thw", None) assert image_grid_thw is not None or video_grid_thw is not None, ( @@ -446,12 +445,6 @@ def __init__( # Lazy initialization. self.model: nn.Module # Set after init_Model - @property - def model_is_mrope(self) -> bool: - """Detect if the model has "mrope" rope_scaling type. - mrope requires keep "rope_deltas" between prompt and decoding phases.""" - return uses_mrope(self.model_config.hf_config) - def load_model(self) -> None: self.model = get_model(vllm_config=self.vllm_config) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 3778707ae07e8..2914f520d823c 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -151,7 +151,7 @@ def __init__( self.local_omp_cpuid = omp_cpuids.split("|")[rank] ModelRunnerClass: Type[CPUModelRunner] = CPUModelRunner - if self._is_encoder_decoder_model(): + if self.model_config.is_encoder_decoder: ModelRunnerClass = CPUEncoderDecoderModelRunner self.model_runner: CPUModelRunner = ModelRunnerClass( vllm_config=vllm_config, @@ -188,9 +188,6 @@ def stop_profile(self): raise RuntimeError("Profiler is not enabled.") self.profiler.stop() - def _is_encoder_decoder_model(self): - return self.model_config.is_encoder_decoder_model - def init_device(self) -> None: if self.local_omp_cpuid != "all": ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a260e1f6c1a36..f50eb37829c9f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -47,7 +47,6 @@ LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.transformers_utils.config import uses_mrope from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_pin_memory_available, supports_dynamo, @@ -493,7 +492,7 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, context_len = seq_data.get_num_computed_tokens() seq_len = min(seq_len, context_len + token_chunk_size) elif self.runner.scheduler_config.is_multi_step or \ - self.runner.model_config.is_encoder_decoder_model: + self.runner.model_config.is_encoder_decoder: context_len = seq_len - 1 else: context_len = seq_data.get_num_computed_tokens() @@ -666,7 +665,7 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, inter_data.multi_modal_placeholder_maps = placeholder_maps # special processing for mrope position deltas. - if self.runner.model_is_mrope: + if self.runner.model_config.uses_mrope: image_grid_thw = mm_kwargs.get("image_grid_thw", None) video_grid_thw = mm_kwargs.get("video_grid_thw", None) assert image_grid_thw is not None or video_grid_thw is not None, ( @@ -711,7 +710,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): encoder_seq_len = 0 - if self.runner.model_config.is_encoder_decoder_model: + if self.runner.model_config.is_encoder_decoder: encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() inter_data = self.init_cached_inter_data( @@ -837,7 +836,7 @@ def build(self) -> ModelInputForGPU: if not inter_data.is_prompt: max_decode_seq_len = max(max_decode_seq_len, max(inter_data.seq_lens)) - if self.runner.model_config.is_encoder_decoder_model: + if self.runner.model_config.is_encoder_decoder: max_encoder_seq_len = max(max_encoder_seq_len, inter_data.encoder_seq_len) @@ -1379,12 +1378,6 @@ def list_prompt_adapters(self) -> Set[int]: raise RuntimeError("PromptAdapter is not enabled.") return self.prompt_adapter_manager.list_adapters() - @property - def model_is_mrope(self) -> bool: - """Detect if the model has "mrope" rope_scaling type. - mrope requires keep "rope_deltas" between prompt and decoding phases.""" - return uses_mrope(self.model_config.hf_config) - @torch.inference_mode() def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: """Cuda graph capture a model. @@ -1415,7 +1408,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: max_batch_size = self.max_batchsize_to_capture input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() - if self.model_is_mrope: + if self.model_config.uses_mrope: input_positions = torch.tile(input_positions, (3, 1)) # Prepare dummy previous_hidden_states only if needed by the model. # This is used by draft models such as EAGLE. @@ -1451,7 +1444,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self.attn_state.graph_capture_get_metadata_for_batch( batch_size, is_encoder_decoder_model=self.model_config. - is_encoder_decoder_model)) + is_encoder_decoder)) if self.lora_config: lora_mapping = LoRAMapping( @@ -1470,7 +1463,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: graph_runner = CUDAGraphRunner( self.model, self.attn_backend.get_name(), self.attn_state.graph_clone(batch_size), - self.model_config.is_encoder_decoder_model) + self.model_config.is_encoder_decoder) capture_inputs = { "input_ids": @@ -1501,7 +1494,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self.model.get_seqlen_agnostic_capture_inputs( batch_size) }) - if self.model_config.is_encoder_decoder_model: + if self.model_config.is_encoder_decoder: # add the additional inputs to capture for # encoder-decoder models. self._update_inputs_to_capture_for_enc_dec_model( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index e5d68becc93b4..c6d97a70c8adb 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -77,7 +77,7 @@ def __init__( ModelRunnerClass = model_runner_cls elif model_config.task == "embedding": ModelRunnerClass = EmbeddingModelRunner - elif self._is_encoder_decoder_model(): + elif self.model_config.is_encoder_decoder: ModelRunnerClass = EncoderDecoderModelRunner self.model_runner: GPUModelRunnerBase = ModelRunnerClass( vllm_config=self.vllm_config, @@ -119,9 +119,6 @@ def stop_profile(self): raise RuntimeError("Profiler is not enabled.") self.profiler.stop() - def _is_encoder_decoder_model(self): - return self.model_config.is_encoder_decoder_model - def init_device(self) -> None: if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until