Skip to content

Commit

Permalink
[2/2] Using xfail instead of skip for ROCm 6.2 tests (vllm-project#70)
Browse files Browse the repository at this point in the history
* Enabling some basic tests for ROCm 6.2

Use strict xfail for ROCm 6.2 test repairs

* Use lenient xfail instead

---------

Co-authored-by: Alexei V. Ivanov <[email protected]>
  • Loading branch information
mawong-amd and Alexei-V-Ivanov-AMD authored Jun 28, 2024
1 parent 596d58c commit cce6281
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tests/core/block/e2e/test_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

from vllm import SamplingParams

from ....test_utils import xfail_if_rocm62
from .conftest import get_token_ids_from_llm_generator


@xfail_if_rocm62
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down Expand Up @@ -79,6 +81,7 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
assert baseline_token_ids == test_token_ids


@xfail_if_rocm62
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down Expand Up @@ -140,6 +143,7 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
assert baseline_token_ids == test_token_ids


@xfail_if_rocm62
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down Expand Up @@ -232,6 +236,7 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
assert baseline_token_ids == test_token_ids


@xfail_if_rocm62
@pytest.mark.parametrize(
"common_llm_kwargs",
[
Expand Down Expand Up @@ -302,6 +307,7 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
assert baseline_token_ids == test_token_ids


@xfail_if_rocm62
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down Expand Up @@ -377,6 +383,7 @@ def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption(
assert baseline_token_ids == test_token_ids


@xfail_if_rocm62
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down
3 changes: 3 additions & 0 deletions tests/core/block/e2e/test_correctness_sliding_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

from vllm import LLM, SamplingParams

from ....test_utils import xfail_if_rocm62
from .conftest import get_text_from_llm_generator

# relatively small model with 4k sliding window
MODEL = "bigcode/starcoder2-3b"
BLOCK_SIZE = 16


@xfail_if_rocm62
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down Expand Up @@ -73,6 +75,7 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
assert sum(cmp) > 0.7 * len(cmp)


@xfail_if_rocm62
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down
7 changes: 7 additions & 0 deletions tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams

from ..test_utils import xfail_if_rocm62

MODELS = [
"facebook/opt-125m",
]


@xfail_if_rocm62
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [128])
Expand Down Expand Up @@ -46,6 +49,7 @@ def test_metric_counter_prompt_tokens(
f"metric: {metric_count!r}")


@xfail_if_rocm62
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [128])
Expand Down Expand Up @@ -78,6 +82,7 @@ def test_metric_counter_generation_tokens(
f"metric: {metric_count!r}")


@xfail_if_rocm62
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -106,6 +111,7 @@ def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str,
f"actual: {metrics_tag_content!r}")


@xfail_if_rocm62
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [4])
Expand Down Expand Up @@ -141,6 +147,7 @@ async def test_async_engine_log_metrics_regression(
len(example_prompts))


@xfail_if_rocm62
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [4])
Expand Down
17 changes: 17 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,20 @@ def dummy(*, old_arg: object = None, new_arg: object = None):

with pytest.warns(DeprecationWarning, match="abcd"):
dummy(old_arg=1)


def is_rocm62():
import torch
return isinstance(torch.version.hip,
str) and torch.version.hip.startswith("6.2")


def xfail_if_rocm62(function=None,
reason: str = "Tests are not yet ready for ROCm 6.2",
strict: bool = False):
if function:
return pytest.mark.xfail(is_rocm62(), reason=reason,
strict=strict)(function)
else:
assert callable(function)
return pytest.mark.xfail(is_rocm62(), reason=reason, strict=strict)
3 changes: 3 additions & 0 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from vllm.utils import get_open_port
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size

from ..test_utils import xfail_if_rocm62


def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
engine_args = EngineArgs(model, *args, **kwargs)
Expand Down Expand Up @@ -138,6 +140,7 @@ def test_prepare_prompt(batch_size):
torch.testing.assert_close(actual, expected)


@xfail_if_rocm62
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size):
model_runner = _create_model_runner(
Expand Down
3 changes: 3 additions & 0 deletions tests/worker/test_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.worker import Worker

from ..test_utils import xfail_if_rocm62


@xfail_if_rocm62
def test_swap() -> None:
# Configure the engine.
engine_args = EngineArgs(model="facebook/opt-125m",
Expand Down

0 comments on commit cce6281

Please sign in to comment.