Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Extend beyond image modality and support mixed-modality inference with Llava-OneVision #11685

Merged
merged 55 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
022c6b4
initial
ywang96 Jan 1, 2025
43fdf45
fix llava ov
ywang96 Jan 1, 2025
e0fb002
iterate
ywang96 Jan 1, 2025
a9b9757
Merge branch 'vllm-project:main' into v1-llava-ov
ywang96 Jan 2, 2025
b45010b
revert padding tensor
ywang96 Jan 2, 2025
d83e25e
simplify
ywang96 Jan 2, 2025
d13b0f7
comment
ywang96 Jan 2, 2025
7d1f19a
Merge branch 'vllm-project:main' into v1-llava-ov
ywang96 Jan 2, 2025
6959ec0
simplify and doc
ywang96 Jan 2, 2025
ba071c6
refactor logic
ywang96 Jan 3, 2025
ba2f399
format
ywang96 Jan 3, 2025
ff4cdea
Merge branch 'vllm-project:main' into v1-llava-ov
ywang96 Jan 3, 2025
2eebfd9
switch order
ywang96 Jan 3, 2025
20dd84d
refactor
ywang96 Jan 3, 2025
34ec194
typing
ywang96 Jan 3, 2025
9f19629
hasher
ywang96 Jan 3, 2025
66484aa
consolidate mm hasher
ywang96 Jan 4, 2025
1423f5f
typing
ywang96 Jan 4, 2025
ba17100
Merge branch 'vllm-project:main' into v1-llava-ov
ywang96 Jan 4, 2025
b3c41ce
Merge branch 'main' into v1-llava-ov
ywang96 Jan 5, 2025
14481fd
fix length check
ywang96 Jan 5, 2025
6f435cf
update profiling
ywang96 Jan 5, 2025
16e5b04
update dummy data for llava-ov
ywang96 Jan 5, 2025
612880b
preserve modality order
ywang96 Jan 5, 2025
3022754
format
ywang96 Jan 5, 2025
20d6a67
simplify
ywang96 Jan 5, 2025
3dd2db2
typo
ywang96 Jan 5, 2025
5ce6f7a
clarify
ywang96 Jan 5, 2025
4113e51
add test
ywang96 Jan 5, 2025
3ca30fc
fix test
ywang96 Jan 5, 2025
ef8c6d1
add note
ywang96 Jan 5, 2025
87f4216
Merge branch 'v1-llava-ov' of https://github.com/ywang96/vllm into v1…
ywang96 Jan 5, 2025
bc1debd
comment
ywang96 Jan 5, 2025
56a7ef0
typo
ywang96 Jan 5, 2025
568a586
rename
ywang96 Jan 5, 2025
6ca99a3
remove redundant constants
ywang96 Jan 5, 2025
6c8ff3b
update interface with note
ywang96 Jan 5, 2025
293b3fe
update doc
ywang96 Jan 5, 2025
14482bf
address review comments
ywang96 Jan 6, 2025
eeee402
use namedtuple
ywang96 Jan 6, 2025
7f4815e
add comment
ywang96 Jan 6, 2025
1ba40e9
update
ywang96 Jan 6, 2025
2eb4cf1
format
ywang96 Jan 6, 2025
fe71431
format
ywang96 Jan 6, 2025
1a7b39c
remove unneeded check
ywang96 Jan 6, 2025
61991b6
Merge branch 'main' into v1-llava-ov
ywang96 Jan 6, 2025
ceec26e
remove unused import
ywang96 Jan 6, 2025
7879952
restrict mm_hash to V1
ywang96 Jan 6, 2025
72ae769
fix test and reorder code for readability
ywang96 Jan 6, 2025
48811b6
typo
ywang96 Jan 6, 2025
b31fd4f
format
ywang96 Jan 6, 2025
be54b2c
Fix dummy requests
DarkLight1337 Jan 6, 2025
b2cbc5a
Pass sanity check
DarkLight1337 Jan 6, 2025
3400d07
format
DarkLight1337 Jan 6, 2025
2461f0f
Merge branch 'main' into v1-llava-ov
DarkLight1337 Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 13 additions & 20 deletions vllm/model_executor/models/llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)

# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
# Ref: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA_OneVision.md?plain=1#L14
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 2304

# For profile run
_MAX_FRAMES_PER_VIDEO = 16
Expand Down Expand Up @@ -366,9 +366,11 @@ def input_processor_for_llava_onevision(ctx: InputContext,
and "image" not in multi_modal_data):
return inputs
if "image" in multi_modal_data:
return input_processor_when_multimodal_input_image(ctx, inputs)
inputs = input_processor_when_multimodal_input_image(ctx, inputs)
if "video" in multi_modal_data:
return input_processor_when_multimodal_input_video(ctx, inputs)
else:
return inputs

msg = "Unsupported multi data type"
raise NotImplementedError(msg)
Expand Down Expand Up @@ -832,21 +834,18 @@ def get_multimodal_embeddings(
if not modalities:
return None

# We make a tuple of each embedding with its modality string. This is a
# temporary workaround for models to handle mixed modalities when
# get_multimodal_embeddings and get_input_embeddings are called
# separately.
# TODO(ywang96): Add support for mixed-modality inference for v1.
multimodal_embeddings: List[Tuple[NestedTensors, str]] = []
# The result multimoal_embeddings is tuple of tensors, with each
ywang96 marked this conversation as resolved.
Show resolved Hide resolved
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
ywang96 marked this conversation as resolved.
Show resolved Hide resolved

if "images" in modalities:
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings.append((vision_embeddings, "image"))
multimodal_embeddings += tuple(vision_embeddings)
if "videos" in modalities:
video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input)
multimodal_embeddings.append((video_embeddings, "video"))
multimodal_embeddings += tuple(video_embeddings)

return multimodal_embeddings

Expand All @@ -858,15 +857,9 @@ def get_input_embeddings(
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
for embeddings, modality in multimodal_embeddings:
if modality == "image":
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, embeddings,
self.config.image_token_index)
if modality == "video":
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, embeddings,
self.config.video_token_index)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_index, self.config.video_token_index])
return inputs_embeds

def forward(
Expand Down
3 changes: 2 additions & 1 deletion vllm/multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .base import MultiModalPlaceholderMap, MultiModalPlugin
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalKwargs,
MultiModalDataDict, MultiModalHashDict, MultiModalKwargs,
MultiModalPlaceholderDict, NestedTensors)
from .registry import MultiModalRegistry

Expand All @@ -18,6 +18,7 @@
"ModalityData",
"MultiModalDataBuiltins",
"MultiModalDataDict",
"MultiModalHashDict",
"MultiModalKwargs",
"MultiModalPlaceholderDict",
"MultiModalPlaceholderMap",
Expand Down
7 changes: 6 additions & 1 deletion vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,11 @@ def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
A dictionary containing placeholder ranges.
"""

MultiModalHashDict = Mapping[str, Sequence[str]]
ywang96 marked this conversation as resolved.
Show resolved Hide resolved
"""
A dictionary containing hashes for items in each modality.
"""


class MultiModalInputsV2(TypedDict):
"""
Expand All @@ -513,7 +518,7 @@ class MultiModalInputsV2(TypedDict):
mm_kwargs: MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching."""

mm_hashes: NotRequired[list[str]]
mm_hashes: NotRequired[MultiModalHashDict]
"""The hashes of the multi-modal data."""

mm_placeholders: MultiModalPlaceholderDict
Expand Down
80 changes: 21 additions & 59 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pickle
import re
from abc import ABC, abstractmethod
from collections import defaultdict
Expand All @@ -9,8 +8,6 @@

import numpy as np
import numpy.typing as npt
import torch
from blake3 import blake3
from PIL import Image
from transformers import BatchFeature, ProcessorMixin

Expand All @@ -23,6 +20,7 @@
MultiModalInputsV2, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange)
from .parse import MultiModalDataItems, MultiModalDataParser
from .utils import hash_kwargs

logger = init_logger(__name__)

Expand Down Expand Up @@ -492,56 +490,6 @@ def _maybe_log_cache_stats(self) -> None:
logger.debug("ProcessingCache: hit_ratio = %.2f",
cache_stats.hit_ratio)

def _serialize_item(self, obj: object) -> bytes:
# Simple cases
if isinstance(obj, str):
return obj.encode("utf-8")
if isinstance(obj, bytes):
return obj
if isinstance(obj, Image.Image):
return obj.tobytes()

# Convertible to NumPy arrays
if isinstance(obj, torch.Tensor):
obj = obj.numpy()
if isinstance(obj, (int, float)):
obj = np.array(obj)
if isinstance(obj, np.ndarray):
return obj.tobytes()

logger.warning(
"No serialization method found for %s. "
"Falling back to pickle.", type(obj))

return pickle.dumps(obj)

def _item_to_bytes(
self,
key: str,
obj: object,
) -> Iterable[tuple[bytes, bytes]]:
# Recursive cases
if isinstance(obj, (list, tuple)):
for i, elem in enumerate(obj):
yield from self._item_to_bytes(f"{key}.{i}", elem)
elif isinstance(obj, dict):
for k, v in obj.items():
yield from self._item_to_bytes(f"{key}.{k}", v)
else:
key_bytes = self._serialize_item(key)
value_bytes = self._serialize_item(obj)
yield key_bytes, value_bytes

def _hash_kwargs(self, **kwargs: object) -> str:
hasher = blake3()

for k, v in kwargs.items():
for k_bytes, v_bytes in self._item_to_bytes(k, v):
hasher.update(k_bytes)
hasher.update(v_bytes)

return hasher.hexdigest()

def get(
self,
model_id: str,
Expand All @@ -560,9 +508,9 @@ def get(
"""
self._maybe_log_cache_stats()

cache_key = self._hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
cache_key = hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
return self._cache.get(cache_key)

def put(
Expand All @@ -577,9 +525,9 @@ def put(
Put a processed multi-modal item into the cache
according to its dependencies (see :meth:`get`).
"""
cache_key = self._hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
cache_key = hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
self._cache.put(cache_key, output_kwargs)


Expand Down Expand Up @@ -998,6 +946,19 @@ def apply(
"""
mm_items = self._to_mm_items(mm_data)

# Create MM hashes
# TODO: Use these hash keys for caching operations in apply_hf_processor
# instead of rehashing.
model_id = self.ctx.model_config.model
mm_hashes = {
modality: [
hash_kwargs(model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs) for item in items
]
for modality, items in mm_items.items()
}

prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
prompt_text,
mm_items,
Expand Down Expand Up @@ -1058,6 +1019,7 @@ def apply(
prompt=prompt_text,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholders,
)

Expand Down
Loading
Loading