Skip to content

Commit

Permalink
[tests] add test_image_token_filling
Browse files Browse the repository at this point in the history
  • Loading branch information
laurentd-lunit committed Sep 4, 2024
1 parent e9a7b30 commit 86dc7ff
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion tests/models/llava_next/test_processor_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

import torch
from transformers.testing_utils import require_vision
from transformers.utils import is_vision_available

Expand All @@ -39,3 +39,27 @@ def test_chat_template(self):

formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)

def test_image_token_filling(self):
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
processor.patch_size = 14
processor.vision_feature_select_strategy = "default"
image = torch.randint(0, 2, (3, 500, 316))
expected_image_tokens = 1526

messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
inputs = processor(
text=[processor.apply_chat_template(messages)],
images=[image],
return_tensors="pt",
)
image_tokens = (inputs["input_ids"] == 32000).sum().item()
self.assertEqual(expected_image_tokens, image_tokens)

0 comments on commit 86dc7ff

Please sign in to comment.