Skip to content

Commit

Permalink
Continue generation using input embeds and cache
Browse files Browse the repository at this point in the history
  • Loading branch information
yaswanth19 committed Feb 5, 2025
1 parent 52f9394 commit 8b718dc
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 2 deletions.
7 changes: 6 additions & 1 deletion src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,8 +1690,13 @@ def prepare_inputs_for_generation(
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)

# If past_key_values are present then slice the postion ids for only only the unprocessed tokens.
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if inputs_embeds is not None and input_ids.shape[1] == 0:
position_ids = position_ids[:, -inputs_embeds.shape[1] :]
else:
position_ids = position_ids[:, -input_ids.shape[1] :]

# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
Expand Down
59 changes: 59 additions & 0 deletions tests/models/idefics/test_modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,65 @@ def test_generate_without_input_ids(self):
)
self.assertIsNotNone(output_ids_generate)

@pytest.mark.generate
def test_generate_continue_from_inputs_embeds(self):
"""Overwrite for IDEFICS: Ensure image attention mask is processed while continuing from `inputs_embeds`."""

for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
print(inputs)

model = model_class(config).to(torch_device).eval()

model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
model.generation_config.forced_eos_token_id = None
model.generation_config.use_cache = True

input_ids = inputs.pop("input_ids")
input_embeds = model.get_input_embeddings()(input_ids)

generation_kwargs = {
"return_dict_in_generate": True,
"do_sample": False,
}

inputs["inputs_embeds"] = input_embeds

# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs = model.generate(**inputs, max_new_tokens=4, **generation_kwargs)
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
# inputs may need to be tweaked across `generate` calls (like the attention mask).
initial_output = model.generate(**inputs, max_new_tokens=3, **generation_kwargs)
inputs["past_key_values"] = initial_output.past_key_values

new_attention_len = input_ids.shape[1] + initial_output.sequences.shape[-1]
continued_embeds = torch.cat([input_embeds, model.get_input_embeddings()(initial_output.sequences)], dim=1)
inputs["inputs_embeds"] = continued_embeds

if "attention_mask" in inputs:
inputs["attention_mask"] = torch.nn.functional.pad(
inputs["attention_mask"],
(0, new_attention_len - inputs["attention_mask"].shape[1]),
mode="constant",
value=1,
)
if "image_attention_mask" in inputs:
inputs["image_attention_mask"] = inputs["image_attention_mask"][..., -1:, :]

cached_output = model.generate(**inputs, max_new_tokens=1, **generation_kwargs)

# Verify that the combined outputs match the full generation.
combined_output_sequences = torch.concat([initial_output.sequences, cached_output.sequences], axis=1)
self.assertListEqual(outputs.sequences.tolist(), combined_output_sequences.tolist())
for layer_idx in range(len(cached_output.past_key_values)):
for kv_idx in range(len(cached_output.past_key_values[layer_idx])):
self.assertTrue(
torch.allclose(
outputs.past_key_values[layer_idx][kv_idx],
cached_output.past_key_values[layer_idx][kv_idx],
)
)

def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
Expand Down
3 changes: 2 additions & 1 deletion tests/models/moshi/test_modeling_moshi.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def test_disk_offload_bin(self):
def test_disk_offload_safetensors(self):
pass

@unittest.skip(reason="Test becomes too complex with Moshi requiring multiple modalities input.")
@unittest.skip(reason="Test becomes too complex with Moshi requiring multiple input modalities.")
def test_generate_continue_from_inputs_embeds(self):
pass

Expand Down Expand Up @@ -828,6 +828,7 @@ def test_generate_without_input_ids(self):
output_ids_generate = model.generate(
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
)
print(output_ids_generate)
self.assertIsNotNone(output_ids_generate)

@unittest.skip(reason="The audio encoder has no gradients.")
Expand Down

0 comments on commit 8b718dc

Please sign in to comment.