Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangxu-google committed Oct 8, 2024
1 parent 7238dbc commit 463d2a5
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 45 deletions.
108 changes: 66 additions & 42 deletions tests/models/encoder_decoder/vision_language/test_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,31 @@
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs

from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from ....conftest import HfRunner, PromptImageInput, VllmRunner, _ImageAssets
from ....utils import large_gpu_test
from ...utils import check_logprobs_close

_LIMIT_IMAGE_PER_PROMPT = 1

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|image|><|begin_of_text|>The meaning of the image is",
"cherry_blossom":
"<|image|><|begin_of_text|>The city is",
})

text_only_prompts = [
"The color of the sky is blue but sometimes it can also be",
_LIMIT_IMAGE_PER_PROMPT = 2

# ("prompt", [image indices])
# image-0: stop sign, image-1: cherry blossom
# For each entry, we will generate a batch of
# samples with different image sizes.
PROMPTS = [
# Single leading image.
("<|image|><|begin_of_text|>The meaning of the image is", [0]),
("<|image|><|begin_of_text|>The city is", [1]),
# Single interleaved image.
("<|begin_of_text|>The meaning of the image<|image|> is", [0]),
# Multi leading images.
("<|image|><|image|><|begin_of_text|>Between the first and second image, "
"which is stop sign and which is cherry blossom?", [0, 1]),
# Multi interleaved images.
("<|begin_of_text|>Between the first image<|image|> and second "
"image<|image|>, which is stop sign and which is cherry blossom?", [0,
1]),
# Text only.
("The color of the sky is blue but sometimes it can also be", []),
]

models = [
Expand Down Expand Up @@ -59,32 +68,49 @@ def _get_inputs(
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
) -> List[Tuple[List[str], PromptImageInput]]:
images = [asset.pil_image for asset in image_assets]

if size_factors is not None:
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
elif sizes is not None:
inputs_per_image = [(
[
prompt if size is not None else text_only_prompts[0]
for size in sizes
],
[
image.resize(size) if size is not None else None
for size in sizes
],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
if len(sizes) == 0:
inputs_per_image.append(
(text_only_prompts, [None] * len(text_only_prompts)))
else:
raise ValueError("You must provide either `size_factors` or `sizes`")

return inputs_per_image
) -> List[Tuple[List[str], List[Optional[PromptImageInput]]]]:
assets = [asset.pil_image for asset in image_assets]
assert len(assets) >= 2

# Inputs is a list of batches, a batch is a tuple of
# (prompts, images), prompts is a list of strings,
# images is a nested list of PIL images.
# len(prompts) == len(images)
# A batch will trigger a generate run.
inputs = []
for entry in PROMPTS:
prompt, image_indices = entry
images = [assets[i] for i in image_indices]
batch_prompts = []
batch_images = []
if size_factors is not None:
for factor in size_factors:
if factor is None:
batch_prompts.append(PROMPTS[-1][0])
batch_images.append(None)
else:
batch_prompts.append(prompt)
resized_images = [
rescale_image_size(image, factor) for image in images
]
batch_images.append(
resized_images if resized_images else None)
elif sizes is not None:
for size in sizes:
if size is None:
batch_prompts.append(PROMPTS[-1][0])
batch_images.append(None)
else:
batch_prompts.append(prompt)
resized_images = [image.resize(size) for image in images]
batch_images.append(
resized_images if resized_images else None)
else:
raise ValueError(
"You must provide either `size_factors` or `sizes`")
assert len(batch_prompts) == len(batch_images)
inputs.append((batch_prompts, batch_images))
return inputs


@overload
Expand Down Expand Up @@ -151,7 +177,7 @@ def run_test(
def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput]],
inputs: List[Tuple[List[str], List[Optional[PromptImageInput]]]],
model: str,
*,
dtype: str,
Expand Down Expand Up @@ -226,8 +252,6 @@ def process(hf_inputs: BatchEncoding):
@pytest.mark.parametrize(
"sizes",
[
# Text only
[],
# Single-size
[(512, 512)],
# Single-size, batched
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
assert hf_config.vision_config.image_size % 14 == 0, \
"chunk size should be multiple of 14"
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
print(f"vllm num_tiles: {num_tiles}")
num_tokens = num_tiles * token_per_chunk
llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens
llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID
Expand Down Expand Up @@ -1151,7 +1150,6 @@ def get_cross_attention_states(
) -> Tuple[torch.Tensor, torch.Tensor]:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values = image_inputs['data']
print(f"pixel_values={pixel_values.shape}")
aspect_ratio_ids = image_inputs['aspect_ratio_ids']
aspect_ratio_mask = image_inputs['aspect_ratio_mask']
cross_attention_states = self.vision_model(pixel_values,
Expand Down Expand Up @@ -1189,7 +1187,6 @@ def get_cross_attention_mask(
get_cross_attention_token_mask(t, MLLAMA_IMAGE_TOKEN_ID)
for t in batch_token_ids
]
print(f"sparse_mask={sparse_mask}")

# Skip generating cross-attention mask if all samples
# are text-only or have only 1 leading image.
Expand Down

0 comments on commit 463d2a5

Please sign in to comment.