Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VLMs: fix number of image tokens #34332

Merged
merged 10 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,7 @@ def forward(
if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values)
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
n_image_features = image_tokens.shape[0]
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
if n_image_tokens_in_text != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,9 @@ def forward(

# TODO: @raushan retain only the new behavior after v4.47
else:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this slows down inference no? .item() induces cuda cpu synch!
Anyways not the point, but thanks for the fix.
n_image_features takes into account padding? (are image features not padded to the batch?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, if the image is padded it is usually unpadded before we come to this point, e.g in llava-next. Hm, I don't think the slowdown will be drastic especially since we need to check once per input in the pre-fill stage

n_image_features = image_features.shape[0] * image_features.shape[1]

if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ def forward(
else:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]

if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,7 @@ def forward(
if image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]

if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ def forward(
if image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]

if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ def forward(
)
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]

if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
Expand All @@ -704,6 +705,7 @@ def forward(
)
video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1)

n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens != n_video_features:
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/video_llava/modeling_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,8 @@ def forward(
# TODO: @raushan retain only the new behavior after v4.47
else:
if pixel_values_images is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
Expand All @@ -639,8 +639,8 @@ def forward(
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

if pixel_values_videos is not None:
n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item()
n_video_features = video_features.shape[1]
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0] * video_features.shape[1]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/vipllava/modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,9 @@ def forward(

# TODO: @raushan retain only the new behavior after v4.47
else:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0] * image_features.shape[1]

if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
Expand Down
29 changes: 29 additions & 0 deletions tests/models/llava/test_modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,35 @@ def test_inputs_embeds_matches_input_ids(self):
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))

def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong
when number of images don't match number of image tokens in the text.
Also we need to test multi-image cases when one prompr has multiple image tokens.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successfull forward with no modifications

# remove one image but leave the image token in text
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)

# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)

# one image and two image tokens raise an error
with self.assertRaises(ValueError):
_ = model(input_ids=input_ids, pixel_values=pixel_values)

# two images and two image tokens don't raise an error
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values)

@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
Expand Down
32 changes: 32 additions & 0 deletions tests/models/llava_next/test_modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,38 @@ def test_inputs_embeds_matches_input_ids(self):
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))

def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong
when number of images don't match number of image tokens in the text.
Also we need to test multi-image cases when one prompr has multiple image tokens.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successfull forward with no modifications

# remove one image but leave the image token in text
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)

# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:1]
image_sizes = input_dict["image_sizes"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)

# one image and two image tokens raise an error
with self.assertRaises(ValueError):
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)

# two images and two image tokens don't raise an error
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)

@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
Expand Down
32 changes: 32 additions & 0 deletions tests/models/llava_next_video/test_modeling_llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,38 @@ def test_inputs_embeds_matches_input_ids(self):
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))

def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong
when number of images don't match number of image tokens in the text.
Also we need to test multi-image cases when one prompr has multiple image tokens.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successfull forward with no modifications

# remove one image but leave the image token in text
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)

# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:1]
image_sizes = input_dict["image_sizes"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)

# one image and two image tokens raise an error
with self.assertRaises(ValueError):
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)

# two images and two image tokens don't raise an error
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)

@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
Expand Down
30 changes: 30 additions & 0 deletions tests/models/paligemma/test_modeling_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,36 @@ def test_inputs_embeds_matches_input_ids(self):
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))

# Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens
def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong
when number of images don't match number of image tokens in the text.
Also we need to test multi-image cases when one prompr has multiple image tokens.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successfull forward with no modifications

# remove one image but leave the image token in text
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)

# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)

# one image and two image tokens raise an error
with self.assertRaises(ValueError):
_ = model(input_ids=input_ids, pixel_values=pixel_values)

# two images and two image tokens don't raise an error
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values)

@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
Expand Down
36 changes: 35 additions & 1 deletion tests/models/qwen2_vl/test_modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Qwen2VLVisionText2TextModelTester:
def __init__(
self,
parent,
batch_size=2,
batch_size=3,
seq_length=7,
num_channels=3,
ignore_index=-100,
Expand Down Expand Up @@ -245,6 +245,40 @@ def test_initialization(self):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)

def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong
when number of images don't match number of image tokens in the text.
Also we need to test multi-image cases when one prompr has multiple image tokens.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successfull forward with no modifications

# remove one image but leave the image token in text
patch_size = config.vision_config.patch_size
one_img_length = (self.model_tester.image_size**2) // (patch_size**2)
input_dict["pixel_values"] = input_dict["pixel_values"][-one_img_length:, ...]
input_dict["image_grid_thw"] = input_dict["image_grid_thw"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)

# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:one_img_length]
image_grid_thw = input_dict["image_grid_thw"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)

# one image and two image tokens raise an error
with self.assertRaises(ValueError):
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)

# two images and two image tokens don't raise an error
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)

@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
Expand Down
33 changes: 31 additions & 2 deletions tests/models/video_llava/test_modeling_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def __init__(
self.batch_size = 5
self.num_channels = 3
self.image_size = 224
self.encoder_seq_length = 64
self.encoder_seq_length = 246
self.num_image_tokens = 25
self.num_video_tokens = 26
self.num_video_tokens = 26 * self.num_frames
self.seq_length = seq_length + self.num_image_tokens + self.num_video_tokens

def get_config(self):
Expand Down Expand Up @@ -396,6 +396,35 @@ def test_inputs_embeds_matches_input_ids(self):
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))

def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong
when number of images don't match number of image tokens in the text.
Also we need to test multi-image cases when one prompr has multiple image tokens.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successfull forward with no modifications

# remove one image but leave the image token in text
input_dict["pixel_values_images"] = input_dict["pixel_values_images"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)

# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values_images"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)

# one image and two image tokens raise an error
with self.assertRaises(ValueError):
_ = model(input_ids=input_ids, pixel_values_images=pixel_values)

# two images and two image tokens don't raise an error
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
_ = model(input_ids=input_ids, pixel_values_images=pixel_values)


@require_torch
class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
Expand Down
Loading
Loading