Skip to content

Commit

Permalink
Fix Llava-NeXT / Llava-NeXT Video / Llava-OneVision's token unpadding…
Browse files Browse the repository at this point in the history
… mismatch (huggingface#35779)

* Fix Llava OneVision's token padding

* Fix Llava next and Llava next video's token unpadding for consistency
  • Loading branch information
sheryc authored and bursteratom committed Jan 28, 2025
1 parent 5a6f086 commit b4a24a5
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/transformers/models/llava_next/processing_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,11 @@ def _get_unpadded_features(self, height, width, patches_height, patches_width, s
original_aspect_ratio = width / height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
new_height = int(round(height * (current_width / width), 7))
padding = (current_height - new_height) // 2
current_height -= padding * 2
else:
new_width = (width * current_height) // height
new_width = int(round(width * (current_height / height), 7))
padding = (current_width - new_width) // 2
current_width -= padding * 2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,11 @@ def _get_unpadded_features(self, height, width, patches_height, patches_width, s
original_aspect_ratio = width / height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
new_height = int(round(height * (current_width / width), 7))
padding = (current_height - new_height) // 2
current_height -= padding * 2
else:
new_width = (width * current_height) // height
new_width = int(round(width * (current_height / height), 7))
padding = (current_width - new_width) // 2
current_width -= padding * 2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def _get_number_of_features(self, orig_height: int, orig_width: int, height: int
num_image_tokens = unpadded_features + newline_features + base_features
return num_image_tokens

# Adapted from transformers.models.llava_next.processing_llava_next.LlavaNextProcessor._get_unpadded_features
def _get_unpadded_features(self, height, width, patches_height, patches_width, scale_height, scale_width):
"""
Get number of features for a given image with height/width. LLaVA-NeXT is different from LLaVA
Expand All @@ -237,11 +238,11 @@ def _get_unpadded_features(self, height, width, patches_height, patches_width, s
original_aspect_ratio = width / height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
new_height = int(height * (current_width / width))
new_height = int(round(height * (current_width / width), 7))
padding = (current_height - new_height) // 2
current_height -= padding * 2
else:
new_width = int(width * (current_height / height))
new_width = int(round(width * (current_height / height), 7))
padding = (current_width - new_width) // 2
current_width -= padding * 2

Expand Down

0 comments on commit b4a24a5

Please sign in to comment.