Skip to content

Commit

Permalink
Add pipeline tests and add copied from post process function
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Oct 31, 2024
1 parent ba8f85f commit 00174e8
Show file tree
Hide file tree
Showing 18 changed files with 60 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/transformers/models/blip_2/processing_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

# Copied from transformers.models.blip.processing_blip.BlipProcessor.post_process_image_text_to_text
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/chameleon/processing_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

# Copied from transformers.models.blip.processing_blip.BlipProcessor.post_process_image_text_to_text
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/donut/processing_donut.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def token2json(self, tokens, is_inner_value=False, added_vocab=None):
else:
return [] if is_inner_value else {"text_sequence": tokens}

# Copied from transformers.models.blip.processing_blip.BlipProcessor.post_process_image_text_to_text
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/git/processing_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

# Copied from transformers.models.blip.processing_blip.BlipProcessor.post_process_image_text_to_text
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/idefics/processing_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

# Copied from transformers.models.blip.processing_blip.BlipProcessor.post_process_image_text_to_text
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/idefics2/processing_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

# Copied from transformers.models.blip.processing_blip.BlipProcessor.post_process_image_text_to_text
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/idefics3/processing_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def decode(self, *args, **kwargs):
decode_output = self.tokenizer.decode(*args, **kwargs)
return self._regex_to_remove_extra_special_tokens.sub("<image>", decode_output)

# Copied from transformers.models.blip.processing_blip.BlipProcessor.post_process_image_text_to_text
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

# Copied from transformers.models.blip.processing_blip.BlipProcessor.post_process_image_text_to_text
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/llava/processing_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

# Copied from transformers.models.blip.processing_blip.BlipProcessor.post_process_image_text_to_text
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

# Copied from transformers.models.blip.processing_blip.BlipProcessor.post_process_image_text_to_text
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

# Copied from transformers.models.blip.processing_blip.BlipProcessor.post_process_image_text_to_text
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/paligemma/processing_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

# Copied from transformers.models.blip.processing_blip.BlipProcessor.post_process_image_text_to_text
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

# Copied from transformers.models.blip.processing_blip.BlipProcessor.post_process_image_text_to_text
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/pixtral/processing_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

# Copied from transformers.models.blip.processing_blip.BlipProcessor.post_process_image_text_to_text
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/udop/processing_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

# Copied from transformers.models.blip.processing_blip.BlipProcessor.post_process_image_text_to_text
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Expand Down
10 changes: 6 additions & 4 deletions src/transformers/pipelines/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,14 @@ def retrieve_images_in_chat(chat: dict, images: Optional[Union[str, List[str], "
idx_images += 1
else:
raise ValueError(
"The number of images in the chat should be the same as the number of images passed."
"The number of images in the chat should be the same as the number of images passed to the pipeline."
)

# The number of images passed should be consistent with the number of images in the chat without an image key
if idx_images != len(images):
raise ValueError("The number of images in the chat should be the same as the number of images passed.")
raise ValueError(
"The number of images in the chat should be the same as the number of images passed to the pipeline."
)

return retrieved_images

Expand Down Expand Up @@ -287,9 +289,9 @@ def __call__(
return super().__call__(chats, **kwargs)

# encourage the user to use the chat format if supported
if hasattr(self.processor, "chat_template") and self.processor.chat_template is not None:
if getattr(self.processor, "chat_template", None) is not None:
logger.warning_once(
"The pipeline detected no chat format in the prompt, but this model supports chat format. "
"The input data was not formatted as a chat with dicts containing 'role' and 'content' keys, even though this model supports chat. "
"Consider using the chat format for better results. For more information, see https://huggingface.co/docs/transformers/en/chat_templating"
)

Expand Down
28 changes: 14 additions & 14 deletions tests/pipelines/test_pipelines_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,27 @@ def open(*args, **kwargs):
class ImageTextToTextPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING

def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
pipe = ImageTextToTextPipeline(
model=model, tokenizer=tokenizer, image_processor=processor, torch_dtype=torch_dtype
)
def get_test_pipeline(self, model, tokenizer, processor, image_processor, torch_dtype="float32"):
pipe = ImageTextToTextPipeline(model=model, processor=processor, torch_dtype=torch_dtype)
image_token = processor.image_token if hasattr(processor, "image_token") else "<image>"
examples = {
"images": [
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
"./tests/fixtures/tests_samples/COCO/000000039769.png",
],
"text": [f"{image_token} This is a ", f"{image_token} Here I see a "],
}
examples = [
{
"images": Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
"text": f"{image_token} This is a ",
},
{
"images": "./tests/fixtures/tests_samples/COCO/000000039769.png",
"text": f"{image_token} Here I see a ",
},
]
return pipe, examples

def run_pipeline_test(self, pipe, examples):
outputs = pipe(examples.get("images"), text=examples.get("text"), max_new_tokens=20)
outputs = pipe(examples[0].get("images"), text=examples[0].get("text"), max_new_tokens=20)
self.assertEqual(
outputs,
[
[{"input_text": ANY(str), "generated_text": ANY(str)}],
[{"input_text": ANY(str), "generated_text": ANY(str)}],
{"input_text": ANY(str), "generated_text": ANY(str)},
],
)

Expand Down
27 changes: 25 additions & 2 deletions tests/test_pipeline_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,19 @@ def data(n):
yield copy.deepcopy(random.choice(examples))

out = []
for item in pipeline(data(10), batch_size=4):
out.append(item)
# check if pipeline call signature has more than one argument
# in this case, we can't use a single generator, we need to collate the data first
input_signature = {
k: v for k, v in inspect.signature(pipeline.__call__).parameters.items() if k != "kwargs"
}
if len(input_signature) > 1:
data_list = list(data(10))
# collate data_list
data_dict = {k: [d[k] for d in data_list] for k in data_list[0]}
out = pipeline(**data_dict, batch_size=4)
else:
for item in pipeline(data(10), batch_size=4):
out.append(item)
self.assertEqual(len(out), 10)

run_batch_test(pipeline, examples)
Expand Down Expand Up @@ -588,6 +599,18 @@ def test_pipeline_image_segmentation(self):
def test_pipeline_image_segmentation_fp16(self):
self.run_task_tests(task="image-segmentation", torch_dtype="float16")

@is_pipeline_test
@require_vision
@require_torch
def test_pipeline_image_text_to_text(self):
self.run_task_tests(task="image-text-to-text")

@is_pipeline_test
@require_vision
@require_torch
def test_pipeline_image_text_to_text_fp16(self):
self.run_task_tests(task="image-text-to-text", torch_dtype="float16")

@is_pipeline_test
@require_vision
def test_pipeline_image_to_text(self):
Expand Down

0 comments on commit 00174e8

Please sign in to comment.