Skip to content

Commit

Permalink
test passed
Browse files Browse the repository at this point in the history
  • Loading branch information
zifeitong committed Aug 19, 2024
1 parent 46b527f commit bd81e71
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 19 deletions.
8 changes: 4 additions & 4 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Minimal implementation of CLIPVisionModel intended to be only used
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from array import array
from typing import Iterable, Optional, Tuple
from typing import Iterable, Optional, Tuple, Union, List

import torch
import torch.nn as nn
Expand Down Expand Up @@ -84,7 +84,7 @@ def input_processor_for_clip(
llm_inputs: LLMInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
image_feature_size_override: Optional[Union[int, List[int]]] = None,
):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
Expand Down Expand Up @@ -217,7 +217,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

class CLIPEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`CLIPEncoderLayer`].
Args:
Expand Down
18 changes: 9 additions & 9 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import is_list_of

from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_clip_image_feature_size,
Expand Down Expand Up @@ -223,14 +224,13 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
input_height=height,
input_width=width,
)
elif isinstance(image_data, list):
width, height = image_data[0].size

image_feature_size = get_llava_next_image_feature_size(
hf_config,
input_height=height,
input_width=width,
)
elif is_list_of(image_data, Image.Image):
image_feature_size = [
get_llava_next_image_feature_size(hf_config,
input_height=img.height,
input_width=img.width)
for img in image_data
]
elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0]
else:
Expand Down Expand Up @@ -435,7 +435,7 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
)
num_patches = num_patch_height * num_patch_width

# image patches might be padded for batch.
# Image patches might be padded for batch processing
other_patch_embeds = other_patch_embeds[:num_patches] \
.view(num_patch_height, num_patch_width, height, width, -1)

Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import math
from array import array
from typing import Iterable, Optional, Tuple
from typing import Iterable, Optional, Tuple, Union, List

import torch
from PIL import Image
Expand Down Expand Up @@ -93,7 +93,7 @@ def input_processor_for_siglip(
llm_inputs: LLMInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
image_feature_size_override: Optional[Union[int, List[int]]] = None,
):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
Expand Down
13 changes: 9 additions & 4 deletions vllm/multimodal/image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import lru_cache
from typing import List, Optional, Tuple, TypeVar
from typing import List, Optional, Tuple, TypeVar, Union

import torch
from PIL import Image
Expand Down Expand Up @@ -44,10 +44,13 @@ def repeat_and_pad_image_tokens(
prompt_token_ids: List[int],
*,
image_token_id: int,
repeat_count: int = 1,
repeat_count: Union[int, List[int]] = 1,
pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int]]:
if not isinstance(repeat_count, list):
repeat_count = [repeat_count] * len(prompt_token_ids)

if prompt is None:
new_prompt = None
else:
Expand All @@ -59,7 +62,7 @@ def repeat_and_pad_image_tokens(
replacement_str = "".join(
repeat_and_pad_token(
image_token_str,
repeat_count=repeat_count,
repeat_count=repeat_count[0],
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))
Expand All @@ -76,15 +79,17 @@ def repeat_and_pad_image_tokens(
new_prompt = prompt.replace(image_token_str, replacement_str)

new_token_ids: List[int] = []
idx = 0
for i, token in enumerate(prompt_token_ids):
if token == image_token_id:
replacement_ids = repeat_and_pad_token(
image_token_id,
repeat_count=repeat_count,
repeat_count=repeat_count[idx],
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)
new_token_ids.extend(replacement_ids)
idx += 1
else:
new_token_ids.append(token)

Expand Down

0 comments on commit bd81e71

Please sign in to comment.