Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Clean up and fix multi-modal processors #13012

Merged
merged 8 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/features/compatibility_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar
* ✅
* ✅
* ?
* [✗](gh-issue:7968>)
* [✗](gh-issue:7968)
* ?
* ✅
*
Expand Down
10 changes: 10 additions & 0 deletions tests/models/decoder_only/language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
"google/gemma-1.1-2b-it", # gemma
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
pytest.param(
"THUDM/chatglm3-6b", # ChatGLM (text-only)
),
pytest.param(
"meta-llama/Llama-3.2-1B-Instruct", # llama
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
Expand All @@ -43,6 +46,9 @@
"microsoft/phi-2", # phi
marks=[pytest.mark.core_model],
),
pytest.param(
"Qwen/Qwen-7B", # qwen (text-only)
),
pytest.param(
"Qwen/Qwen2.5-0.5B-Instruct", # qwen2
marks=[pytest.mark.core_model],
Expand All @@ -68,6 +74,10 @@ def test_models(
) -> None:

with hf_runner(model, dtype=dtype) as hf_model:
if model.startswith("THUDM/chatglm3"):
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.transformer.output_layer

hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

Expand Down
2 changes: 1 addition & 1 deletion tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _test_processing_correctness(
mm_data = {
k:
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
for _ in range(rng.randint(limit))]
for _ in range(rng.randint(limit + 1))]
for k, limit in limit_mm_per_prompt.items()
}

Expand Down
3 changes: 0 additions & 3 deletions tests/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ def random_video(
min_wh: int,
max_wh: int,
):
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412
num_frames = rng.randint(min_frames, max_frames)
num_frames = (num_frames // 2) * 2

w, h = rng.randint(min_wh, max_wh, size=(2, ))
return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8)

Expand Down
160 changes: 63 additions & 97 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# https://github.com/THUDM/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights."""
from argparse import Namespace
from typing import (Iterable, List, Mapping, Optional, Sequence, Set, Tuple,
TypedDict, Union)
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)

import torch
from torch import nn
Expand All @@ -19,7 +19,6 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
Expand All @@ -37,12 +36,10 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BatchFeature,
BoundPromptReplacement,
MultiModalFieldConfig,
PlaceholderFeaturesInfo,
PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
Expand All @@ -53,39 +50,6 @@
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)

logger = init_logger(__name__)

IMAGE_TOKEN_ID = 151329


def build_normalization_transform(image_size: int) -> transforms.Compose:
"""
Build a normalization transform which can be applied to one or
more input images from which we want to extract visual features.

Args:
image_size: size of the image to be processed for visual embeddings.

Returns:
Callable transform for normalizing and resizing one RGB image.
"""

return transforms.Compose([
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
])


def calculate_image_placeholder(vision_config):
return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2


class GLMImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
Expand All @@ -109,9 +73,20 @@ def __init__(
self.config = config
self.tokenizer = tokenizer

if hasattr(self.config, "vision_config"):
self.image_transform = build_normalization_transform(
config.vision_config["image_size"])
if vision_config := getattr(config, "vision_config", None):
image_size = vision_config["image_size"]

self.image_transform = transforms.Compose([
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
])
else:
self.image_transform = None

Expand Down Expand Up @@ -150,9 +125,19 @@ def __call__(

class GLM4VProcessingInfo(BaseProcessingInfo):

def __init__(self, ctx):
super().__init__(ctx)
self._pre_calculate()
def get_tokenizer(self):
tokenizer = self.ctx.tokenizer
assert isinstance(tokenizer, PreTrainedTokenizer)
return tokenizer

def get_hf_config(self):
return self.ctx.get_hf_config(ChatGLMConfig)

def get_hf_processor(self) -> GLM4VProcessor:
return GLM4VProcessor(
self.get_hf_config(),
self.get_tokenizer(),
)

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
Expand All @@ -162,27 +147,21 @@ def get_mm_max_tokens_per_item(
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_feature_tokens()}

return {"image": self.image_token_num + 2}

def _pre_calculate(self):
def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
self.image_token_num = calculate_image_placeholder(vision_config)
self.image_size = vision_config["image_size"]
if not (vision_config := getattr(hf_config, "vision_config", None)):
return 0

def get_num_image_tokens(self) -> int:
return self.image_token_num + 2
image_size = vision_config["image_size"]
patch_size = vision_config["patch_size"]
grid_length = image_size // patch_size // 2
return grid_length * grid_length

def get_image_size(self) -> ImageSize:

return ImageSize(height=self.image_size, width=self.image_size)

def get_hf_processor(self) -> GLM4VProcessor:
return GLM4VProcessor(
self.get_hf_config(),
self.get_tokenizer(),
)
def get_num_image_feature_tokens(self) -> int:
# EVA2CLIPModel has embeddings for boi and eoi tokens as well
return self.get_num_image_tokens() + 2


class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
Expand All @@ -192,18 +171,24 @@ def get_dummy_processor_inputs(
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.info.get_hf_config()
if not (vision_config := getattr(hf_config, "vision_config", None)):
return ProcessorInputs(prompt_text="", mm_data={})

target_width = target_height = vision_config["image_size"]
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_image_size()

mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
text = "<|begin_of_image|><|endoftext|><|end_of_image|>"

base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>"

return ProcessorInputs(
prompt_text=text,
prompt_text=base_text * num_images,
mm_data=mm_data,
)

Expand All @@ -223,47 +208,28 @@ def _get_prompt_replacements(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.info.get_hf_config()
if not hasattr(hf_config, "vision_config"):
return []

boi_token_id = hf_config.boi_token_id
image_token_id = hf_config.pad_token_id
eoi_token_id = hf_config.eoi_token_id

def get_replacement(item_idx: int):
image_tokens = self.info.image_token_num
return [IMAGE_TOKEN_ID] * image_tokens
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [image_token_id] * num_image_tokens

return [boi_token_id] + image_tokens + [eoi_token_id]

return [
PromptReplacement(
modality="image",
target=[IMAGE_TOKEN_ID],
target=[boi_token_id, image_token_id, eoi_token_id],
replacement=get_replacement,
),
]

def _apply_prompt_replacements(
self,
token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids,
mm_prompt_repls=mm_prompt_repls,
mm_item_counts=mm_item_counts,
)
hf_config = self.info.get_hf_config()
boi_token_id = hf_config.boi_token_id
eoi_token_id = hf_config.eoi_token_id
placeholders = {
modality: [
PlaceholderFeaturesInfo(
modality=p.modality,
item_idx=p.item_idx,
start_idx=p.start_idx - 1,
tokens=[boi_token_id] + p.tokens + [eoi_token_id],
) for p in ps
]
for modality, ps in placeholders.items()
}

return token_ids, text, placeholders


class GLMAttention(nn.Module):

Expand Down Expand Up @@ -618,7 +584,7 @@ def get_input_embeddings(
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=[
self.config.boi_token_id,
IMAGE_TOKEN_ID,
self.config.pad_token_id,
self.config.eoi_token_id,
],
)
Expand Down
Loading