diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 45558bd22a4e..f0f7f2b0b6b5 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -381,9 +381,13 @@ def prepare_inputs_for_generation( # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. # (we can't check exception 3 while compiling) + # Excpetion 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + # generate the first token for each sequence. Later use the generated Input ids for continuation. if past_key_values is not None: model_inputs["past_key_values"] = past_key_values - if ( + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif ( inputs_embeds is not None # Exception 1 or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 ): @@ -393,9 +397,9 @@ def prepare_inputs_for_generation( # 3. Prepare base model inputs input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step for every prompt. if not self.config.is_encoder_decoder: - if inputs_embeds is not None and cache_position[0] == 0: + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: model_inputs[input_ids_key] = None model_inputs["inputs_embeds"] = inputs_embeds else: diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 19ca679ad0df..9a03a793a613 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -895,8 +895,12 @@ def prepare_inputs_for_generation( # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. # (we can't check exception 3 while compiling) + # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + # generate the first token for each sequence. Later use the generated Input ids for continuation. if past_key_values is not None: - if ( + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif ( inputs_embeds is not None # Exception 1 or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 ): @@ -905,7 +909,7 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, cache_position] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 1e088fcaba00..3bc6a43d6f56 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1654,8 +1654,12 @@ def prepare_inputs_for_generation( # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. # (we can't check exception 3 while compiling) + # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + # generate the first token for each sequence. Later use the generated Input ids for continuation. if past_key_values is not None: - if ( + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif ( inputs_embeds is not None # Exception 1 or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 ): @@ -1668,10 +1672,13 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if inputs_embeds is not None and input_ids.shape[1] == 0: + position_ids = position_ids[:, -inputs_embeds.shape[1] :] + else: + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 4dbe4ad4c7f9..6857fb624c0f 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1674,10 +1674,13 @@ def prepare_inputs_for_generation( else: model_inputs["pixel_values"] = pixel_values - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # If we have cache: let's slice `input_ids` or `input embeds` through `cache_position`, to keep only the unprocessed tokens if past_key_values is not None: if inputs_embeds is not None: - input_ids = input_ids[:, -cache_position.shape[0] :] + if input_ids.shape[1] == 0: + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + else: + input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: input_ids = input_ids[:, cache_position] if image_attention_mask is not None: @@ -1687,14 +1690,19 @@ def prepare_inputs_for_generation( # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) + + # If past_key_values are present then slice the postion ids for only only the unprocessed tokens. if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if inputs_embeds is not None and input_ids.shape[1] == 0: + position_ids = position_ids[:, -inputs_embeds.shape[1] :] + else: + position_ids = position_ids[:, -input_ids.shape[1] :] # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. position_ids = position_ids.clone(memory_format=torch.contiguous_format) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: model_inputs.update({"inputs_embeds": inputs_embeds, "input_ids": None}) else: # The clone here is for the same reason as for `position_ids`. diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 01d2ff1940fc..6bde89f9aab5 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -1901,8 +1901,7 @@ def forward( @add_start_docstrings( - "The original Moshi model with an audio encoder, a Moshi depth decoder and a Moshi decoder, " - "for speech-to-speech.", + "The original Moshi model with an audio encoder, a Moshi depth decoder and a Moshi decoder, for speech-to-speech.", MOSHI_START_DOCSTRING, ) class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin): @@ -2458,18 +2457,57 @@ def prepare_inputs_for_generation( blank_user_audio_codes: Optional[torch.FloatTensor] = None, **kwargs, ): - # Overwritten -- Moshi has custom post-processing - # 1. Do usual operations done on LLMs like Gemma - because we pre-processed inputs, the first pass always has inputs_embeds - model_inputs = super().prepare_inputs_for_generation( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - cache_position=cache_position, - position_ids=position_ids, - use_cache=use_cache, - logits_to_keep=logits_to_keep, - **kwargs, + # Overwritten -- Moshi has custom post-processing on the prepared inputs. + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + + if past_key_values is not None: + if ( + inputs_embeds is not None # Exception 1 + or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 + ): + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "cache_position": cache_position, + } ) # 2. Now that everything is prepared, generate audio_codes using the depth decoder diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 242622d293a2..82b112ad3665 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1261,7 +1261,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when @@ -1872,8 +1872,12 @@ def prepare_inputs_for_generation( # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. # (we can't check exception 3 while compiling) + # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + # generate the first token for each sequence. Later use the generated Input ids for continuation. if past_key_values is not None: - if ( + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif ( inputs_embeds is not None # Exception 1 or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 ): @@ -1886,7 +1890,7 @@ def prepare_inputs_for_generation( pixel_values_videos = None # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: model_inputs = {"input_ids": input_ids, "inputs_embeds": None} diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 87216988b717..601ad373771c 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -770,8 +770,12 @@ def prepare_inputs_for_generation( # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. # (we can't check exception 3 while compiling) + # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + # generate the first token for each sequence. Later use the generated Input ids for continuation. if past_key_values is not None: - if ( + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif ( inputs_embeds is not None # Exception 1 or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 ): @@ -784,7 +788,7 @@ def prepare_inputs_for_generation( pixel_values_videos = None # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: model_inputs = {"input_ids": input_ids, "inputs_embeds": None} diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index d94daa39a729..51d8fe9430b5 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1735,8 +1735,12 @@ def prepare_inputs_for_generation( # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. # (we can't check exception 3 while compiling) + # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + # generate the first token for each sequence. Later use the generated Input ids for continuation. if past_key_values is not None: - if ( + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif ( inputs_embeds is not None # Exception 1 or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 ): @@ -1749,7 +1753,7 @@ def prepare_inputs_for_generation( pixel_values_videos = None # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: model_inputs = {"input_ids": input_ids, "inputs_embeds": None} diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index f2d7d21a743e..8f00780f341d 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1557,7 +1557,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 321803a2179b..c7c8c7f8c108 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1857,6 +1857,83 @@ def test_generate_continue_from_past_key_values(self): ) ) + @pytest.mark.generate + def test_generate_continue_from_inputs_embeds(self): + """Tests that we can continue generation from `inputs_embeds` and past key values returned from a previous `generate` call.""" + for model_class in self.all_generative_model_classes: + if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]): + self.skipTest(reason="Won't fix: old model with unique inputs/caches/other") + if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]): + self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility") + + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + + if "token_type_ids" in inputs_dict: + del inputs_dict["token_type_ids"] + + if config.is_encoder_decoder: + self.skipTest(reason="This model is encoder-decoder") + if not hasattr(config, "use_cache"): + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + + model = model_class(config).to(torch_device).eval() + + if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): + self.skipTest(reason="This model does not support `inputs_embeds` in generation") + + # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format) + outputs = model(**inputs_dict) + if "past_key_values" not in outputs: + self.skipTest(reason="This model doesn't return `past_key_values`") + + pixel_values_is_mutually_exclusive = any( + model_name in model_class.__name__.lower() + for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma", "emu3"] + ) + if pixel_values_is_mutually_exclusive: + inputs_dict.pop("pixel_values", None) + inputs_dict.pop("pixel_values_videos", None) + inputs_dict.pop("pixel_values_images", None) + + input_ids = inputs_dict.pop("input_ids") + + model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 + model.generation_config.forced_eos_token_id = None + model.config.is_decoder = True + model.generation_config.use_cache = True + + generation_kwargs = { + "return_dict_in_generate": True, + "do_sample": False, + } + + # Traditional way of generating text, with `return_dict_in_generate` to return the past key values. + input_embeds = model.get_input_embeddings()(input_ids) + outputs = model.generate(inputs_embeds=input_embeds, max_new_tokens=4, **generation_kwargs) + + # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens) + initial_output = model.generate(inputs_embeds=input_embeds, max_new_tokens=3, **generation_kwargs) + continued_embeds = torch.cat([input_embeds, model.get_input_embeddings()(initial_output.sequences)], dim=1) + cached_output = model.generate( + inputs_embeds=continued_embeds, + max_new_tokens=1, + past_key_values=initial_output.past_key_values, + **generation_kwargs, + ) + + # Combine the (3 + 1) generated tokens and verify it matches with full generation. + combined_output_sequences = torch.concat([initial_output.sequences, cached_output.sequences], axis=1) + self.assertListEqual(outputs.sequences.tolist(), combined_output_sequences.tolist()) + # The two sets of past kv should be equal to each other + for layer_idx in range(len(cached_output.past_key_values)): + for kv_idx in range(len(cached_output.past_key_values[layer_idx])): + self.assertTrue( + torch.allclose( + outputs.past_key_values[layer_idx][kv_idx], + cached_output.past_key_values[layer_idx][kv_idx], + ) + ) + @parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5) @require_torch_gpu @pytest.mark.generate diff --git a/tests/models/clvp/test_modeling_clvp.py b/tests/models/clvp/test_modeling_clvp.py index 84a0101f6f28..334f01004936 100644 --- a/tests/models/clvp/test_modeling_clvp.py +++ b/tests/models/clvp/test_modeling_clvp.py @@ -334,6 +334,10 @@ def test_training_gradient_checkpointing(self): loss = model(**inputs).loss loss.backward() + @unittest.skip(reason="Clvp `prepare_inputs_for_generation` function doesn't have cache position.") + def test_generate_continue_from_inputs_embeds(self): + pass + class ClvpModelForConditionalGenerationTester: def __init__(self, parent, is_training=False): diff --git a/tests/models/cohere2/test_modeling_cohere2.py b/tests/models/cohere2/test_modeling_cohere2.py index 436f1f965e90..81ea53b49f88 100644 --- a/tests/models/cohere2/test_modeling_cohere2.py +++ b/tests/models/cohere2/test_modeling_cohere2.py @@ -131,6 +131,10 @@ def test_generate_with_static_cache(self): def test_generate_from_inputs_embeds_with_static_cache(self): pass + @unittest.skip("Cohere2 has HybridCache and doesn't support progressive generation using input embeds.") + def test_generate_continue_from_inputs_embeds(self): + pass + # overwrite because HybridCache has fixed length for key/values def _check_attentions_for_generate( self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 diff --git a/tests/models/fuyu/test_modeling_fuyu.py b/tests/models/fuyu/test_modeling_fuyu.py index 0444ad14f269..634dfcf61565 100644 --- a/tests/models/fuyu/test_modeling_fuyu.py +++ b/tests/models/fuyu/test_modeling_fuyu.py @@ -325,6 +325,10 @@ def test_disk_offload_safetensors(self): def test_model_parallelism(self): super().test_model_parallelism() + @unittest.skip(reason="Fuyu `prepare_inputs_for_generation` function doesn't have cache position.") + def test_generate_continue_from_inputs_embeds(): + pass + @slow @require_torch_accelerator diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 1fb7bdfa8994..a0563aed90cb 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -146,6 +146,10 @@ def test_generate_with_static_cache(self): def test_generate_from_inputs_embeds_with_static_cache(self): pass + @unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + def test_generate_continue_from_inputs_embeds(self): + pass + # overwrite because HybridCache has fixed length for key/values def _check_attentions_for_generate( self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index 1ac2db408123..c854a7e71167 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -450,6 +450,10 @@ def test_disk_offload(self): def test_past_key_values_format(self): pass + @unittest.skip(reason="BigCodeGPT has a non-standard KV cache format and breaks this test.") + def test_generate_continue_from_inputs_embeds(self): + pass + def test_gpt_bigcode_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_gpt_bigcode_model(*config_and_inputs) diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index 01871e81b30e..cc9efc967db2 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -755,6 +755,65 @@ def test_generate_without_input_ids(self): ) self.assertIsNotNone(output_ids_generate) + @pytest.mark.generate + def test_generate_continue_from_inputs_embeds(self): + """Overwrite for IDEFICS: Ensure image attention mask is processed while continuing from `inputs_embeds`.""" + + for model_class in self.all_generative_model_classes: + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + print(inputs) + + model = model_class(config).to(torch_device).eval() + + model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 + model.generation_config.forced_eos_token_id = None + model.generation_config.use_cache = True + + input_ids = inputs.pop("input_ids") + input_embeds = model.get_input_embeddings()(input_ids) + + generation_kwargs = { + "return_dict_in_generate": True, + "do_sample": False, + } + + inputs["inputs_embeds"] = input_embeds + + # Traditional way of generating text, with `return_dict_in_generate` to return the past key values + outputs = model.generate(**inputs, max_new_tokens=4, **generation_kwargs) + # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the + # inputs may need to be tweaked across `generate` calls (like the attention mask). + initial_output = model.generate(**inputs, max_new_tokens=3, **generation_kwargs) + inputs["past_key_values"] = initial_output.past_key_values + + new_attention_len = input_ids.shape[1] + initial_output.sequences.shape[-1] + continued_embeds = torch.cat([input_embeds, model.get_input_embeddings()(initial_output.sequences)], dim=1) + inputs["inputs_embeds"] = continued_embeds + + if "attention_mask" in inputs: + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], + (0, new_attention_len - inputs["attention_mask"].shape[1]), + mode="constant", + value=1, + ) + if "image_attention_mask" in inputs: + inputs["image_attention_mask"] = inputs["image_attention_mask"][..., -1:, :] + + cached_output = model.generate(**inputs, max_new_tokens=1, **generation_kwargs) + + # Verify that the combined outputs match the full generation. + combined_output_sequences = torch.concat([initial_output.sequences, cached_output.sequences], axis=1) + self.assertListEqual(outputs.sequences.tolist(), combined_output_sequences.tolist()) + for layer_idx in range(len(cached_output.past_key_values)): + for kv_idx in range(len(cached_output.past_key_values[layer_idx])): + self.assertTrue( + torch.allclose( + outputs.past_key_values[layer_idx][kv_idx], + cached_output.past_key_values[layer_idx][kv_idx], + ) + ) + def _check_attentions_for_generate( self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 ): diff --git a/tests/models/moshi/test_modeling_moshi.py b/tests/models/moshi/test_modeling_moshi.py index adaf0fcc34ac..09278f0d24c4 100644 --- a/tests/models/moshi/test_modeling_moshi.py +++ b/tests/models/moshi/test_modeling_moshi.py @@ -358,6 +358,10 @@ def test_disk_offload_bin(self): def test_disk_offload_safetensors(self): pass + @unittest.skip(reason="Test becomes too complex with Moshi requiring multiple input modalities.") + def test_generate_continue_from_inputs_embeds(self): + pass + @is_flaky(max_attempts=5, description="flaky on some models.") def test_save_load(self): super().test_save_load() @@ -824,6 +828,7 @@ def test_generate_without_input_ids(self): output_ids_generate = model.generate( do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True ) + print(output_ids_generate) self.assertIsNotNone(output_ids_generate) @unittest.skip(reason="The audio encoder has no gradients.") @@ -919,6 +924,10 @@ def test_disk_offload_bin(self): def test_disk_offload_safetensors(self): pass + @unittest.skip(reason="Test becomes too complex with Moshi requiring multiple modalities") + def test_generate_continue_from_inputs_embeds(self): + pass + @is_flaky(max_attempts=5, description="flaky on some models.") def test_save_load(self): super().test_save_load() diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 2bd6732514c6..c876e598e867 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -333,6 +333,10 @@ def test_past_key_values_format(self): """ pass + @unittest.skip(reason="Zamba2 has hybrid cache.") + def test_generate_continue_from_inputs_embeds(self): + pass + @unittest.skip(reason="A large mamba2 would be necessary (and costly) for that") def test_multi_gpu_data_parallel_forward(self): pass