From 7a3a83e3b87f50fe9c0985a5c5bcc1d4cf2e95cd Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 11 Jan 2025 13:50:05 +0800 Subject: [PATCH] [CI/Build] Move model-specific multi-modal processing tests (#11934) Signed-off-by: DarkLight1337 --- .buildkite/test-pipeline.yaml | 1 + .../processing => multimodal}/__init__.py | 0 .../models/multimodal/processing/__init__.py | 0 .../multimodal/processing/test_common.py | 201 +++++++++++++++ .../processing/test_idefics3.py | 4 +- .../processing/test_internvl.py | 4 +- .../processing/test_llava_next.py | 2 +- .../processing/test_llava_onevision.py | 2 +- .../processing/test_phi3v.py | 4 +- .../processing/test_qwen.py | 4 +- .../processing/test_qwen2_vl.py | 4 +- tests/multimodal/test_processing.py | 232 +----------------- tests/multimodal/utils.py | 33 +++ 13 files changed, 251 insertions(+), 240 deletions(-) rename tests/models/{decoder_only/vision_language/processing => multimodal}/__init__.py (100%) create mode 100644 tests/models/multimodal/processing/__init__.py create mode 100644 tests/models/multimodal/processing/test_common.py rename tests/models/{decoder_only/vision_language => multimodal}/processing/test_idefics3.py (98%) rename tests/models/{decoder_only/vision_language => multimodal}/processing/test_internvl.py (98%) rename tests/models/{decoder_only/vision_language => multimodal}/processing/test_llava_next.py (99%) rename tests/models/{decoder_only/vision_language => multimodal}/processing/test_llava_onevision.py (99%) rename tests/models/{decoder_only/vision_language => multimodal}/processing/test_phi3v.py (95%) rename tests/models/{decoder_only/vision_language => multimodal}/processing/test_qwen.py (98%) rename tests/models/{decoder_only/vision_language => multimodal}/processing/test_qwen2_vl.py (96%) create mode 100644 tests/multimodal/utils.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d3bd809cfdf24..cf82210f96ee3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -368,6 +368,7 @@ steps: - tests/models/encoder_decoder/vision_language commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' - pytest -v -s models/embedding/vision_language -m core_model diff --git a/tests/models/decoder_only/vision_language/processing/__init__.py b/tests/models/multimodal/__init__.py similarity index 100% rename from tests/models/decoder_only/vision_language/processing/__init__.py rename to tests/models/multimodal/__init__.py diff --git a/tests/models/multimodal/processing/__init__.py b/tests/models/multimodal/processing/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py new file mode 100644 index 0000000000000..0a38779e0e4f0 --- /dev/null +++ b/tests/models/multimodal/processing/test_common.py @@ -0,0 +1,201 @@ +from functools import partial + +import numpy as np +import pytest +from PIL import Image + +from vllm.config import ModelConfig +from vllm.inputs import InputProcessingContext +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.processing import ProcessingCache +from vllm.multimodal.utils import cached_get_tokenizer + +from ....multimodal.utils import random_audio, random_image, random_video + + +def _test_processing_correctness( + model_id: str, + modalities: dict[str, bool], + hit_rate: float, + num_batches: int, + simplify_rate: float, +): + if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3": + hf_overrides = {"architectures": ["MantisForConditionalGeneration"]} + else: + hf_overrides = {} + + limit_mm_per_prompt = { + modality: 3 if supports_multi else 1 + for modality, supports_multi in modalities.items() + } + + model_config = ModelConfig( + model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=True, + seed=0, + dtype="float16", + revision=None, + hf_overrides=hf_overrides, + limit_mm_per_prompt=limit_mm_per_prompt, + ) + + model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) + factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] + ctx = InputProcessingContext( + model_config, + tokenizer=cached_get_tokenizer(model_config.tokenizer), + ) + # Ensure that it can fit all of the data + cache = ProcessingCache(capacity=1 << 30) + + baseline_processor = factories.build_processor(ctx, cache=None) + cached_processor = factories.build_processor(ctx, cache=cache) + dummy_inputs = baseline_processor.dummy_inputs + tokenizer = baseline_processor.info.get_tokenizer() + + rng = np.random.RandomState(0) + + input_to_hit = { + "image": Image.new("RGB", size=(128, 128)), + "video": np.zeros((4, 128, 128, 3), dtype=np.uint8), + "audio": (np.zeros((512, )), 16000), + } + input_factory = { + "image": + partial(random_image, rng, min_wh=128, max_wh=256), + "video": + partial(random_video, + rng, + min_frames=2, + max_frames=8, + min_wh=128, + max_wh=256), + "audio": + partial(random_audio, rng, min_len=512, max_len=1024, sr=16000), + } + + for batch_idx in range(num_batches): + mm_data = { + k: + [(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]()) + for _ in range(rng.randint(limit_mm_per_prompt[k]))] + for k in modalities + } + + mm_counts = {k: len(vs) for k, vs in mm_data.items()} + prompt = dummy_inputs.get_dummy_processor_inputs( + model_config.max_model_len, + mm_counts, + ).prompt_text + + # Drop unnecessary keys and test single -> multi conversion + if rng.rand() < simplify_rate: + for k in list(mm_data.keys()): + if not mm_data[k]: + del mm_data[k] + elif len(mm_data[k]) == 1: + mm_data[k] = mm_data[k][0] + + baseline_result = baseline_processor.apply( + prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + cached_result = cached_processor.apply( + prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + + assert baseline_result == cached_result, ( + f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") + + baseline_tokenized_result = baseline_processor.apply( + tokenizer.encode(prompt), + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + + assert baseline_result == baseline_tokenized_result, ( + f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") + + cached_tokenized_result = cached_processor.apply( + tokenizer.encode(prompt), + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + + assert cached_result == cached_tokenized_result, ( + f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") + + +# yapf: disable +# True if the model supports multiple data items of the modality per request +@pytest.mark.parametrize(("model_id", "modalities"), [ + ("rhymes-ai/Aria", {"image": True}), + ("Salesforce/blip2-opt-2.7b", {"image": False}), + ("facebook/chameleon-7b", {"image": False}), + ("adept/fuyu-8b", {"image": False}), + ("llava-hf/llava-1.5-7b-hf", {"image": True}), + ("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}), + ("llava-hf/LLaVA-NeXT-Video-7B-hf", {"video": False}), + ("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", {"image": True, "video": True}), # noqa: E501 + ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}), + ("mistral-community/pixtral-12b", {"image": True}), + ("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}), + ("Qwen/Qwen2-Audio-7B-Instruct", {"audio": True}), + ("fixie-ai/ultravox-v0_3", {"audio": True}), +]) +@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) +@pytest.mark.parametrize("num_batches", [32]) +@pytest.mark.parametrize("simplify_rate", [1.0]) +# yapf: enable +def test_processing_correctness( + model_id: str, + modalities: dict[str, bool], + hit_rate: float, + num_batches: int, + simplify_rate: float, +): + _test_processing_correctness( + model_id, + modalities, + hit_rate=hit_rate, + num_batches=num_batches, + simplify_rate=simplify_rate, + ) + + +# yapf: disable +@pytest.mark.parametrize(("model_id", "modalities"), [ + ("microsoft/Phi-3-vision-128k-instruct", {"image": True}), +]) +@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) +@pytest.mark.parametrize("num_batches", [32]) +@pytest.mark.parametrize("simplify_rate", [1.0]) +# yapf: enable +def test_processing_correctness_phi3v( + model_id: str, + modalities: dict[str, bool], + hit_rate: float, + num_batches: int, + simplify_rate: float, +): + # HACK - this is an attempted workaround for the following bug + # https://github.com/huggingface/transformers/issues/34307 + from transformers import AutoImageProcessor # noqa: F401 + from transformers import AutoProcessor # noqa: F401 + + AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True) + + _test_processing_correctness( + model_id, + modalities, + hit_rate=hit_rate, + num_batches=num_batches, + simplify_rate=simplify_rate, + ) diff --git a/tests/models/decoder_only/vision_language/processing/test_idefics3.py b/tests/models/multimodal/processing/test_idefics3.py similarity index 98% rename from tests/models/decoder_only/vision_language/processing/test_idefics3.py rename to tests/models/multimodal/processing/test_idefics3.py index c71a2d359043d..69b91ad4a5df8 100644 --- a/tests/models/decoder_only/vision_language/processing/test_idefics3.py +++ b/tests/models/multimodal/processing/test_idefics3.py @@ -8,8 +8,8 @@ from vllm.inputs import InputContext, token_inputs from vllm.multimodal import MultiModalRegistry -from .....conftest import _ImageAssets -from ....utils import build_model_context +from ....conftest import _ImageAssets +from ...utils import build_model_context models = ["HuggingFaceM4/Idefics3-8B-Llama3"] diff --git a/tests/models/decoder_only/vision_language/processing/test_internvl.py b/tests/models/multimodal/processing/test_internvl.py similarity index 98% rename from tests/models/decoder_only/vision_language/processing/test_internvl.py rename to tests/models/multimodal/processing/test_internvl.py index af0c2aa211998..d6c60595ca5ea 100644 --- a/tests/models/decoder_only/vision_language/processing/test_internvl.py +++ b/tests/models/multimodal/processing/test_internvl.py @@ -7,8 +7,8 @@ from vllm.inputs import InputContext, token_inputs from vllm.multimodal import MultiModalRegistry -from .....conftest import _ImageAssets -from ....utils import build_model_context +from ....conftest import _ImageAssets +from ...utils import build_model_context models = ["OpenGVLab/InternVL2-2B"] diff --git a/tests/models/decoder_only/vision_language/processing/test_llava_next.py b/tests/models/multimodal/processing/test_llava_next.py similarity index 99% rename from tests/models/decoder_only/vision_language/processing/test_llava_next.py rename to tests/models/multimodal/processing/test_llava_next.py index 689d17be81889..1eec35d9c3c72 100644 --- a/tests/models/decoder_only/vision_language/processing/test_llava_next.py +++ b/tests/models/multimodal/processing/test_llava_next.py @@ -10,7 +10,7 @@ from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.utils import cached_get_tokenizer -from ....utils import build_model_context +from ...utils import build_model_context def _validate_image_prompt_replacements_one( diff --git a/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py b/tests/models/multimodal/processing/test_llava_onevision.py similarity index 99% rename from tests/models/decoder_only/vision_language/processing/test_llava_onevision.py rename to tests/models/multimodal/processing/test_llava_onevision.py index a033354f0e9b8..94ea604c58b43 100644 --- a/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py +++ b/tests/models/multimodal/processing/test_llava_onevision.py @@ -10,7 +10,7 @@ from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.utils import cached_get_tokenizer -from ....utils import build_model_context +from ...utils import build_model_context def _validate_image_prompt_replacements_one( diff --git a/tests/models/decoder_only/vision_language/processing/test_phi3v.py b/tests/models/multimodal/processing/test_phi3v.py similarity index 95% rename from tests/models/decoder_only/vision_language/processing/test_phi3v.py rename to tests/models/multimodal/processing/test_phi3v.py index c5b77260c6544..7f82a8f18f0ca 100644 --- a/tests/models/decoder_only/vision_language/processing/test_phi3v.py +++ b/tests/models/multimodal/processing/test_phi3v.py @@ -4,8 +4,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import cached_get_tokenizer -from .....conftest import _ImageAssets -from ....utils import build_model_context +from ....conftest import _ImageAssets +from ...utils import build_model_context @pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"]) diff --git a/tests/models/decoder_only/vision_language/processing/test_qwen.py b/tests/models/multimodal/processing/test_qwen.py similarity index 98% rename from tests/models/decoder_only/vision_language/processing/test_qwen.py rename to tests/models/multimodal/processing/test_qwen.py index 163220c91a27d..af0ace711ba3e 100644 --- a/tests/models/decoder_only/vision_language/processing/test_qwen.py +++ b/tests/models/multimodal/processing/test_qwen.py @@ -9,8 +9,8 @@ from vllm.multimodal import MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer -from .....conftest import IMAGE_ASSETS -from ....utils import build_model_context +from ....conftest import IMAGE_ASSETS +from ...utils import build_model_context ### Multimodal preprocessing tests SAMPLE_IMAGE = IMAGE_ASSETS[0].pil_image diff --git a/tests/models/decoder_only/vision_language/processing/test_qwen2_vl.py b/tests/models/multimodal/processing/test_qwen2_vl.py similarity index 96% rename from tests/models/decoder_only/vision_language/processing/test_qwen2_vl.py rename to tests/models/multimodal/processing/test_qwen2_vl.py index 0d54802f2b733..de14fbbffe5b7 100644 --- a/tests/models/decoder_only/vision_language/processing/test_qwen2_vl.py +++ b/tests/models/multimodal/processing/test_qwen2_vl.py @@ -3,8 +3,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import cached_get_tokenizer -from .....conftest import _ImageAssets -from ....utils import build_model_context +from ....conftest import _ImageAssets +from ...utils import build_model_context @pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index d18909a4197b6..54269c3ef7ce0 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -1,30 +1,25 @@ from contextlib import nullcontext -from functools import partial from typing import cast from unittest.mock import MagicMock import numpy as np import pytest -from PIL import Image from vllm.config import ModelConfig -from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY -# yapf conflicts with isort for this block -# yapf: disable -from vllm.multimodal.processing import (PlaceholderInfo, ProcessingCache, - PromptReplacement, +from vllm.multimodal.processing import (PlaceholderInfo, PromptReplacement, find_mm_placeholders, find_text_matches, find_token_matches, iter_token_matches, replace_text_matches, replace_token_matches) -# yapf: enable from vllm.multimodal.profiling import MultiModalProfiler from vllm.multimodal.utils import cached_get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import full_groupby +from .utils import random_image + # yapf: disable @pytest.mark.parametrize( @@ -531,37 +526,6 @@ def test_find_mm_placeholders( assert result == expected -def _rand_img(rng: np.random.RandomState, min_wh: int, max_wh: int): - w, h = rng.randint(min_wh, max_wh, size=(2, )) - arr = rng.randint(0, 255, size=(w, h, 3), dtype=np.uint8) - return Image.fromarray(arr) - - -def _rand_video( - rng: np.random.RandomState, - min_frames: int, - max_frames: int, - min_wh: int, - max_wh: int, -): - # Temporary workaround for https://github.com/huggingface/transformers/issues/35412 - num_frames = rng.randint(min_frames, max_frames) - num_frames = (num_frames // 2) * 2 - - w, h = rng.randint(min_wh, max_wh, size=(2, )) - return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8) - - -def _rand_audio( - rng: np.random.RandomState, - min_len: int, - max_len: int, - sr: int, -): - audio_len = rng.randint(min_len, max_len) - return rng.rand(audio_len), sr - - @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize( ("limit", "num_supported", "is_valid"), @@ -628,7 +592,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): ) rng = np.random.RandomState(0) - image = _rand_img(rng, min_wh=128, max_wh=256) + image = random_image(rng, min_wh=128, max_wh=256) if num_images == 0: mm_data = {} elif num_images == 1: @@ -647,191 +611,3 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): mm_data=mm_data, hf_processor_mm_kwargs={}, ) - - -def _test_processing_correctness( - model_id: str, - modalities: dict[str, bool], - hit_rate: float, - num_batches: int, - simplify_rate: float, -): - if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3": - hf_overrides = {"architectures": ["MantisForConditionalGeneration"]} - else: - hf_overrides = {} - - limit_mm_per_prompt = { - modality: 3 if supports_multi else 1 - for modality, supports_multi in modalities.items() - } - - model_config = ModelConfig( - model_id, - task="auto", - tokenizer=model_id, - tokenizer_mode="auto", - trust_remote_code=True, - seed=0, - dtype="float16", - revision=None, - hf_overrides=hf_overrides, - limit_mm_per_prompt=limit_mm_per_prompt, - ) - - model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) - factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] - ctx = InputProcessingContext( - model_config, - tokenizer=cached_get_tokenizer(model_config.tokenizer), - ) - # Ensure that it can fit all of the data - cache = ProcessingCache(capacity=1 << 30) - - baseline_processor = factories.build_processor(ctx, cache=None) - cached_processor = factories.build_processor(ctx, cache=cache) - dummy_inputs = baseline_processor.dummy_inputs - tokenizer = baseline_processor.info.get_tokenizer() - - rng = np.random.RandomState(0) - - input_to_hit = { - "image": Image.new("RGB", size=(128, 128)), - "video": np.zeros((4, 128, 128, 3), dtype=np.uint8), - "audio": (np.zeros((512, )), 16000), - } - input_factory = { - "image": - partial(_rand_img, rng, min_wh=128, max_wh=256), - "video": - partial(_rand_video, - rng, - min_frames=2, - max_frames=8, - min_wh=128, - max_wh=256), - "audio": - partial(_rand_audio, rng, min_len=512, max_len=1024, sr=16000), - } - - for batch_idx in range(num_batches): - mm_data = { - k: - [(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]()) - for _ in range(rng.randint(limit_mm_per_prompt[k]))] - for k in modalities - } - - mm_counts = {k: len(vs) for k, vs in mm_data.items()} - prompt = dummy_inputs.get_dummy_processor_inputs( - model_config.max_model_len, - mm_counts, - ).prompt_text - - # Drop unnecessary keys and test single -> multi conversion - if rng.rand() < simplify_rate: - for k in list(mm_data.keys()): - if not mm_data[k]: - del mm_data[k] - elif len(mm_data[k]) == 1: - mm_data[k] = mm_data[k][0] - - baseline_result = baseline_processor.apply( - prompt, - mm_data=mm_data, - hf_processor_mm_kwargs={}, - ) - cached_result = cached_processor.apply( - prompt, - mm_data=mm_data, - hf_processor_mm_kwargs={}, - ) - - assert baseline_result == cached_result, ( - f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") - - baseline_tokenized_result = baseline_processor.apply( - tokenizer.encode(prompt), - mm_data=mm_data, - hf_processor_mm_kwargs={}, - ) - - assert baseline_result == baseline_tokenized_result, ( - f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") - - cached_tokenized_result = cached_processor.apply( - tokenizer.encode(prompt), - mm_data=mm_data, - hf_processor_mm_kwargs={}, - ) - - assert cached_result == cached_tokenized_result, ( - f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") - - -# yapf: disable -# True if the model supports multiple data items of the modality per request -@pytest.mark.parametrize(("model_id", "modalities"), [ - ("rhymes-ai/Aria", {"image": True}), - ("Salesforce/blip2-opt-2.7b", {"image": False}), - ("facebook/chameleon-7b", {"image": False}), - ("adept/fuyu-8b", {"image": False}), - ("llava-hf/llava-1.5-7b-hf", {"image": True}), - ("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}), - ("llava-hf/LLaVA-NeXT-Video-7B-hf", {"video": False}), - ("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", {"image": True, "video": True}), # noqa: E501 - ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}), - ("mistral-community/pixtral-12b", {"image": True}), - ("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}), - ("Qwen/Qwen2-Audio-7B-Instruct", {"audio": True}), - ("fixie-ai/ultravox-v0_3", {"audio": True}), -]) -@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) -@pytest.mark.parametrize("num_batches", [32]) -@pytest.mark.parametrize("simplify_rate", [1.0]) -# yapf: enable -def test_processing_correctness( - model_id: str, - modalities: dict[str, bool], - hit_rate: float, - num_batches: int, - simplify_rate: float, -): - _test_processing_correctness( - model_id, - modalities, - hit_rate=hit_rate, - num_batches=num_batches, - simplify_rate=simplify_rate, - ) - - -# yapf: disable -@pytest.mark.parametrize(("model_id", "modalities"), [ - ("microsoft/Phi-3-vision-128k-instruct", {"image": True}), -]) -@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) -@pytest.mark.parametrize("num_batches", [32]) -@pytest.mark.parametrize("simplify_rate", [1.0]) -# yapf: enable -def test_processing_correctness_phi3v( - model_id: str, - modalities: dict[str, bool], - hit_rate: float, - num_batches: int, - simplify_rate: float, -): - # HACK - this is an attempted workaround for the following bug - # https://github.com/huggingface/transformers/issues/34307 - from transformers import AutoImageProcessor # noqa: F401 - from transformers import AutoProcessor # noqa: F401 - - AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True) - - _test_processing_correctness( - model_id, - modalities, - hit_rate=hit_rate, - num_batches=num_batches, - simplify_rate=simplify_rate, - ) diff --git a/tests/multimodal/utils.py b/tests/multimodal/utils.py new file mode 100644 index 0000000000000..29aeca605109b --- /dev/null +++ b/tests/multimodal/utils.py @@ -0,0 +1,33 @@ +import numpy as np +from PIL import Image + + +def random_image(rng: np.random.RandomState, min_wh: int, max_wh: int): + w, h = rng.randint(min_wh, max_wh, size=(2, )) + arr = rng.randint(0, 255, size=(w, h, 3), dtype=np.uint8) + return Image.fromarray(arr) + + +def random_video( + rng: np.random.RandomState, + min_frames: int, + max_frames: int, + min_wh: int, + max_wh: int, +): + # Temporary workaround for https://github.com/huggingface/transformers/issues/35412 + num_frames = rng.randint(min_frames, max_frames) + num_frames = (num_frames // 2) * 2 + + w, h = rng.randint(min_wh, max_wh, size=(2, )) + return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8) + + +def random_audio( + rng: np.random.RandomState, + min_len: int, + max_len: int, + sr: int, +): + audio_len = rng.randint(min_len, max_len) + return rng.rand(audio_len), sr