Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Increase default max_num_batched_tokens for multimodal models #8028

Merged
merged 4 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
logger = init_logger(__name__)

_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 8192

_PP_SUPPORTED_MODELS = [
"AquilaModel",
Expand Down Expand Up @@ -571,6 +572,10 @@ def is_embedding_model(self) -> bool:
"""Extract the embedding model flag."""
return self.embedding_mode

@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None


class CacheConfig:
"""Configuration for the KV cache.
Expand Down Expand Up @@ -947,25 +952,36 @@ def __init__(self,
num_lookahead_slots: int = 0,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
embedding_mode: Optional[bool] = False,
embedding_mode: bool = False,
is_multimodal_model: bool = False,
preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1,
send_delta_data: bool = False) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
else:
if max_num_batched_tokens is None:
if enable_chunked_prefill:
# It is the values that have the best balance between ITL
# and TTFT on A100. Note it is not optimized for throughput.
self.max_num_batched_tokens = 512
elif embedding_mode:
# For embedding, choose specific value for higher throughput
self.max_num_batched_tokens = max(
max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS)
max_num_batched_tokens = 512
else:
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048)
max_num_batched_tokens = max(max_model_len, 2048)

if embedding_mode:
# For embedding, choose specific value for higher throughput
max_num_batched_tokens = max(
max_num_batched_tokens,
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
if is_multimodal_model:
# The value needs to be at least the number of multimodal tokens
max_num_batched_tokens = max(
max_num_batched_tokens,
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)

self.max_num_batched_tokens = max_num_batched_tokens

if enable_chunked_prefill:
logger.info(
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
Expand Down
1 change: 1 addition & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,7 @@ def create_engine_config(self) -> EngineConfig:
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
is_multimodal_model=model_config.is_multimodal_model,
preemption_mode=self.preemption_mode,
num_scheduler_steps=self.num_scheduler_steps,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
Expand Down
6 changes: 5 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2019,7 +2019,7 @@ def _validate_model_inputs(self, inputs: Union[LLMInputs,
if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")

if self.model_config.multimodal_config is not None:
if self.model_config.is_multimodal_model:
max_prompt_len = self.model_config.max_model_len

if len(prompt_ids) > max_prompt_len:
Expand All @@ -2030,3 +2030,7 @@ def _validate_model_inputs(self, inputs: Union[LLMInputs,
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.")

# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
2 changes: 1 addition & 1 deletion vllm/worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def assert_enc_dec_mr_supported_scenario(
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])

if enc_dec_mr.model_config.multimodal_config is not None:
if enc_dec_mr.model_config.is_multimodal_model:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])

Expand Down
Loading