diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index e5583d782d88..ff000f075971 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -853,7 +853,7 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) # Merge text and images in prefill stage - if past_key_values is None: + if input_ids is not None and inputs_embeds.shape[1] != 1: # First merge image tokens if there are any if pixel_values is not None and pixel_values.size(0) > 0: image_features = self._get_image_features(pixel_values, image_sizes) @@ -910,7 +910,7 @@ def forward( pass # generation with cache, decoding stage - elif past_key_values is not None and (pixel_values is not None or pixel_values_videos is not None): + elif pixel_values is not None or pixel_values_videos is not None: # Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0 first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 330ef62e56fb..deb0eddfde8e 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -653,9 +653,6 @@ def prepare_inputs_for_generation( if cache_length < past_length and attention_mask is not None: attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] - pixel_values_videos = None - pixel_values_images = None - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation diff --git a/tests/models/blip/test_modeling_blip.py b/tests/models/blip/test_modeling_blip.py index 2f8ee3229ff2..06904795adff 100644 --- a/tests/models/blip/test_modeling_blip.py +++ b/tests/models/blip/test_modeling_blip.py @@ -21,6 +21,7 @@ import numpy as np import requests +from parameterized import parameterized from transformers import BlipConfig, BlipTextConfig, BlipVisionConfig from transformers.testing_utils import ( @@ -1106,6 +1107,7 @@ def test_model_from_pretrained(self): @require_torch class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (BlipForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (BlipForConditionalGeneration,) if is_torch_available() else () fx_compatible = False test_head_masking = False test_pruning = False @@ -1116,6 +1118,18 @@ class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase): def setUp(self): self.model_tester = BlipTextImageModelsModelTester(self) + @parameterized.expand([(True,), (False,)]) + def test_greedy_generation(self, use_cache: bool): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + out = model.generate(**inputs_dict, min_new_tokens=20, max_new_tokens=20, use_cache=use_cache) + self.assertTrue(out.shape[1] == inputs_dict["input_ids"].shape[1] + 19) + def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 28ed3a79cae5..ae9ab97fee03 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -20,6 +20,7 @@ import numpy as np import requests +from parameterized import parameterized from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig from transformers.testing_utils import ( @@ -314,7 +315,7 @@ def __init__( hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, - max_position_embeddings=20, + max_position_embeddings=256, eos_token_id=2, pad_token_id=1, bos_token_id=0, @@ -436,8 +437,9 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): +class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (Blip2ForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (Blip2ForConditionalGeneration,) if is_torch_available() else () fx_compatible = False test_head_masking = False test_pruning = False @@ -448,6 +450,18 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT def setUp(self): self.model_tester = Blip2ForConditionalGenerationDecoderOnlyModelTester(self) + @parameterized.expand([(True,), (False,)]) + def test_greedy_generation(self, use_cache: bool): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + out = model.generate(**inputs_dict, min_new_tokens=20, max_new_tokens=20, use_cache=use_cache) + self.assertTrue(out.shape[1] == 21) # BLIP is special, so should be 21 + def test_for_conditional_generation(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs) diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py index 1aaa8e1a8b68..e428f72ccf09 100644 --- a/tests/models/instructblip/test_modeling_instructblip.py +++ b/tests/models/instructblip/test_modeling_instructblip.py @@ -20,6 +20,7 @@ import numpy as np import requests +from parameterized import parameterized from transformers import ( CONFIG_MAPPING, @@ -38,7 +39,6 @@ ) from transformers.utils import is_torch_available, is_vision_available -from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( ModelTesterMixin, @@ -319,7 +319,7 @@ def __init__( hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, - max_position_embeddings=20, + max_position_embeddings=256, eos_token_id=2, pad_token_id=1, bos_token_id=0, @@ -452,8 +452,9 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): +class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (InstructBlipForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (InstructBlipForConditionalGeneration,) if is_torch_available() else () fx_compatible = False test_head_masking = False test_pruning = False @@ -464,6 +465,19 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene def setUp(self): self.model_tester = InstructBlipForConditionalGenerationDecoderOnlyModelTester(self) + @parameterized.expand([(True,), (False,)]) + def test_greedy_generation(self, use_cache: bool): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + model.config.text_config.architectures = ["OptForCausalLM"] + + out = model.generate(**inputs_dict, min_new_tokens=20, max_new_tokens=20, use_cache=use_cache) + self.assertTrue(out.shape[1] == 21) # BLIP is special, therefore 21 + def test_for_conditional_generation(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs) diff --git a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py index 1265db3a2a2e..a3fcca23bf1c 100644 --- a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py +++ b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py @@ -20,6 +20,7 @@ import numpy as np from huggingface_hub import hf_hub_download +from parameterized import parameterized from transformers import ( CONFIG_MAPPING, @@ -38,7 +39,6 @@ ) from transformers.utils import is_torch_available, is_vision_available -from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( ModelTesterMixin, @@ -333,7 +333,7 @@ def __init__( hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, - max_position_embeddings=100, + max_position_embeddings=256, eos_token_id=2, pad_token_id=1, bos_token_id=0, @@ -471,10 +471,9 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class InstructBlipVideoForConditionalGenerationDecoderOnlyTest( - ModelTesterMixin, GenerationTesterMixin, unittest.TestCase -): +class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (InstructBlipVideoForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (InstructBlipVideoForConditionalGeneration,) if is_torch_available() else () fx_compatible = False test_head_masking = False test_pruning = False @@ -485,6 +484,19 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest( def setUp(self): self.model_tester = InstructBlipVideoForConditionalGenerationDecoderOnlyModelTester(self) + @parameterized.expand([(True,), (False,)]) + def test_greedy_generation(self, use_cache: bool): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + model.config.text_config.architectures = ["OptForCausalLM"] + + out = model.generate(**inputs_dict, min_new_tokens=20, max_new_tokens=20, use_cache=use_cache) + self.assertTrue(out.shape[1] == 21) # BLIP is special, therefore 21 + def test_for_conditional_generation(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs) diff --git a/tests/models/kosmos2/test_modeling_kosmos2.py b/tests/models/kosmos2/test_modeling_kosmos2.py index 6f34689004ef..639ac2eb4358 100644 --- a/tests/models/kosmos2/test_modeling_kosmos2.py +++ b/tests/models/kosmos2/test_modeling_kosmos2.py @@ -281,6 +281,17 @@ def setUp(self): self.model_tester = Kosmos2ModelTester(self) self.config_tester = ConfigTester(self, config_class=Kosmos2Config, hidden_size=37) + def test_greedy_generation(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + out = model.generate(**inputs_dict, min_new_tokens=20, max_new_tokens=20) + self.assertTrue(out.shape[1] == inputs_dict["input_ids"].shape[1] + 20) + # overwrite from common to skip `image_to_text_projection.latent_query` def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index ce13ab6738af..89bed80f36dd 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -18,6 +18,7 @@ import unittest import requests +from parameterized import parameterized from transformers import ( AutoProcessor, @@ -80,7 +81,7 @@ def __init__( "initializer_range": 0.02, "num_labels": 3, "num_choices": 4, - "pad_token_id": 0, + "pad_token_id": 1, }, is_training=True, vision_config={ @@ -148,6 +149,8 @@ def prepare_config_and_inputs_for_common(self): config, pixel_values = config_and_inputs input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 attention_mask = input_ids.ne(1).to(torch_device) + # set to random non-image token to prevent flakiness + input_ids[input_ids == config.image_token_index] = 1 # we are giving 3 images let's make sure we pass in 3 image tokens input_ids[:, 1] = config.image_token_index inputs_dict = { @@ -178,6 +181,7 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase """ all_model_classes = (LlavaForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (LlavaForConditionalGeneration,) if is_torch_available() else () pipeline_model_mapping = {"image-to-text": LlavaForConditionalGeneration} if is_torch_available() else {} test_pruning = False test_head_masking = False @@ -186,6 +190,24 @@ def setUp(self): self.model_tester = LlavaVisionText2TextModelTester(self) self.config_tester = ConfigTester(self, config_class=LlavaConfig, has_text_modality=False) + @parameterized.expand([(True,), (False,)]) + def test_greedy_generation(self, use_cache: bool): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + out = model.generate( + **inputs_dict, + min_new_tokens=20, + max_new_tokens=20, + use_cache=use_cache, + bad_words_ids=[[config.image_token_index]], + ) + self.assertTrue(out.shape[1] == inputs_dict["input_ids"].shape[1] + 20) + @unittest.skip( reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index 70d91002a91b..f7cf39b6bf78 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -19,6 +19,7 @@ import requests from huggingface_hub import hf_hub_download +from parameterized import parameterized from transformers import ( AutoProcessor, @@ -34,7 +35,6 @@ torch_device, ) -from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( ModelTesterMixin, @@ -86,7 +86,7 @@ def __init__( "initializer_range": 0.02, "num_labels": 3, "num_choices": 4, - "pad_token_id": 0, + "pad_token_id": 1, }, is_training=True, vision_config={ @@ -157,6 +157,8 @@ def prepare_config_and_inputs_for_common(self): config, pixel_values = config_and_inputs input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2 attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device) + # set to random non-image token to prevent flakiness + input_ids[input_ids == config.image_token_index] = 2 # we are giving 3 images let's make sure we pass in 3 image tokens input_ids[:, 1] = config.image_token_index labels = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device) @@ -208,12 +210,13 @@ def create_and_check_llava_next_model_fp16_autocast_forward( @require_torch -class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): +class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase): """ Model tester for `LlavaNextForConditionalGeneration`. """ all_model_classes = (LlavaNextForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (LlavaNextForConditionalGeneration,) if is_torch_available() else () test_pruning = False test_head_masking = False @@ -237,6 +240,24 @@ def test_initialization(self): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + @parameterized.expand([(True,), (False,)]) + def test_greedy_generation(self, use_cache: bool): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + out = model.generate( + **inputs_dict, + min_new_tokens=20, + max_new_tokens=20, + use_cache=use_cache, + bad_words_ids=[[config.image_token_index]], + ) + self.assertTrue(out.shape[1] == inputs_dict["input_ids"].shape[1] + 20) + @unittest.skip( reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index 9ba7ef869ddf..d657dcc7fdda 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -19,6 +19,7 @@ import numpy as np from huggingface_hub import hf_hub_download +from parameterized import parameterized from transformers import ( AutoProcessor, @@ -34,7 +35,6 @@ torch_device, ) -from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( ModelTesterMixin, @@ -86,7 +86,7 @@ def __init__( "initializer_range": 0.02, "num_labels": 3, "num_choices": 4, - "pad_token_id": 0, + "pad_token_id": 1, }, is_training=True, vision_config={ @@ -167,6 +167,9 @@ def prepare_config_and_inputs_for_common(self): config, pixel_values, pixel_values_videos = self.prepare_config_and_inputs() input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2 attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device) + # set to random non-image token to prevent flakiness + input_ids[input_ids == config.image_token_index] = 2 + input_ids[input_ids == config.video_token_index] = 2 # we are giving 3 images and videos let's make sure we pass in 3 special tokens input_ids[:, 1] = config.image_token_index input_ids[:, 2] = config.video_token_index @@ -223,12 +226,13 @@ def create_and_check_llava_next_video_model_fp16_autocast_forward( @require_torch -class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): +class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase): """ Model tester for `LlavaNextVideoForConditionalGeneration`. """ all_model_classes = (LlavaNextVideoForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (LlavaNextVideoForConditionalGeneration,) if is_torch_available() else () test_pruning = False test_head_masking = False @@ -274,6 +278,24 @@ def test_inputs_embeds(self): with torch.no_grad(): model(**inputs) + @parameterized.expand([(True,), (False,)]) + def test_greedy_generation(self, use_cache: bool): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + out = model.generate( + **inputs_dict, + min_new_tokens=20, + max_new_tokens=20, + use_cache=use_cache, + bad_words_ids=[[config.image_token_index], [config.video_token_index]], + ) + self.assertTrue(out.shape[1] == inputs_dict["input_ids"].shape[1] + 20) + @unittest.skip( reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index 7753ae073dd3..43a6f4826b0c 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -176,6 +176,7 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, unittest.Test """ all_model_classes = (PaliGemmaForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (PaliGemmaForConditionalGeneration,) if is_torch_available() else () fx_compatible = False test_pruning = False test_torchscript = False @@ -185,6 +186,18 @@ def setUp(self): self.model_tester = PaliGemmaVisionText2TextModelTester(self) self.config_tester = ConfigTester(self, config_class=PaliGemmaConfig, has_text_modality=False) + @parameterized.expand([(True,), (False,)]) + def test_greedy_generation(self, use_cache: bool): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + out = model.generate(**inputs_dict, min_new_tokens=20, max_new_tokens=20, use_cache=use_cache) + self.assertTrue(out.shape[1] == inputs_dict["input_ids"].shape[1] + 20) + @unittest.skip( reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py index fe3eea97dcf3..fd71be29e708 100644 --- a/tests/models/video_llava/test_modeling_video_llava.py +++ b/tests/models/video_llava/test_modeling_video_llava.py @@ -20,6 +20,7 @@ import numpy as np import requests from huggingface_hub import hf_hub_download +from parameterized import parameterized from transformers import ( VideoLlavaConfig, @@ -30,7 +31,6 @@ ) from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device -from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -75,7 +75,7 @@ def __init__( "initializer_range": 0.02, "num_labels": 3, "num_choices": 4, - "pad_token_id": 0, + "pad_token_id": 1, }, is_training=True, vision_config={ @@ -158,10 +158,11 @@ def prepare_config_and_inputs_for_common(self): config, pixel_values_images, pixel_values_videos = config_and_inputs input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 attention_mask = input_ids.ne(1).to(torch_device) + # set to random non-image token to prevent flakiness + input_ids[input_ids == config.image_token_index] = 2 + input_ids[input_ids == config.video_token_index] = 2 - # we are giving 3 videos and 3 images. Need to pass in image and video tokens, both - # also need to make sure no other special tokens are set - input_ids[(input_ids == 0) | (input_ids == 1)] = 3 + # we are giving 3 videos and 3 images. Need to pass in image and video tokens input_ids[:, 0] = config.video_token_index input_ids[:, 1:2] = config.image_token_index inputs_dict = { @@ -190,12 +191,13 @@ def prepare_config_and_inputs_for_batched_test(self): @require_torch -class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): +class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase): """ Model tester for `VideoLlavaForConditionalGeneration`. """ all_model_classes = (VideoLlavaForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (VideoLlavaForConditionalGeneration,) if is_torch_available() else () fx_compatible = False test_pruning = False test_resize_embeddings = True @@ -205,6 +207,24 @@ def setUp(self): self.model_tester = VideoLlavaVisionText2TextModelTester(self) self.config_tester = ConfigTester(self, config_class=VideoLlavaConfig, has_text_modality=False) + @parameterized.expand([(True,), (False,)]) + def test_greedy_generation(self, use_cache: bool): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + out = model.generate( + **inputs_dict, + min_new_tokens=20, + max_new_tokens=20, + use_cache=use_cache, + bad_words_ids=[[config.image_token_index], [config.video_token_index]], + ) + self.assertTrue(out.shape[1] == inputs_dict["input_ids"].shape[1] + 20) + @unittest.skip( reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py index a4e89d3f9ddf..d65c195e5bdb 100644 --- a/tests/models/vipllava/test_modeling_vipllava.py +++ b/tests/models/vipllava/test_modeling_vipllava.py @@ -18,6 +18,7 @@ import unittest import requests +from parameterized import parameterized from transformers import ( AutoProcessor, @@ -73,7 +74,7 @@ def __init__( "initializer_range": 0.02, "num_labels": 3, "num_choices": 4, - "pad_token_id": 0, + "pad_token_id": 1, }, is_training=True, vision_config={ @@ -140,6 +141,8 @@ def prepare_config_and_inputs_for_common(self): config, pixel_values = config_and_inputs input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 attention_mask = input_ids.ne(1).to(torch_device) + # set to random non-image token to prevent flakiness + input_ids[input_ids == config.image_token_index] = 2 # we are giving 3 images let's make sure we pass in 3 image tokens input_ids[:, 1] = config.image_token_index inputs_dict = { @@ -158,6 +161,7 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestC """ all_model_classes = (VipLlavaForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (VipLlavaForConditionalGeneration,) if is_torch_available() else () fx_compatible = False test_pruning = False test_resize_embeddings = True @@ -167,6 +171,24 @@ def setUp(self): self.model_tester = VipLlavaVisionText2TextModelTester(self) self.config_tester = ConfigTester(self, config_class=VipLlavaConfig, has_text_modality=False) + @parameterized.expand([(True,), (False,)]) + def test_greedy_generation(self, use_cache: bool): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + out = model.generate( + **inputs_dict, + min_new_tokens=20, + max_new_tokens=20, + use_cache=use_cache, + bad_words_ids=[[config.image_token_index]], + ) + self.assertTrue(out.shape[1] == inputs_dict["input_ids"].shape[1] + 20) + @unittest.skip( reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" )