Skip to content

Commit

Permalink
resovle comments
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangxu-google committed Oct 10, 2024
1 parent b0d5064 commit a69f7f9
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 93 deletions.
2 changes: 0 additions & 2 deletions examples/offline_inference_vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,6 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
limit_mm_per_prompt={"image": len(image_urls)},
)

question = "Between the two images, " \
"which one is a lion and which one is a duck?"
prompt = f"<|image|><|image|><|begin_of_text|>{question}"
return ModelRequestData(
llm=llm,
Expand Down
193 changes: 125 additions & 68 deletions tests/models/encoder_decoder/vision_language/test_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,22 @@
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs

from ....conftest import HfRunner, PromptImageInput, VllmRunner, _ImageAssets
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from ....utils import large_gpu_test
from ...utils import check_logprobs_close

_LIMIT_IMAGE_PER_PROMPT = 2

# ("prompt", [image indices])
# image-0: stop sign, image-1: cherry blossom
# For each entry, we will generate a batch of
# samples with different image sizes.
PROMPTS = [
# Single leading image.
("<|image|><|begin_of_text|>The meaning of the image is", [0]),
("<|image|><|begin_of_text|>The city is", [1]),
# Single interleaved image.
("<|begin_of_text|>The meaning of the image<|image|> is", [0]),
# Multi leading images.
("<|image|><|image|><|begin_of_text|>Between the first and second image, "
"which is stop sign and which is cherry blossom?", [0, 1]),
# Multi interleaved images.
("<|begin_of_text|>Between the first image<|image|> and second "
"image<|image|>, which is stop sign and which is cherry blossom?", [0,
1]),
# Text only.
("The color of the sky is blue but sometimes it can also be", []),
_LIMIT_IMAGE_PER_PROMPT = 3

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|image|><|begin_of_text|>The meaning of the image is",
"cherry_blossom":
"<|image|><|begin_of_text|>The city is",
})

text_only_prompts = [
"The color of the sky is blue but sometimes it can also be",
]

models = [
Expand Down Expand Up @@ -68,49 +59,32 @@ def _get_inputs(
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
) -> List[Tuple[List[str], List[Optional[PromptImageInput]]]]:
assets = [asset.pil_image for asset in image_assets]
assert len(assets) >= 2

# Inputs is a list of batches, a batch is a tuple of
# (prompts, images), prompts is a list of strings,
# images is a nested list of PIL images.
# len(prompts) == len(images)
# A batch will trigger a generate run.
inputs = []
for entry in PROMPTS:
prompt, image_indices = entry
images = [assets[i] for i in image_indices]
batch_prompts = []
batch_images = []
if size_factors is not None:
for factor in size_factors:
if factor is None:
batch_prompts.append(PROMPTS[-1][0])
batch_images.append(None)
else:
batch_prompts.append(prompt)
resized_images = [
rescale_image_size(image, factor) for image in images
]
batch_images.append(
resized_images if resized_images else None)
elif sizes is not None:
for size in sizes:
if size is None:
batch_prompts.append(PROMPTS[-1][0])
batch_images.append(None)
else:
batch_prompts.append(prompt)
resized_images = [image.resize(size) for image in images]
batch_images.append(
resized_images if resized_images else None)
else:
raise ValueError(
"You must provide either `size_factors` or `sizes`")
assert len(batch_prompts) == len(batch_images)
inputs.append((batch_prompts, batch_images))
return inputs
) -> List[Tuple[List[str], PromptImageInput]]:
images = [asset.pil_image for asset in image_assets]

if size_factors is not None:
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
elif sizes is not None:
inputs_per_image = [(
[
prompt if size is not None else text_only_prompts[0]
for size in sizes
],
[
image.resize(size) if size is not None else None
for size in sizes
],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
if len(sizes) == 0:
inputs_per_image.append(
(text_only_prompts, [None] * len(text_only_prompts)))
else:
raise ValueError("You must provide either `size_factors` or `sizes`")

return inputs_per_image


@overload
Expand Down Expand Up @@ -177,7 +151,7 @@ def run_test(
def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], List[Optional[PromptImageInput]]]],
inputs: List[Tuple[List[str], PromptImageInput]],
model: str,
*,
dtype: str,
Expand Down Expand Up @@ -252,6 +226,8 @@ def process(hf_inputs: BatchEncoding):
@pytest.mark.parametrize(
"sizes",
[
# Text only
[],
# Single-size
[(512, 512)],
# Single-size, batched
Expand All @@ -268,8 +244,9 @@ def process(hf_inputs: BatchEncoding):
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype,
max_tokens, num_logprobs) -> None:
def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
model, sizes, dtype, max_tokens,
num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
Expand All @@ -281,3 +258,83 @@ def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)


@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens,
num_logprobs) -> None:

stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image

inputs = [(
[
"<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501
"<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501
"<|image|><|image|><|image|><|begin_of_text|>Describe 3 images.", # noqa: E501
],
[
[stop_sign, cherry_blossom],
# Images with different sizes.
[
stop_sign.resize((512, 512)),
stop_sign,
],
[
stop_sign,
stop_sign.resize((512, 1536)),
cherry_blossom.resize((512, 1024)),
],
])]

_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)


@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
dtype, max_tokens, num_logprobs) -> None:

stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image

inputs = [(
[
"<|begin_of_text|>The meaning of the image <|image|> is", # noqa: E501
"<|begin_of_text|>Is this <|image|> a stop sign and is this <|image|> a cherry blossom?", # noqa: E501
],
[
[stop_sign.resize((1536, 512))],
[
stop_sign.resize((1024, 512)),
cherry_blossom.resize((512, 2028)),
],
])]

_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
52 changes: 29 additions & 23 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,27 +1128,14 @@ def flat_encoder_result(self, cross_attention_states: torch.Tensor,
start_pos:end_pos] = vision_token_in_batch[:seq_len]
start_pos = end_pos
cross_attention_states = cross_attention_states_flat

full_text_row_masked_out_mask = torch.ones(
(attn_metadata.num_prefill_tokens, 1), dtype=torch.bool)
start_pos = 0
for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens,
attn_metadata.encoder_seq_lens):
if encoder_seq_len == 0:
full_text_row_masked_out_mask[start_pos:start_pos +
seq_len] = False
start_pos += seq_len
full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
cross_attention_states.device)

return cross_attention_states, full_text_row_masked_out_mask
return cross_attention_states

def get_cross_attention_states(
self,
image_inputs: MllamaImagePixelInputs,
attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor]:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values = image_inputs['data']
aspect_ratio_ids = image_inputs['aspect_ratio_ids']
Expand All @@ -1163,12 +1150,10 @@ def get_cross_attention_states(
cross_attention_states = cross_attention_states.view(
bsz, -1, image_token_dim)

cross_attention_states, full_text_row_masked_out_mask = \
self.flat_encoder_result(
cross_attention_states, attn_metadata,
actual_encoder_seq_lens)
cross_attention_states = self.flat_encoder_result(
cross_attention_states, attn_metadata, actual_encoder_seq_lens)

return cross_attention_states, full_text_row_masked_out_mask
return cross_attention_states

def get_cross_attention_mask(
self,
Expand Down Expand Up @@ -1206,6 +1191,24 @@ def get_cross_attention_mask(

return cross_attention_mask, kv_range_for_decode

def get_full_text_row_masked_out_mask(
self,
attn_metadata: AttentionMetadata,
device: torch.device,
) -> torch.Tensor:
full_text_row_masked_out_mask = torch.ones(
(attn_metadata.num_prefill_tokens, 1), dtype=torch.bool)
start_pos = 0
for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens,
attn_metadata.encoder_seq_lens):
if encoder_seq_len == 0:
full_text_row_masked_out_mask[start_pos:start_pos +
seq_len] = False
start_pos += seq_len
full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
device)
return full_text_row_masked_out_mask

def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -1248,9 +1251,12 @@ def forward(
actual_encoder_seq_lens, attn_metadata.encoder_seq_lens):
assert actual_len >= last_group_len

cross_attention_states, full_text_row_masked_out_mask = \
self.get_cross_attention_states(
image_inputs, attn_metadata, actual_encoder_seq_lens)
cross_attention_states = self.get_cross_attention_states(
image_inputs, attn_metadata, actual_encoder_seq_lens)

full_text_row_masked_out_mask = \
self.get_full_text_row_masked_out_mask(
attn_metadata, input_ids.device)

cross_attention_mask, kv_range_for_decode = \
self.get_cross_attention_mask(
Expand Down

0 comments on commit a69f7f9

Please sign in to comment.