Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Frontend] Add Support for Inference Time mm_processor_kwargs #9131

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def run_phi3v(question, modality):
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs={"num_crops": 16},
)
stop_token_ids = None
Expand Down
110 changes: 70 additions & 40 deletions tests/multimodal/test_processor_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def mm_model_cls():
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
"num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
"pixel_values": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
}


### Test for default processor logic & mm_processor_kwargs wrapping
### Tests for default processor logic & mm_processor_kwargs wrapping
def test_default_processor_is_a_noop():
"""Ensure that by default, there is no processor override."""
dummy_registry = InputRegistry()
Expand All @@ -89,23 +89,46 @@ def test_default_processor_is_a_noop():
assert proc_inputs is proc_outputs


@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_processor_default_kwargs(use_processor_mock, num_crops):
"""Ensure input processors can use processor kwargs."""
dummy_registry = InputRegistry()
def _get_num_crops_info(init_num_crops: int, inference_num_crops: int):
"""Get the init / inference kwargs and expected num_crops for this test."""
# If we have a value for num_crops, pass the override value and make
# sure we get that value as a return-value from out mock processor,
# otherwise fall back to the default value
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
init_kwargs = None if init_num_crops is None else {
"num_crops": init_num_crops
}
expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)
inference_kwargs = None if inference_num_crops is None else {
"num_crops": inference_num_crops
}
if inference_num_crops is not None:
expected_seq_count = inference_num_crops
elif init_num_crops is not None:
expected_seq_count = init_num_crops
else:
expected_seq_count = DEFAULT_NUM_CROPS
return init_kwargs, inference_kwargs, expected_seq_count


@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
(None, None),
(NUM_CROPS_OVERRIDE, None),
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
])
def test_input_processor_kwargs(use_processor_mock, init_num_crops,
inference_num_crops):
"""Ensure input processors can use processor kwargs."""
dummy_registry = InputRegistry()

init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
init_num_crops, inference_num_crops)

num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
assert num_crops_val == expected_num_crops
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(
LLMInputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=inference_kwargs))
assert num_crops_val == expected_seq_count


@pytest.mark.parametrize(
Expand All @@ -124,11 +147,16 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
mm_processor_kwargs):
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
dummy_registry = InputRegistry()
# Should filter out the init time kwargs
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)

processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
# Should filter out the inference time kwargs
num_crops_val = processor(
LLMInputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=mm_processor_kwargs))
assert num_crops_val == DEFAULT_NUM_CROPS


Expand Down Expand Up @@ -271,32 +299,34 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1


@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
(None, None),
(NUM_CROPS_OVERRIDE, None),
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
])
def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops,
inference_num_crops):
"""Ensure custom mappers can use processor kwargs."""
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
init_num_crops, inference_num_crops)

ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
mm_processor_kwargs=init_kwargs,
limit_mm_per_prompt={"image": 1})

mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
image = image_assets[0].pil_image
mm_inputs = {"image": image}

with patch.object(
mm_registry._get_plugin("image"),
"_default_input_mapper",
{mm_model_cls(): custom_mapper},
):
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
mm_model_cls())
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs,
inference_kwargs)

assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1

Expand All @@ -316,24 +346,24 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
mm_processor_kwargs):
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
# Should filter out the init time kwargs
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})

mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
image = image_assets[0].pil_image
mm_inputs = {"image": image}

with patch.object(
mm_registry._get_plugin("image"),
"_default_input_mapper",
{mm_model_cls(): custom_mapper},
):
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
mm_model_cls())
# Should filter out the inference time kwargs
mapped_inputs = mm_registry.map_input(
ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs)

assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1
26 changes: 26 additions & 0 deletions tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.parse import parse_and_batch_prompt

STRING_INPUTS = [
Expand Down Expand Up @@ -51,3 +52,28 @@ def test_parse_single_batch_token_consistent(token_input: List[int]):
def test_parse_single_batch_string_slice(inputs_slice: slice):
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice])


# yapf: disable
@pytest.mark.parametrize('mm_processor_kwargs,expected_mm_kwargs', [
(None, [{}, {}]),
({}, [{}, {}]),
({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
])
# yapf: enable
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
"""Test mm_processor_kwargs init for zipping enc/dec prompts."""
encoder_prompts = ['An encoder prompt', 'Another encoder prompt']
decoder_prompts = ['A decoder prompt', 'Another decoder prompt']
zipped_prompts = zip_enc_dec_prompts(encoder_prompts, decoder_prompts,
mm_processor_kwargs)
assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
for enc, dec, exp_kwargs, zipped in zip(encoder_prompts, decoder_prompts,
expected_mm_kwargs,
zipped_prompts):
assert isinstance(zipped, dict)
assert len(zipped.keys()) == 3
assert zipped['encoder_prompt'] == enc
assert zipped['decoder_prompt'] == dec
assert zipped['mm_processor_kwargs'] == exp_kwargs
32 changes: 31 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
get_open_port, merge_async_iterators)
get_open_port, merge_async_iterators, supports_kw)

from .utils import error_on_warning

Expand Down Expand Up @@ -236,3 +236,33 @@ def test_no_model_tag(parser_with_config):
with pytest.raises(ValueError):
parser_with_config.parse_args(
['serve', '--config', './data/test_config.yaml'])


# yapf: enable
@pytest.mark.parametrize(
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
[
# Tests for positional argument support
(lambda foo: None, "foo", True, True, False),
(lambda foo: None, "foo", False, True, True),
# Tests for positional or keyword / keyword only
(lambda foo=100: None, "foo", True, True, False),
(lambda *, foo: None, "foo", False, True, True),
# Tests to make sure the names of variadic params are NOT supported
(lambda *args: None, "args", False, True, False),
(lambda **kwargs: None, "kwargs", False, True, False),
# Tests for if we allow var kwargs to add support
(lambda foo: None, "something_else", False, True, False),
(lambda foo, **kwargs: None, "something_else", False, True, True),
(lambda foo, **kwargs: None, "kwargs", True, True, False),
(lambda foo, **kwargs: None, "foo", True, True, False),
])
# yapf: disable
def test_supports_kw(callable,kw_name,requires_kw_only,
allow_var_kwargs,is_supported):
assert supports_kw(
callable=callable,
kw_name=kw_name,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs
) == is_supported
1 change: 1 addition & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,7 @@ def schedule(
# `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request,
)
else:
Expand Down
7 changes: 7 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,13 @@ def add_request(
)
processed_inputs = self.input_processor(preprocessed_inputs)

# This is a bit of a hack - copy the mm_processor_kwargs that were
# used in the input processor to the processed output, since these
# kwargs are presumed to be immutable and the values should be aligned
# between the input processor (here) and the input mapper.
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
"mm_processor_kwargs")

self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
Expand Down
9 changes: 9 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ def chat(
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: Optional[List[Dict[str, Any]]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> List[RequestOutput]:
"""
Generate responses for a chat conversation.
Expand Down Expand Up @@ -532,6 +533,8 @@ def chat(
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be `True`
if `add_generation_prompt` is also `True`.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.

Returns:
A list of ``RequestOutput`` objects containing the generated
Expand All @@ -553,6 +556,9 @@ def chat(
tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()

# NOTE: _parse_chat_message_content_parts() currently doesn't
# handle mm_processor_kwargs, since there is no implementation in
# the chat message parsing for it.
conversation, mm_data = parse_chat_messages(
msgs, model_config, tokenizer)

Expand Down Expand Up @@ -585,6 +591,9 @@ def chat(
if mm_data is not None:
prompt["multi_modal_data"] = mm_data

if mm_processor_kwargs is not None:
prompt["mm_processor_kwargs"] = mm_processor_kwargs

prompts.append(prompt)

return self.generate(
Expand Down
Loading
Loading