Skip to content

Commit

Permalink
Add and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Nov 5, 2024
1 parent 4a1b1e0 commit 0410d9f
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 76 deletions.
3 changes: 1 addition & 2 deletions docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,7 @@ In this example, we will serve the ``TIGER-Lab/VLM2Vec-Full`` model.
.. code-block:: bash
vllm serve TIGER-Lab/VLM2Vec-Full --task embedding \
--trust-remote-code --max-model-len 4096 \
--chat-template examples/template_vlm2vec.jinja --chat-template-content-format openai
--trust-remote-code --max-model-len 4096 --chat-template examples/template_vlm2vec.jinja
.. important::

Expand Down
6 changes: 3 additions & 3 deletions docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ completion = client.chat.completions.create(
)
```
Most chat templates for LLMs expect the `content` field to be a string but there are some newer models like
`meta-llama/Llama-Guard-3-1B` that expect the content to be according to the OpenAI schema in the request.
vLLM provides best-effort support to detect this automatically, which is logged as a string like
`meta-llama/Llama-Guard-3-1B` that expect the content to be formatted according to the OpenAI schema in the
request. vLLM provides best-effort support to detect this automatically, which is logged as a string like
*"Detected the chat template content format to be..."*, and internally converts incoming requests to match
the detected format. If the result is not what you expect, you can use the `--chat-template-content-format`
CLI argument to explicitly specify which format to use (`"string"` or `"openai"`).
CLI argument to override which format to use (`"string"` or `"openai"`).


## Command line arguments for the server
Expand Down
1 change: 1 addition & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_serving_chat_should_set_correct_max_tokens():
BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
lora_modules=None,
prompt_adapters=None,
request_logger=None)
Expand Down
83 changes: 81 additions & 2 deletions tests/entrypoints/test_chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,24 @@

from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (parse_chat_messages,
parse_chat_messages_futures)
from vllm.entrypoints.chat_utils import (load_chat_template,
parse_chat_messages,
parse_chat_messages_futures,
resolve_chat_template_content_format)
from vllm.entrypoints.llm import apply_hf_chat_template
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import encode_image_base64
from vllm.transformers_utils.tokenizer_group import TokenizerGroup

from ..utils import VLLM_PATH

EXAMPLES_DIR = VLLM_PATH / "examples"

PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_3"
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -702,3 +711,73 @@ def get_conversation(is_hf: bool):
)

assert hf_result == vllm_result


# yapf: disable
@pytest.mark.parametrize(
("model", "expected_format"),
[(PHI3V_MODEL_ID, "string"),
(QWEN2VL_MODEL_ID, "openai"),
(ULTRAVOX_MODEL_ID, "string"),
(MLLAMA_MODEL_ID, "openai"),
(LLAMA_GUARD_MODEL_ID, "openai")],
)
# yapf: enable
def test_resolve_content_format_hf_defined(model, expected_format):
tokenizer_group = TokenizerGroup(
model,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
tokenizer = tokenizer_group.tokenizer

resolved_format = resolve_chat_template_content_format(
tokenizer.chat_template,
"auto",
tokenizer,
)

assert resolved_format == expected_format


# yapf: disable
@pytest.mark.parametrize(
("template_path", "expected_format"),
[("template_alpaca.jinja", "string"),
("template_baichuan.jinja", "string"),
("template_blip2.jinja", "string"),
("template_chatglm.jinja", "string"),
("template_chatglm2.jinja", "string"),
("template_chatml.jinja", "string"),
("template_falcon_180b.jinja", "string"),
("template_falcon.jinja", "string"),
("template_inkbot.jinja", "string"),
("template_llava.jinja", "string"),
("template_vlm2vec.jinja", "openai"),
("tool_chat_template_granite_20b_fc.jinja", "string"),
("tool_chat_template_hermes.jinja", "string"),
("tool_chat_template_internlm2_tool.jinja", "string"),
("tool_chat_template_llama3.1_json.jinja", "string"),
("tool_chat_template_llama3.2_json.jinja", "string"),
("tool_chat_template_mistral_parallel.jinja", "string"),
("tool_chat_template_mistral.jinja", "string")],
)
# yapf: enable
def test_resolve_content_format_examples(template_path, expected_format):
tokenizer_group = TokenizerGroup(
PHI3V_MODEL_ID,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
dummy_tokenizer = tokenizer_group.tokenizer
dummy_tokenizer.chat_template = None

resolved_format = resolve_chat_template_content_format(
load_chat_template(EXAMPLES_DIR / template_path),
"auto",
dummy_tokenizer,
)

assert resolved_format == expected_format
134 changes: 76 additions & 58 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Literal, Mapping, Optional, Tuple, TypeVar, Union, cast)

import jinja2
import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
# yapf conflicts with isort for this block
Expand Down Expand Up @@ -144,6 +143,80 @@ class ConversationMessage(TypedDict, total=False):
_ChatTemplateContentFormat = Literal["string", "openai"]


def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
if isinstance(node, jinja2.nodes.Name):
return node.ctx == "load" and node.name == varname

return False


def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
if isinstance(node, jinja2.nodes.Getitem):
return (node.ctx == "load" and _is_var_access(node.node, varname)
and isinstance(node.arg, jinja2.nodes.Const)
and node.arg.value == key)

if isinstance(node, jinja2.nodes.Getattr):
return (node.ctx == "load" and _is_var_access(node.node, varname)
and node.attr == key)

return False


def _iter_self_and_descendants(node: jinja2.nodes.Node):
yield node
yield from node.find_all(jinja2.nodes.Node)


def _iter_nodes_define_message(chat_template_ast: jinja2.nodes.Template):
# Search for {%- for message in messages -%} loops
for loop_ast in chat_template_ast.find_all(jinja2.nodes.For):
loop_target = loop_ast.target

# yapf: disable
if any(
_is_var_access(loop_iter_desc, "messages") for loop_iter_desc
in _iter_self_and_descendants(loop_ast.iter)
): # yapf: enable
assert isinstance(loop_target, jinja2.nodes.Name)
yield loop_ast, loop_target.name


def _iter_nodes_define_content_item(chat_template_ast: jinja2.nodes.Template):
for node, message_varname in _iter_nodes_define_message(chat_template_ast):
# Search for {%- for content in message['content'] -%} loops
for loop_ast in node.find_all(jinja2.nodes.For):
loop_target = loop_ast.target

# yapf: disable
if any(
_is_attr_access(loop_iter_desc, message_varname, "content")
for loop_iter_desc in _iter_self_and_descendants(loop_ast.iter)
): # yapf: enable
assert isinstance(loop_target, jinja2.nodes.Name)
yield loop_ast, loop_target.name


def _detect_content_format(
chat_template: str,
*,
default: _ChatTemplateContentFormat,
) -> _ChatTemplateContentFormat:
try:
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
jinja_ast = jinja_compiled.environment.parse(chat_template)
except Exception:
logger.exception("Error when compiling Jinja template")
return default

try:
next(_iter_nodes_define_content_item(jinja_ast))
except StopIteration:
return "string"
else:
return "openai"


def _resolve_chat_template_content_format(
chat_template: Optional[str],
given_format: ChatTemplateContentFormatOption,
Expand All @@ -164,7 +237,7 @@ def _resolve_chat_template_content_format(
jinja_text = load_chat_template(chat_template, is_literal=True)

detected_format = ("string" if jinja_text is None else
_detect_chat_template_content_format(jinja_text))
_detect_content_format(jinja_text, default="string"))

return detected_format if given_format == "auto" else given_format

Expand All @@ -183,7 +256,7 @@ def resolve_chat_template_content_format(

logger.info(
"Detected the chat template content format to be '%s'. "
"Set `--chat-template-content-format` to explicitly specify this.",
"You can set `--chat-template-content-format` to override this.",
detected_format,
)

Expand Down Expand Up @@ -738,61 +811,6 @@ def parse_chat_messages_futures(
return conversation, mm_tracker.all_mm_data()


def _iter_nodes_define_message(chat_template_ast: jinja2.nodes.Template):
# Search for {%- for message in messages -%} loops
for loop_ast in chat_template_ast.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
if not (isinstance(loop_iter, jinja2.nodes.Name)
and loop_iter.ctx == "load" and loop_iter.name == "messages"):
continue

loop_target = loop_ast.target
if not isinstance(loop_target, jinja2.nodes.Name):
continue

yield loop_ast, loop_target.name


def _iter_nodes_define_content_item(chat_template_ast: jinja2.nodes.Template):
for node, message_varname in _iter_nodes_define_message(chat_template_ast):
# Search for {%- for content in message['content'] -%} loops
for loop_ast in node.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
if not (isinstance(loop_iter, jinja2.nodes.Getitem)
and loop_iter.ctx == "load"):
continue

getitem_src = loop_iter.node
if not (isinstance(getitem_src, jinja2.nodes.Name)
and getitem_src.ctx == "load"
and getitem_src.name == message_varname):
continue

getitem_idx = loop_iter.arg
if not (isinstance(getitem_idx, jinja2.nodes.Const)
and getitem_idx.value == "content"):
continue

loop_target = loop_ast.target
if not isinstance(loop_target, jinja2.nodes.Name):
continue

yield loop_iter, loop_target.name


def _detect_chat_template_content_format(
chat_template: str) -> _ChatTemplateContentFormat:
jinjacompiled = hf_chat_utils._compile_jinja_template(chat_template)
jinja_ast = jinjacompiled.environment.parse(chat_template)

try:
next(_iter_nodes_define_content_item(jinja_ast))
except StopIteration:
return "string"
else:
return "openai"


def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: List[ConversationMessage],
Expand Down
26 changes: 15 additions & 11 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,9 +526,11 @@ def chat(
to the OpenAI API.
Args:
messages: A list of conversations or a single conversation.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
messages: A list of conversations or a single conversation.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
Expand All @@ -539,16 +541,18 @@ def chat(
chat_template: The template to use for structuring the chat.
If not provided, the model's default chat template will be used.
chat_template_content_format: The format to render message content.
- "string" will render the content as a string.
Example: "Hello World"
- "openai" will render the content as a list of dictionaries,
similar to OpenAI schema.
Example: [{"type": "text", "text": "Hello world!"}]
- "string" will render the content as a string.
Example: ``"Who are you?"``
- "openai" will render the content as a list of dictionaries,
similar to OpenAI schema.
Example: ``[{"type": "text", "text": "Who are you?"}]``
add_generation_prompt: If True, adds a generation template
to each message.
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be `True`
if `add_generation_prompt` is also `True`.
the conversation instead of starting a new one. Cannot be
``True`` if ``add_generation_prompt`` is also ``True``.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
Expand Down Expand Up @@ -739,7 +743,7 @@ def encode(
generation, if any.
Returns:
A list of `EmbeddingRequestOutput` objects containing the
A list of ``EmbeddingRequestOutput`` objects containing the
generated embeddings in the same order as the input prompts.
Note:
Expand Down

0 comments on commit 0410d9f

Please sign in to comment.