diff --git a/python/tvm/testing/tir.py b/python/tvm/testing/tir.py index 53e082c9ec8e..8dd482673829 100644 --- a/python/tvm/testing/tir.py +++ b/python/tvm/testing/tir.py @@ -79,7 +79,7 @@ def mma_schedule( mma_store_intrin, shared_scope="shared", ): - """Create a tensorized schedule for 4k GEMM with MMA intrinsics.""" + """Create a tensorized schedule for GEMM with MMA intrinsics.""" ir_module = tvm.IRModule({"main": workload}) sch = tvm.tir.Schedule(ir_module) diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index f89dbacee33d..fddda05eb5b0 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -1034,10 +1034,16 @@ def test_error_missing_annotation(): _check_error(simple_compute_missing_annotation) +@tvm.testing.requires_cuda def test_three_stage_gemm(): N = K = M = 4096 i_factors, j_factors, k_factors = [4, 8, 2, 4, 1], [1, 64, 2, 1, 2], [128, 2, 1] + def is_ampere_or_newer(): + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + return major >= 8 + def index_map(i, j): return ( i // 16, @@ -1071,17 +1077,18 @@ def index_map(i, j): sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 3]) sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) - f = tvm.build(sch.mod["main"], target="cuda") - - dev = tvm.device("cuda", 0) - a_np = np.random.uniform(size=(N, K)).astype("float16") - b_np = np.random.uniform(size=(K, M)).astype("float16") - c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev) - f(a, b, c) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) + if is_ampere_or_newer(): + f = tvm.build(sch.mod["main"], target="cuda") + + dev = tvm.device("cuda", 0) + a_np = np.random.uniform(size=(N, K)).astype("float16") + b_np = np.random.uniform(size=(K, M)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev) + f(a, b, c) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) if __name__ == "__main__":