diff --git a/tests/kernels/test_cascade_flash_attn.py b/tests/kernels/test_cascade_flash_attn.py index c9c4eee3b4fe7..8edfde42ede74 100755 --- a/tests/kernels/test_cascade_flash_attn.py +++ b/tests/kernels/test_cascade_flash_attn.py @@ -93,9 +93,9 @@ def test_cascade( fa_version: int, ) -> None: torch.set_default_device("cuda") - if is_fa_version_supported(fa_version): - pytest.skip("Flash attention version not supported due to: " + \ - fa_version_unsupported_reason(fa_version)) + if not is_fa_version_supported(fa_version): + pytest.skip(f"Flash attention version {fa_version} not supported due " + f"to: \"{fa_version_unsupported_reason(fa_version)}\"") current_platform.seed_everything(0) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 401449f7e61d2..0ee0bf6c6a374 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -97,8 +97,8 @@ def test_flash_attn_with_paged_kv( fa_version: int, ) -> None: torch.set_default_device("cuda") - if is_fa_version_supported(fa_version): - pytest.skip(f"Flash attention version {fa_version} not supported due " + if not is_fa_version_supported(fa_version): + pytest.skip(f"Flash attention version {fa_version} not supported due " f"to: \"{fa_version_unsupported_reason(fa_version)}\"") current_platform.seed_everything(0) @@ -183,8 +183,8 @@ def test_varlen_with_paged_kv( fa_version: int, ) -> None: torch.set_default_device("cuda") - if is_fa_version_supported(fa_version): - pytest.skip(f"Flash attention version {fa_version} not supported due " + if not is_fa_version_supported(fa_version): + pytest.skip(f"Flash attention version {fa_version} not supported due " f"to: \"{fa_version_unsupported_reason(fa_version)}\"") current_platform.seed_everything(0) num_seqs = len(seq_lens)