Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pixtral #8377

Merged
merged 16 commits into from
Sep 11, 2024
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ Multimodal Language Models
- Audio\ :sup:`E+`
- :code:`fixie-ai/ultravox-v0_3`
-
* - :code:`PixtralForConditionalGeneration`
- Pixtral
- Image\ :sup:`E+`
- :code:`mistralai/Pixtral-12B-2409`
-
ywang96 marked this conversation as resolved.
Show resolved Hide resolved

| :sup:`E` Pre-computed embeddings can be inputted for this modality.
| :sup:`+` Multiple items can be inputted per text prompt for this modality.
Expand Down
164 changes: 164 additions & 0 deletions examples/offline_inference_pixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# ruff: noqa
import argparse

from vllm import LLM
from vllm.sampling_params import SamplingParams

# This script is an offline demo for running Pixtral.
#
# If you want to run a server/client setup, please follow this code:
#
# - Server:
#
# ```bash
# vllm serve mistralai/Pixtral-12B-2409 --tokenizer_mode mistral --limit_mm_per_prompt 'image=4' --max_num_batched_tokens 16384
# ```
#
# - Client:
#
# ```bash
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
# --header 'Content-Type: application/json' \
# --header 'Authorization: Bearer token' \
# --data '{
# "model": "mistralai/Pixtral-12B-2409",
# "messages": [
# {
# "role": "user",
# "content": [
# {"type" : "text", "text": "Describe this image in detail please."},
# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
# {"type" : "text", "text": "and this one as well. Answer in French."},
# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
# ]
# }
# ]
# }'
# ```
#
# Usage:
# python demo.py simple
# python demo.py advanced


def run_simple_demo():
model_name = "mistralai/Pixtral-12B-2409"
sampling_params = SamplingParams(max_tokens=8192)

llm = LLM(model=model_name, tokenizer_mode="mistral")

prompt = "Describe this image in one sentence."
image_url = "https://picsum.photos/id/237/200/300"

messages = [
{
"role":
"user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
],
},
]
outputs = llm.chat(messages, sampling_params=sampling_params)

print(outputs[0].outputs[0].text)


def run_advanced_demo():
model_name = "mistralai/Pixtral-12B-2409"
max_img_per_msg = 5
max_tokens_per_img = 4096

sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
llm = LLM(
model=model_name,
tokenizer_mode="mistral",
limit_mm_per_prompt={"image": max_img_per_msg},
max_num_batched_tokens=max_img_per_msg * max_tokens_per_img,
)

prompt = "Describe the following image."

url_1 = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png"
url_2 = "https://picsum.photos/seed/picsum/200/300"
url_3 = "https://picsum.photos/id/32/512/512"

messages = [
{
"role":
"user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": url_1
}
},
{
"type": "image_url",
"image_url": {
"url": url_2
}
},
],
},
{
"role": "assistant",
"content": "The images show nature.",
},
{
"role": "user",
"content": "More details please and answer only in French!.",
},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": url_3
}
},
],
},
]

outputs = llm.chat(messages=messages, sampling_params=sampling_params)
print(outputs[0].outputs[0].text)


def main():
parser = argparse.ArgumentParser(
description="Run a demo in simple or advanced mode.")

parser.add_argument(
"mode",
choices=["simple", "advanced"],
help="Specify the demo mode: 'simple' or 'advanced'",
)

args = parser.parse_args()

if args.mode == "simple":
print("Running simple demo...")
run_simple_demo()
elif args.mode == "advanced":
print("Running advanced demo...")
run_advanced_demo()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pyzmq
msgspec
gguf == 0.9.1
importlib_metadata
mistral_common >= 1.3.4
mistral_common >= 1.4.0
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
einops # Required for Qwen2-VL.
58 changes: 58 additions & 0 deletions tests/models/test_pixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.

Run `pytest tests/models/test_mistral.py`.
"""
import pytest

from vllm.sampling_params import SamplingParams

MODELS = ["mistralai/Pixtral-12B-2409"]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
image_urls = [
"https://picsum.photos/id/237/200/300",
"https://picsum.photos/seed/picsum/200/300"
]
expected = [
"The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", # noqa
"The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset." # noqa
]
prompt = "Describe the image in one short sentence."

sampling_params = SamplingParams(max_tokens=512, temperature=0.0)

with vllm_runner(model, dtype=dtype,
tokenizer_mode="mistral") as vllm_model:

for i, image_url in enumerate(image_urls):
messages = [
{
"role":
"user",
"content": [{
"type": "text",
"text": prompt
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}]
},
]

outputs = vllm_model.model.chat(messages,
sampling_params=sampling_params)
assert outputs[0].outputs[0].text == expected[i]
3 changes: 2 additions & 1 deletion vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def _placeholder_str(self, modality: ModalityStr,
return f"<|image_{current_count}|>"
if model_type == "minicpmv":
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
"pixtral"):
# These models do not use image tokens in the prompt
return None
if model_type == "qwen":
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"PixtralForConditionalGeneration": ("pixtral",
"PixtralForConditionalGeneration"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
"Qwen2VLForConditionalGeneration"),
}
Expand Down
Loading
Loading