Skip to content

Commit

Permalink
[Misc] Consolidate ModelConfig code related to HF config (vllm-projec…
Browse files Browse the repository at this point in the history
…t#10104)

Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Nov 7, 2024
1 parent 511cfe5 commit ab1930c
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 43 deletions.
2 changes: 1 addition & 1 deletion docs/source/serving/compatibility_matrix.rst
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ Feature x Hardware
- ✅
- ✅
- ✅
- `<https://github.com/vllm-project/vllm/blob/a84e598e2125960d3b4f716b78863f24ac562947/vllm/worker/cpu_model_runner.py#L125>`__
-
- ✗
* - :abbr:`logP (Logprobs)`
- ✅
Expand Down
38 changes: 38 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,41 @@ def test_rope_customization():
assert getattr(longchat_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING
assert longchat_model_config.max_model_len == 4096


@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
("facebook/opt-125m", False),
("facebook/bart-base", True),
("meta-llama/Llama-3.2-1B", False),
("meta-llama/Llama-3.2-11B-Vision", True),
])
def test_is_encoder_decoder(model_id, is_encoder_decoder):
config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
)

assert config.is_encoder_decoder == is_encoder_decoder


@pytest.mark.parametrize(("model_id", "uses_mrope"), [
("facebook/opt-125m", False),
("Qwen/Qwen2-VL-2B-Instruct", True),
])
def test_uses_mrope(model_id, uses_mrope):
config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
)

assert config.uses_mrope == uses_mrope
14 changes: 8 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_hf_text_config)
get_hf_text_config,
is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
print_warning_once)

Expand Down Expand Up @@ -667,12 +668,13 @@ def get_multimodal_config(self) -> "MultiModalConfig":
return self.multimodal_config

@property
def is_encoder_decoder_model(self) -> bool:
def is_encoder_decoder(self) -> bool:
"""Extract the HF encoder/decoder model flag."""
return getattr(
self.hf_config, "is_encoder_decoder",
False) or (hasattr(self.hf_config, "text_config") and getattr(
self.hf_config.text_config, "is_encoder_decoder", False))
return is_encoder_decoder(self.hf_config)

@property
def uses_mrope(self) -> bool:
return uses_mrope(self.hf_config)

@property
def is_multimodal_model(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,4 +580,4 @@ async def preprocess_async(
)

def is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model
return self.model_config.is_encoder_decoder
9 changes: 9 additions & 0 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ def uses_mrope(config: PretrainedConfig) -> bool:
return "mrope_section" in rope_scaling


def is_encoder_decoder(config: PretrainedConfig) -> bool:
"""Detect if the model with this config is used as an encoder/decoder."""
text_config = getattr(config, "text_config", None)
if text_config is not None:
return is_encoder_decoder(text_config)

return getattr(config, "is_encoder_decoder", False)


def get_config(
model: Union[str, Path],
trust_remote_code: bool,
Expand Down
4 changes: 0 additions & 4 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,6 @@
"currently supported with encoder/"
"decoder models.")

STR_NOT_IMPL_ENC_DEC_CPU = ("CPU is not currently supported with "
"encoder/decoder models.")

# Efficiently import all enc/dec error strings
# rather than having to import all of the above
STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
Expand All @@ -105,7 +102,6 @@
"STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
"STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
"STR_NOT_IMPL_ENC_DEC_CPU": STR_NOT_IMPL_ENC_DEC_CPU
}

# Constants related to forcing the attention backend selection
Expand Down
9 changes: 1 addition & 8 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
MultiModalInputs, MultiModalPlaceholderMap)
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
Expand Down Expand Up @@ -163,7 +162,7 @@ def _compute_multi_modal_input(self, seq_group: SequenceGroupMetadata,

# special processing for mrope position deltas.
mrope_positions = None
if self.runner.model_is_mrope:
if self.runner.model_config.uses_mrope:
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
assert image_grid_thw is not None or video_grid_thw is not None, (
Expand Down Expand Up @@ -446,12 +445,6 @@ def __init__(
# Lazy initialization.
self.model: nn.Module # Set after init_Model

@property
def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases."""
return uses_mrope(self.model_config.hf_config)

def load_model(self) -> None:
self.model = get_model(vllm_config=self.vllm_config)

Expand Down
5 changes: 1 addition & 4 deletions vllm/worker/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(
self.local_omp_cpuid = omp_cpuids.split("|")[rank]

ModelRunnerClass: Type[CPUModelRunner] = CPUModelRunner
if self._is_encoder_decoder_model():
if self.model_config.is_encoder_decoder:
ModelRunnerClass = CPUEncoderDecoderModelRunner
self.model_runner: CPUModelRunner = ModelRunnerClass(
vllm_config=vllm_config,
Expand Down Expand Up @@ -188,9 +188,6 @@ def stop_profile(self):
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()

def _is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model

def init_device(self) -> None:
if self.local_omp_cpuid != "all":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
Expand Down
23 changes: 8 additions & 15 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
async_tensor_h2d, flatten_2d_lists,
is_pin_memory_available, supports_dynamo,
Expand Down Expand Up @@ -493,7 +492,7 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
context_len = seq_data.get_num_computed_tokens()
seq_len = min(seq_len, context_len + token_chunk_size)
elif self.runner.scheduler_config.is_multi_step or \
self.runner.model_config.is_encoder_decoder_model:
self.runner.model_config.is_encoder_decoder:
context_len = seq_len - 1
else:
context_len = seq_data.get_num_computed_tokens()
Expand Down Expand Up @@ -666,7 +665,7 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
inter_data.multi_modal_placeholder_maps = placeholder_maps

# special processing for mrope position deltas.
if self.runner.model_is_mrope:
if self.runner.model_config.uses_mrope:
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
assert image_grid_thw is not None or video_grid_thw is not None, (
Expand Down Expand Up @@ -711,7 +710,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):

encoder_seq_len = 0

if self.runner.model_config.is_encoder_decoder_model:
if self.runner.model_config.is_encoder_decoder:
encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len()

inter_data = self.init_cached_inter_data(
Expand Down Expand Up @@ -837,7 +836,7 @@ def build(self) -> ModelInputForGPU:
if not inter_data.is_prompt:
max_decode_seq_len = max(max_decode_seq_len,
max(inter_data.seq_lens))
if self.runner.model_config.is_encoder_decoder_model:
if self.runner.model_config.is_encoder_decoder:
max_encoder_seq_len = max(max_encoder_seq_len,
inter_data.encoder_seq_len)

Expand Down Expand Up @@ -1375,12 +1374,6 @@ def list_prompt_adapters(self) -> Set[int]:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.list_adapters()

@property
def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases."""
return uses_mrope(self.model_config.hf_config)

@torch.inference_mode()
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
"""Cuda graph capture a model.
Expand Down Expand Up @@ -1411,7 +1404,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
max_batch_size = self.max_batchsize_to_capture
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
if self.model_is_mrope:
if self.model_config.uses_mrope:
input_positions = torch.tile(input_positions, (3, 1))
# Prepare dummy previous_hidden_states only if needed by the model.
# This is used by draft models such as EAGLE.
Expand Down Expand Up @@ -1447,7 +1440,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
self.attn_state.graph_capture_get_metadata_for_batch(
batch_size,
is_encoder_decoder_model=self.model_config.
is_encoder_decoder_model))
is_encoder_decoder))

if self.lora_config:
lora_mapping = LoRAMapping(
Expand All @@ -1466,7 +1459,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
graph_runner = CUDAGraphRunner(
self.model, self.attn_backend.get_name(),
self.attn_state.graph_clone(batch_size),
self.model_config.is_encoder_decoder_model)
self.model_config.is_encoder_decoder)

capture_inputs = {
"input_ids":
Expand Down Expand Up @@ -1497,7 +1490,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
self.model.get_seqlen_agnostic_capture_inputs(
batch_size)
})
if self.model_config.is_encoder_decoder_model:
if self.model_config.is_encoder_decoder:
# add the additional inputs to capture for
# encoder-decoder models.
self._update_inputs_to_capture_for_enc_dec_model(
Expand Down
5 changes: 1 addition & 4 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
ModelRunnerClass = model_runner_cls
elif model_config.task == "embedding":
ModelRunnerClass = EmbeddingModelRunner
elif self._is_encoder_decoder_model():
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = EncoderDecoderModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
vllm_config=self.vllm_config,
Expand Down Expand Up @@ -119,9 +119,6 @@ def stop_profile(self):
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()

def _is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model

def init_device(self) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
Expand Down

0 comments on commit ab1930c

Please sign in to comment.