Skip to content

Commit

Permalink
fix batched inference
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Oct 31, 2024
1 parent fa25411 commit 9fe26c7
Showing 1 changed file with 42 additions and 23 deletions.
65 changes: 42 additions & 23 deletions src/transformers/pipelines/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ def __init__(self, images: Union[str, List[str], "Image.Image", List["Image.Imag
self.images = images
self.text = text

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
return ImageText(self.images[idx], self.text[idx])


def retrieve_images_in_chat(chat: dict, images: Optional[Union[str, List[str], "Image.Image", List["Image.Image"]]]):
"""
Expand Down Expand Up @@ -313,32 +319,35 @@ def __call__(
else:
image_token = IMAGE_TOKEN
# Check number of image_token token in each text
nested_images = False
num_images_in_text = [text_single.count(image_token) for text_single in text]
if sum(num_images_in_text) > 0:
if any(num > 1 for num in num_images_in_text) and batch_size > 1:
raise ValueError(
"The pipeline does not support multiple images for a single prompt with batch_size > 1."
)
# Check if already nested images and consistency
if isinstance(images[0], (list, tuple)):
if len(images) != len(text):
raise ValueError("The number of nested image groups and prompts should be the same.")
num_images_in_images = [len(image) for image in images]
if num_images_in_text != num_images_in_images:
if any(num > 1 for num in num_images_in_text):
if batch_size > 1:
raise ValueError(
f"The number of images in each nested image group should be the same as the number of {image_token} tokens in the corresponding prompt."
"The pipeline does not support multiple images for a single prompt with batch_size > 1."
)
elif sum(num_images_in_text) != len(images):
raise ValueError(
f"The total number of {image_token} tokens in the prompts should be the same as the number of images passed."
)
else:
# Reorganize the images to match the prompts
images_reorganized = []
for num_images in num_images_in_text:
images_reorganized.append(images[:num_images])
images = images[num_images:]
images = images_reorganized
nested_images = True
# Check if already nested images and consistency
if isinstance(images[0], (list, tuple)):
if len(images) != len(text):
raise ValueError("The number of nested image groups and prompts should be the same.")
num_images_in_images = [len(image) for image in images]
if num_images_in_text != num_images_in_images:
raise ValueError(
f"The number of images in each nested image group should be the same as the number of {image_token} tokens in the corresponding prompt."
)
elif sum(num_images_in_text) != len(images):
raise ValueError(
f"The total number of {image_token} tokens in the prompts should be the same as the number of images passed."
)
else:
# Reorganize the images to match the prompts
images_reorganized = []
for num_images in num_images_in_text:
images_reorganized.append(images[:num_images])
images = images[num_images:]
images = images_reorganized
elif len(text) == 1 and len(images) > 1:
logger.warning(
"The pipeline detected multiple images for one prompt, but no image tokens in the prompt. "
Expand All @@ -351,7 +360,17 @@ def __call__(
raise ValueError(
"Undefined behavior, please check the number of images and prompts, and nest the images to match the prompts."
)
return super().__call__([ImageText(image, text_single) for image, text_single in zip(images, text)], **kwargs)

if nested_images:
return super().__call__(
[ImageText(image, text_single) for image, text_single in zip(images, text)], **kwargs
)

# otherwise, we can flatten the images and text as we have a 1:1 relationship
if isinstance(images[0], (list, tuple)):
images = [img for img_list in images for img in img_list]

return super().__call__(ImageText(images, text), **kwargs)

def preprocess(
self, inputs=None, truncation=None, padding=False, max_length=None, timeout=None, continue_final_message=None
Expand Down

0 comments on commit 9fe26c7

Please sign in to comment.