Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Improve Nullable kv Arg Parsing #8525

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion tests/engine/test_arg_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from argparse import ArgumentTypeError

import pytest

from vllm.engine.arg_utils import EngineArgs
from vllm.engine.arg_utils import EngineArgs, nullable_kvs
from vllm.utils import FlexibleArgumentParser


Expand All @@ -13,6 +15,10 @@
"image": 16,
"video": 2
}),
("Image=16, Video=2", {
"image": 16,
"video": 2
}),
])
def test_limit_mm_per_prompt_parser(arg, expected):
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
Expand All @@ -22,3 +28,15 @@ def test_limit_mm_per_prompt_parser(arg, expected):
args = parser.parse_args(["--limit-mm-per-prompt", arg])

assert args.limit_mm_per_prompt == expected


@pytest.mark.parametrize(
("arg"),
[
"image", # Missing =
"image=4,image=5", # Conflicting values
"image=video=4" # Too many = in tokenized arg
])
def test_bad_nullable_kvs(arg):
with pytest.raises(ArgumentTypeError):
nullable_kvs(arg)
28 changes: 21 additions & 7 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,36 @@ def nullable_str(val: str):


def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
"""Parses a string containing comma separate key [str] to value [int]
pairs into a dictionary.

Args:
val: String value to be parsed.

Returns:
Dictionary with parsed values.
"""
if len(val) == 0:
return None

out_dict: Dict[str, int] = {}
for item in val.split(","):
try:
key, value = item.split("=")
except TypeError as exc:
msg = "Each item should be in the form KEY=VALUE"
raise ValueError(msg) from exc
kv_parts = [part.lower().strip() for part in item.split("=")]
if len(kv_parts) != 2:
raise argparse.ArgumentTypeError(
"Each item should be in the form KEY=VALUE")
key, value = kv_parts

try:
out_dict[key] = int(value)
parsed_value = int(value)
except ValueError as exc:
msg = f"Failed to parse value of item {key}={value}"
raise ValueError(msg) from exc
raise argparse.ArgumentTypeError(msg) from exc

if key in out_dict and out_dict[key] != parsed_value:
raise argparse.ArgumentTypeError(
f"Conflicting values specified for key: {key}")
out_dict[key] = parsed_value

return out_dict

Expand Down
Loading