Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UnitTest][NVPTX] Avoid cascading failures from CUDA postproc #15136

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 51 additions & 46 deletions tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np

import tvm
import tvm.testing
from tvm.script import tir as T

import pytest
import numpy as np


def count_cp_async(stmt):
num_alloc = [0]
Expand Down Expand Up @@ -351,36 +354,54 @@ def test_inject_async_copy_shared_dyn():
"""


generated_code = ""
support_async = True
@pytest.fixture
def postproc_if_missing_async_support():
arch = tvm.contrib.nvcc.get_target_compute_version()
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
support_async = major >= 8

func_name = "tvm_callback_cuda_postproc"
prev_postproc = tvm.get_global_func(func_name, allow_missing=True)

# Store the generated code prior to the post-processing. This
# way, even though the generated code doesn't compile on platforms
# that do not support async, the comparison against an expected
# output can still be performed. We cannot use
# `mod.get_source()`, as that contains the source after all
# post-processing.
original_code = None

def get_original_code():
nonlocal original_code
return original_code

@tvm.register_func(func_name, override=True)
def tvm_callback_cuda_postproc(code, _):
nonlocal original_code
original_code = code
if support_async:
return code
else:
ret = []
for line in code.split("\n"):
ret.append(line)
ret.append("\n")
if line.startswith('extern "C" __global__') and line.endswith("{"):
break
ret.append("}")
return "".join(ret)

yield get_original_code

@tvm.register_func
def tvm_callback_cuda_postproc(code, _):
global generated_code
global support_async
generated_code = code
# return a dummy code so that device < sm80 could build correctly
if not support_async:
ret = ""
for line in code.split("\n"):
ret += line + "\n"
if line.startswith('extern "C" __global__'):
break
ret += "}"
return ret
return code
# Restore previous postproc func to avoid impacting other tests
if prev_postproc is None:
tvm._ffi.registry.remove_global_func(func_name)
else:
tvm.register_func(func_name, prev_postproc, override=True)


@tvm.testing.requires_cuda
def test_cp_async_in_if_then_else():
global support_async
arch = tvm.contrib.nvcc.get_target_compute_version()
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 8:
# At least sm80 is required
support_async = False

def test_cp_async_in_if_then_else(postproc_if_missing_async_support):
@T.prim_func
def simple_compute(
A: T.Buffer((16, 14), "float32"),
Expand Down Expand Up @@ -422,22 +443,12 @@ def simple_compute(
mod = tvm.IRModule.from_expr(simple_compute)
with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
tvm.build(mod, target="cuda")
generated_code = postproc_if_missing_async_support()
assert generated_code == expected_cuda_script

if not support_async:
# avoid return dummy code to other tests
support_async = True


@tvm.testing.requires_cuda
def test_vectorize_cp_async_in_if_then_else():
global support_async
arch = tvm.contrib.nvcc.get_target_compute_version()
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 8:
# At least sm80 is required
support_async = False

def test_vectorize_cp_async_in_if_then_else(postproc_if_missing_async_support):
@T.prim_func
def complex_compute(
A: T.Buffer((2, 16, 16, 1280), "float16"),
Expand Down Expand Up @@ -887,16 +898,10 @@ def complex_compute(
mod = tvm.IRModule.from_expr(complex_compute)
with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
tvm.build(mod, target="cuda")
generated_code = postproc_if_missing_async_support()
# generated_code must contain " setp.ne.b32 p, %0, 0;"
assert "setp.ne.b32" in generated_code

if not support_async:
# avoid return dummy code to other tests
support_async = True


if __name__ == "__main__":
test_inject_async_copy()
test_inject_async_copy_shared_dyn()
test_cp_async_in_if_then_else()
test_vectorize_cp_async_in_if_then_else()
tvm.testing.main()