Skip to content

Commit

Permalink
Reformat to be fully compatible with chat templates
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Oct 31, 2024
1 parent 0aabcb2 commit 5189f16
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 40 deletions.
6 changes: 6 additions & 0 deletions docs/source/en/main_classes/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,12 @@ Pipelines available for multimodal tasks include the following.
- __call__
- all

### ImageTextToTextPipeline

[[autodoc]] ImageTextToTextPipeline
- __call__
- all

### MaskGenerationPipeline

[[autodoc]] MaskGenerationPipeline
Expand Down
6 changes: 6 additions & 0 deletions docs/source/ja/main_classes/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,12 @@ my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
- __call__
- all

### ImageTextToTextPipeline

[[autodoc]] ImageTextToTextPipeline
- __call__
- all

### VisualQuestionAnsweringPipeline

[[autodoc]] VisualQuestionAnsweringPipeline
Expand Down
6 changes: 6 additions & 0 deletions docs/source/zh/main_classes/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,12 @@ See [`TokenClassificationPipeline`] for all details.
- __call__
- all

### ImageTextToTextPipeline

[[autodoc]] ImageTextToTextPipeline
- __call__
- all

### MaskGenerationPipeline

[[autodoc]] MaskGenerationPipeline
Expand Down
152 changes: 113 additions & 39 deletions src/transformers/pipelines/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
from typing import Dict, List, Optional, Union

from ..tokenization_utils_base import AddedToken
from ..utils import (
add_end_docstrings,
is_torch_available,
Expand All @@ -40,6 +42,12 @@
IMAGE_TOKEN = "<image>"


class ReturnType(enum.Enum):
TENSORS = 0
NEW_TEXT = 1
FULL_TEXT = 2


class Chat:
"""This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats
to this format because the rest of the pipeline code tends to assume that lists of messages are
Expand All @@ -64,20 +72,26 @@ def __init__(self, images: Union[str, List[str], "Image.Image", List["Image.Imag
self.text = text


def retrieve_images_in_chat(chat, images):
def retrieve_images_in_chat(chat: dict, images: Optional[Union[str, List[str], "Image.Image", List["Image.Image"]]]):
"""
Retrieve and combine images from the chat and the images passed as input.
"""
if images is None:
images = []
idx_images = 0
retrieved_images = []
for message in chat:
for content in message["content"]:
if content.get("type") == "image":
if isinstance(content, dict) and content.get("type") == "image":
if "image" in content:
retrieved_images.append(content["image"])
else:
elif idx_images < len(images):
retrieved_images.append(images[idx_images])
idx_images += 1
else:
raise ValueError(
"The number of images in the chat should be the same as the number of images passed."
)

# The number of images passed should be consistent with the number of images in the chat without an image key
if idx_images != len(images):
Expand All @@ -90,6 +104,9 @@ def retrieve_images_in_chat(chat, images):
class ImageTextToTextPipeline(Pipeline):
"""
Image-text-to-text pipeline using an `AutoModelForImageTextToText`. This pipeline generates text given an image and text.
When the underlying model is a conversational model, it can also accept one or more chats,
in which case the pipeline will operate in chat mode and will continue the chat(s) by adding its response(s).
Each chat takes the form of a list of dicts, where each dict contains "role" and "content" keys.
Example:
Expand Down Expand Up @@ -125,11 +142,14 @@ def _sanitize_parameters(
padding=None,
max_length=None,
timeout=None,
include_query_in_output=None,
return_full_text=None,
return_tensors=None,
return_type=None,
continue_final_message=None,
):
forward_kwargs = {}
preprocess_params = {}
post_process_params = {}
postprocess_params = {}

if timeout is not None:
preprocess_params["timeout"] = timeout
Expand All @@ -143,6 +163,9 @@ def _sanitize_parameters(
if max_length is not None:
preprocess_params["max_length"] = max_length

if continue_final_message is not None:
preprocess_params["continue_final_message"] = continue_final_message

if generate_kwargs is not None:
forward_kwargs["generate_kwargs"] = generate_kwargs

Expand All @@ -156,10 +179,18 @@ def _sanitize_parameters(
)
forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens

if include_query_in_output is not None:
post_process_params["include_query_in_output"] = include_query_in_output
if return_full_text is not None and return_type is None:
if return_tensors is not None:
raise ValueError("`return_full_text` is mutually exclusive with `return_tensors`")
return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT
if return_tensors is not None and return_type is None:
return_type = ReturnType.TENSORS
if return_type is not None:
postprocess_params["return_type"] = return_type
if continue_final_message is not None:
postprocess_params["continue_final_message"] = continue_final_message

return preprocess_params, forward_kwargs, post_process_params
return preprocess_params, forward_kwargs, postprocess_params

def __call__(
self,
Expand Down Expand Up @@ -201,30 +232,34 @@ def __call__(
if images is None and text is None:
raise ValueError("You must at least provide either text or images.")

if not isinstance(images, (list, tuple)):
images = [images]

if isinstance(text, (list, tuple, KeyDataset) if is_torch_available() else (list, tuple)) and isinstance(
text[0], (list, tuple, dict)
):
# We have one or more prompts in list-of-dicts format, so this is chat mode

if isinstance(text[0], dict):
return super().__call__(Chat(text, images), **kwargs)
else:
if images is None:
images = [None] * len(text)
chats = [Chat(chat, image) for chat, image in zip(text, images)] # 🐈 🐈 🐈
return super().__call__(chats, **kwargs)

# If we are not in chat mode, we need both images and text
if images is None or text is None:
raise ValueError("You must provide both images and text when not using chat templates.")

if not isinstance(images, (list, tuple)):
images = [images]
if isinstance(text, str):
text = [text] * len(images)
if not isinstance(text[0], str):
raise ValueError("The pipeline does not support nested lists of prompts.")

if hasattr(self.processor, "image_token") and self.processor.image_token is not None:
image_token = self.processor.image_token
if isinstance(image_token, AddedToken):
image_token = image_token.content
else:
image_token = IMAGE_TOKEN
# Check number of image_token token in each text
Expand Down Expand Up @@ -260,24 +295,37 @@ def __call__(

return super().__call__([ImageText(image, text_single) for image, text_single in zip(images, text)], **kwargs)

def preprocess(self, inputs=None, truncation=None, padding=False, max_length=None, timeout=None):
def preprocess(
self, inputs=None, truncation=None, padding=False, max_length=None, timeout=None, continue_final_message=None
):
kwargs = {
"legacy": False,
"truncation": truncation,
"padding": padding,
"max_length": max_length,
}
kwargs = {k: v for k, v in kwargs.items() if v is not None}

images = inputs.images

if isinstance(inputs, Chat):
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
# because very few models support multiple separate, consecutive assistant messages
if continue_final_message is None:
continue_final_message = inputs.messages[-1]["role"] == "assistant"
text = self.processor.apply_chat_template(
inputs.messages,
add_generation_prompt=True,
add_generation_prompt=not continue_final_message,
continue_final_message=continue_final_message,
# return_dict=True,
return_tensors=self.framework,
**kwargs,
)
inputs_text = inputs
else:
text = inputs.text
inputs_text = inputs.text

if not isinstance(images, (list, tuple)):
images = load_image(images, timeout=timeout)
else:
Expand All @@ -293,52 +341,78 @@ def preprocess(self, inputs=None, truncation=None, padding=False, max_length=Non
dtype=self.torch_dtype
)

model_inputs["text"] = text
model_inputs["text"] = inputs_text

return model_inputs

def _forward(self, model_inputs, generate_kwargs=None):
generate_kwargs = {} if generate_kwargs is None else generate_kwargs
input_text = model_inputs.pop("text")
prompt_text = model_inputs.pop("text")
input_ids = (
model_inputs["input_ids"] if "input_ids" in model_inputs else model_inputs["decoder_input_ids"]
) # for decoder-only models
model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
return {"outputs": model_outputs, "input_text": input_text, "input_ids": input_ids}

def postprocess(self, model_outputs, include_query_in_output=False):
input_text = model_outputs["input_text"]
input_text = [input_text] if isinstance(input_text, str) else input_text
outputs = model_outputs["outputs"]
inputs_id = model_outputs["input_ids"]
generated_sequence = self.model.generate(**model_inputs, **generate_kwargs)

return {"generated_sequence": generated_sequence, "prompt_text": prompt_text, "input_ids": input_ids}

def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, continue_final_message=None):
input_texts = model_outputs["prompt_text"]
input_texts = [input_texts] if isinstance(input_texts, (str, Chat)) else input_texts
generated_sequence = model_outputs["generated_sequence"]
input_ids = model_outputs["input_ids"]
if return_type == ReturnType.TENSORS:
return [
{"input_text": input_texts[i], "generated_token_ids": generated_sequence[i]}
for i in range(len(input_texts))
]

# Decode inputs and outputs the same way to remove input text from generated text if present
generated_texts = self.processor.post_process_image_text_to_text(outputs)
decoded_inputs = self.processor.post_process_image_text_to_text(inputs_id)
generated_texts = [text.strip() for text in generated_texts]
decoded_inputs = [text.strip() for text in decoded_inputs]
generated_texts = self.processor.post_process_image_text_to_text(generated_sequence)
decoded_inputs = self.processor.post_process_image_text_to_text(input_ids)

# Force consistent behavior for including the input text in the output
if include_query_in_output:
# Add the input text to the generated text if the generated text does not start with the input
generated_texts = [
f"{decoded_inputs[i]} {text_generated}"
if not text_generated.startswith(decoded_inputs[i])
else text_generated
for i, text_generated in enumerate(generated_texts)
]
else:
if return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
# Remove the input text from the generated text if the generated text starts with the input text
generated_texts = [
text_generated[len(decoded_inputs[i]) :].strip()
text_generated[len(decoded_inputs[i]) :]
if text_generated.startswith(decoded_inputs[i])
else text_generated
for i, text_generated in enumerate(generated_texts)
]
if return_type == ReturnType.FULL_TEXT:
full_texts = []
for prompt_text, generated_text in zip(input_texts, generated_texts):
if isinstance(prompt_text, str):
generated_text = prompt_text + generated_text
elif isinstance(prompt_text, Chat):
if continue_final_message is None:
# If the user passes a chat ending in an assistant message, we treat it as a prefill by
# default because very few models support multiple separate, consecutive assistant messages
continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
if continue_final_message:
# With assistant prefill, concat onto the end of the last message
new_text = dict(prompt_text.messages[-1]["content"][-1].items())
new_text["text"] += generated_text
generated_text = list(prompt_text.messages)[:-1] + [
{
"role": prompt_text.messages[-1]["role"],
"content": prompt_text.messages[-1]["content"][:-1] + [new_text],
}
]
else:
# When we're not starting from a prefill, the output is a new assistant message
generated_text = list(prompt_text.messages) + [
{"role": "assistant", "content": generated_text}
]
full_texts.append(generated_text)
generated_texts = full_texts

records = [
{"input_text": input_text[i], "generated_text": generated_text}
for i, generated_text in enumerate(generated_texts)
{
"input_text": input_text.messages if isinstance(input_text, Chat) else input_text,
"generated_text": generated_text,
}
for input_text, generated_text in zip(input_texts, generated_texts)
]

return records
3 changes: 2 additions & 1 deletion tests/pipelines/test_pipelines_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
pipe = ImageTextToTextPipeline(
model=model, tokenizer=tokenizer, image_processor=processor, torch_dtype=torch_dtype
)
image_token = processor.image_token if hasattr(processor, "image_token") else "<image>"
examples = {
"images": [
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
"./tests/fixtures/tests_samples/COCO/000000039769.png",
],
"text": ["<image> This is a ", "<image> Here I see a "],
"text": [f"{image_token} This is a ", f"{image_token} Here I see a "],
}
return pipe, examples

Expand Down

0 comments on commit 5189f16

Please sign in to comment.