Skip to content

Commit

Permalink
[CI/Build] Update pixtral tests to use JSON (#8436)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Sep 13, 2024
1 parent 3f79bc3 commit 8427550
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 18 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ exclude = [

[tool.codespell]
ignore-words-list = "dout, te, indicies, subtile"
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"

[tool.isort]
use_parentheses = true
Expand Down
1 change: 1 addition & 0 deletions tests/models/fixtures/pixtral_chat.json

Large diffs are not rendered by default.

Binary file removed tests/models/fixtures/pixtral_chat.pickle
Binary file not shown.
1 change: 1 addition & 0 deletions tests/models/fixtures/pixtral_chat_engine.json

Large diffs are not rendered by default.

Binary file removed tests/models/fixtures/pixtral_chat_engine.pickle
Binary file not shown.
56 changes: 39 additions & 17 deletions tests/models/test_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
Run `pytest tests/models/test_mistral.py`.
"""
import pickle
import json
import uuid
from typing import Any, Dict, List
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Tuple

import pytest
from mistral_common.protocol.instruct.messages import ImageURLChunk
Expand All @@ -14,6 +15,7 @@

from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
from vllm.multimodal import MultiModalDataBuiltins
from vllm.sequence import Logprob, SampleLogprobs

from .utils import check_logprobs_close

Expand Down Expand Up @@ -81,13 +83,33 @@ def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
LIMIT_MM_PER_PROMPT = dict(image=4)

MAX_MODEL_LEN = [8192, 65536]
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle"
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle"
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.json"
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.json"

OutputsLogprobs = List[Tuple[List[int], str, Optional[SampleLogprobs]]]

def load_logprobs(filename: str) -> Any:
with open(filename, 'rb') as f:
return pickle.load(f)

# For the test author to store golden output in JSON
def _dump_outputs_w_logprobs(outputs: OutputsLogprobs, filename: str) -> None:
json_data = [(tokens, text,
[{k: asdict(v)
for k, v in token_logprobs.items()}
for token_logprobs in (logprobs or [])])
for tokens, text, logprobs in outputs]

with open(filename, "w") as f:
json.dump(json_data, f)


def load_outputs_w_logprobs(filename: str) -> OutputsLogprobs:
with open(filename, "rb") as f:
json_data = json.load(f)

return [(tokens, text,
[{int(k): Logprob(**v)
for k, v in token_logprobs.items()}
for token_logprobs in logprobs])
for tokens, text, logprobs in json_data]


@pytest.mark.skip(
Expand All @@ -103,7 +125,7 @@ def test_chat(
model: str,
dtype: str,
) -> None:
EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT)
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT)
with vllm_runner(
model,
dtype=dtype,
Expand All @@ -120,10 +142,10 @@ def test_chat(
outputs.extend(output)

logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
check_logprobs_close(outputs_0_lst=logprobs,
outputs_1_lst=EXPECTED_CHAT_LOGPROBS,
name_0="output",
name_1="h100_ref")
check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
outputs_1_lst=logprobs,
name_0="h100_ref",
name_1="output")


@pytest.mark.skip(
Expand All @@ -133,7 +155,7 @@ def test_chat(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE)
EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_ENGINE)
args = EngineArgs(
model=model,
tokenizer_mode="mistral",
Expand Down Expand Up @@ -162,7 +184,7 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
break

logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
check_logprobs_close(outputs_0_lst=logprobs,
outputs_1_lst=EXPECTED_ENGINE_LOGPROBS,
name_0="output",
name_1="h100_ref")
check_logprobs_close(outputs_0_lst=EXPECTED_ENGINE_LOGPROBS,
outputs_1_lst=logprobs,
name_0="h100_ref",
name_1="output")

0 comments on commit 8427550

Please sign in to comment.