From 012c8e46ce4d89ae5794d023234d2bba28c59737 Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Sun, 12 Feb 2023 21:14:13 +0800 Subject: [PATCH 01/21] support async copy for if_then_else --- src/target/source/ptx.cc | 35 +++++ src/target/source/ptx.h | 16 +++ src/tir/transforms/inject_ptx_async_copy.cc | 151 ++++++++++++-------- 3 files changed, 139 insertions(+), 63 deletions(-) diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index 886242efe08c..035d466bb6ed 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -659,5 +659,40 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr, return asm_code; } +std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, + const std::string& bytes, + const std::string& predicate_value) { + std::string predicated_asm_code = R"( + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)({smem_addr})) + ); + __asm__ __volatile__( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %0, 0;\n" + "\t@p cp.async.{cg_or_ca}.shared.global [%1], [%2], %3;\n" + "}\n" + :: "r"((int){pred_guard}), + "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}) + ); + } +)"; + Replacer replacer; + replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); + replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); + replacer.register_rule("{bytes}", bytes); + replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca"); + replacer.register_rule("{pred_guard}", predicate_value); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + } // namespace codegen } // namespace tvm diff --git a/src/target/source/ptx.h b/src/target/source/ptx.h index c811a1b9c1d6..1e49b57c1790 100644 --- a/src/target/source/ptx.h +++ b/src/target/source/ptx.h @@ -92,6 +92,22 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr, const std::string& global_ptr, const std::string& global_elem_offset, const std::string& bytes); +/*! + * \brief Print predicated ptx cp.async assembly string given parameters. + * \param shared_ptr: The pointer to the destination shared memory. + * \param shared_elem_offset: The offset into the shared memory. + * \param global_ptr: The pointer to the global memory. + * \param global_elem_offset: The offset into the global memory. + * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16. + * \param predicate_value: The value of predicate `@p`. + */ +std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, + const std::string& bytes, + const std::string& predicate_value); + } // namespace codegen } // namespace tvm diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index 8ee0d054e56d..ebd1728bf5f4 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -47,74 +47,99 @@ class PTXAsyncCopyInjector : public StmtMutator { return StmtMutator::VisitStmt_(attr); } - Stmt VisitStmt_(const BufferStoreNode* store) { - if (in_async && (store->buffer.scope() == "shared" || store->buffer.scope() == "shared.dyn")) { - if (auto* load = store->value.as()) { - if (load->buffer.scope() == "global") { - ICHECK(load->indices.size() == 1 && store->indices.size() == 1); - ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes()); - - const int indices_lanes = load->indices[0]->dtype.lanes(); - const int bytes = indices_lanes * load->buffer->dtype.bytes(); - - if (bytes == 4 || bytes == 8 || bytes == 16) { - auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation); - auto src_elem_type = GetPointerType(load->buffer->data->type_annotation); - ICHECK(dst_elem_type.has_value() && src_elem_type.has_value()) - << "Both store and load buffer should have a pointer type annotation."; - - int index_factor = 1; - if (dst_elem_type.value() != src_elem_type.value()) { - // The only case where src and dst have different dtypes is when the dst shared memory - // is a byte buffer generated by merging dynamic shared memory. - ICHECK(store->buffer.scope() == "shared.dyn"); - ICHECK(dst_elem_type.value() == DataType::UInt(8)); - // BufferStore/Load have the "pointer reinterpret" semantics according to their - // "value" dtype. Their "indices" are supposed to be applied after such pointer cast, - // for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value; - // To replace BufferStore/Load with cp.async, we need to multiply the store index by - // the byte size of the "value" dtype, to get the correct offset into the byte buffer. - index_factor = src_elem_type->bytes(); - } + Stmt injectPTX(const BufferLoadNode* load, const BufferStoreNode* store, bool predicated = false, PrimExpr predicate_value = PrimExpr()) { + if (load->buffer.scope() == "global") { + ICHECK(load->indices.size() == 1 && store->indices.size() == 1); + // std::cout << "[BufferLoadNode]: " << load->indices << " " << store->indices << std::endl; + ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes()); + + const int indices_lanes = load->indices[0]->dtype.lanes(); + const int bytes = indices_lanes * load->buffer->dtype.bytes(); + + if (bytes == 4 || bytes == 8 || bytes == 16) { + auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation); + auto src_elem_type = GetPointerType(load->buffer->data->type_annotation); + ICHECK(dst_elem_type.has_value() && src_elem_type.has_value()) + << "Both store and load buffer should have a pointer type annotation."; + + int index_factor = 1; + if (dst_elem_type.value() != src_elem_type.value()) { + // The only case where src and dst have different dtypes is when the dst shared memory + // is a byte buffer generated by merging dynamic shared memory. + ICHECK(store->buffer.scope() == "shared.dyn"); + ICHECK(dst_elem_type.value() == DataType::UInt(8)); + // BufferStore/Load have the "pointer reinterpret" semantics according to their + // "value" dtype. Their "indices" are supposed to be applied after such pointer cast, + // for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value; + // To replace BufferStore/Load with cp.async, we need to multiply the store index by + // the byte size of the "value" dtype, to get the correct offset into the byte buffer. + index_factor = src_elem_type->bytes(); + } - if (indices_lanes == 1) { - auto src_offset = load->indices[0]; - auto dst_offset = store->indices[0]; - return Evaluate( - Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), - {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), - load->buffer->data, src_offset, PrimExpr(bytes)})); + if (indices_lanes == 1) { + auto src_offset = load->indices[0]; + auto dst_offset = store->indices[0]; + Array args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes)}; + if (predicated) args.push_back(predicate_value); + return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), args)); + } + + // Predicated load don't support vectorized indexing. + if (!predicated) { + // Only some vectorized indexing patterns are supported for now. + auto src_offset = [=]() -> PrimExpr { + if (load->indices[0]->IsInstance()) { + return load->indices[0].as()->base; + } + return PrimExpr(); + }(); + + auto dst_offset = [=]() -> PrimExpr { + if (store->indices[0].as()) { + return store->indices[0].as()->base; + } else if (store->indices[0].as()) { + // The case where the dst buffer is a byte buffer generated by merging dynamic + // shared memory. + // A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)] + auto* add = store->indices[0].as(); + if (!add->a->IsInstance()) return PrimExpr(); + if (!add->b->IsInstance()) return PrimExpr(); + return tir::Add(add->a.as()->base, add->b.as()->value); } + return PrimExpr(); + }(); + + if (src_offset.defined() && dst_offset.defined()) { + Array args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes)}; + if (predicated) args.push_back(predicate_value); + return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), args)); + } + } + } + } + return StmtMutator::VisitStmt_(store); + } - // Only some vectorized indexing patterns are supported for now. - auto src_offset = [=]() -> PrimExpr { - if (load->indices[0]->IsInstance()) { - return load->indices[0].as()->base; - } - return PrimExpr(); - }(); - - auto dst_offset = [=]() -> PrimExpr { - if (store->indices[0].as()) { - return store->indices[0].as()->base; - } else if (store->indices[0].as()) { - // The case where the dst buffer is a byte buffer generated by merging dynamic - // shared memory. - // A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)] - auto* add = store->indices[0].as(); - if (!add->a->IsInstance()) return PrimExpr(); - if (!add->b->IsInstance()) return PrimExpr(); - return tir::Add(add->a.as()->base, add->b.as()->value); + Stmt VisitStmt_(const BufferStoreNode* store) { + if (in_async && (store->buffer.scope() == "shared" || store->buffer.scope() == "shared.dyn")) { + if (auto* load = store->value.as()) { + return injectPTX(load, store); + } else if (auto* call = store->value.as()) { + // tir.if_then_else is a call to tir::builtin::if_then_else( + if (call->op.same_as(builtin::if_then_else()) && call->args.size() == 3) { + if (auto* load = call->args[1].as()) { + bool else_value_is_zero = false; + if (auto* b = call->args[2].as()) { + if (auto* f = b->value.as()) { + else_value_is_zero = f->value == 0.0f; } - return PrimExpr(); - }(); - - if (src_offset.defined() && dst_offset.defined()) { - return Evaluate( - Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), - {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), - load->buffer->data, src_offset, PrimExpr(bytes)})); } + if (auto* f = call->args[2].as()) { + else_value_is_zero = f->value == 0.0f; + } + if (else_value_is_zero) return injectPTX(load, store, true, call->args[0]); } } } From d700b21472f5379f1c59bb3e2a6175da16361849 Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Sun, 12 Feb 2023 21:22:55 +0800 Subject: [PATCH 02/21] add comment and trigger PrintPredicatedCpAsyncAssembly in codegen_cuda.cc --- src/target/source/codegen_cuda.cc | 7 ++++++- src/tir/transforms/inject_ptx_async_copy.cc | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index c891ec5a28cf..4cee77641e3e 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -914,7 +914,12 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string src = this->PrintExpr(op->args[2]); std::string src_offset = this->PrintExpr(op->args[3]); std::string size = this->PrintExpr(op->args[4]); - this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); + // use size of argument list to indicate whether or not to use predicated cp.async + if (op->args.size() == 5) + this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); + else + this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src, src_offset, + size, this->PrintExpr(op->args[5])); } else if (op->op.same_as(builtin::ptx_commit_group())) { this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n"; } else if (op->op.same_as(builtin::ptx_wait_group())) { diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index ebd1728bf5f4..9e5f31487c89 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -81,6 +81,7 @@ class PTXAsyncCopyInjector : public StmtMutator { auto dst_offset = store->indices[0]; Array args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)}; + // use arguments size to indicate whether or not to use predicated cp.async if (predicated) args.push_back(predicate_value); return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), args)); } @@ -113,7 +114,6 @@ class PTXAsyncCopyInjector : public StmtMutator { if (src_offset.defined() && dst_offset.defined()) { Array args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)}; - if (predicated) args.push_back(predicate_value); return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), args)); } } From f055ae549661b2ac2082d6e7fea20443f275f63e Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Mon, 13 Feb 2023 00:49:43 +0800 Subject: [PATCH 03/21] reformat code --- src/tir/transforms/inject_ptx_async_copy.cc | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index 9e5f31487c89..b3d23df71ab2 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -47,10 +47,10 @@ class PTXAsyncCopyInjector : public StmtMutator { return StmtMutator::VisitStmt_(attr); } - Stmt injectPTX(const BufferLoadNode* load, const BufferStoreNode* store, bool predicated = false, PrimExpr predicate_value = PrimExpr()) { + Stmt injectPTX(const BufferLoadNode* load, const BufferStoreNode* store, + bool predicated = false, PrimExpr predicate_value = PrimExpr()) { if (load->buffer.scope() == "global") { ICHECK(load->indices.size() == 1 && store->indices.size() == 1); - // std::cout << "[BufferLoadNode]: " << load->indices << " " << store->indices << std::endl; ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes()); const int indices_lanes = load->indices[0]->dtype.lanes(); @@ -79,7 +79,8 @@ class PTXAsyncCopyInjector : public StmtMutator { if (indices_lanes == 1) { auto src_offset = load->indices[0]; auto dst_offset = store->indices[0]; - Array args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + Array args = {store->buffer->data, + tir::Mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)}; // use arguments size to indicate whether or not to use predicated cp.async if (predicated) args.push_back(predicate_value); @@ -112,9 +113,10 @@ class PTXAsyncCopyInjector : public StmtMutator { }(); if (src_offset.defined() && dst_offset.defined()) { - Array args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), - load->buffer->data, src_offset, PrimExpr(bytes)}; - return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), args)); + return Evaluate( + Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), + {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes)})); } } } @@ -127,7 +129,7 @@ class PTXAsyncCopyInjector : public StmtMutator { if (auto* load = store->value.as()) { return injectPTX(load, store); } else if (auto* call = store->value.as()) { - // tir.if_then_else is a call to tir::builtin::if_then_else( + // tir.if_then_else is a call to tir::builtin::if_then_else() if (call->op.same_as(builtin::if_then_else()) && call->args.size() == 3) { if (auto* load = call->args[1].as()) { bool else_value_is_zero = false; From 96224f2c047dc9c0b3049bb50346a1979566014d Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Mon, 13 Feb 2023 13:06:30 +0800 Subject: [PATCH 04/21] add zfill support & comment --- src/target/source/ptx.cc | 10 +++------- src/tir/transforms/inject_ptx_async_copy.cc | 3 +++ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index 035d466bb6ed..b5299b4e4b2a 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -673,14 +673,10 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, : "=r"(addr) : "l"((void *)({smem_addr})) ); + int src_bytes = {pred_guard} ? {bytes} : 0; __asm__ __volatile__( - "{\n" - "\t.reg .pred p;\n" - "\tsetp.ne.b32 p, %0, 0;\n" - "\t@p cp.async.{cg_or_ca}.shared.global [%1], [%2], %3;\n" - "}\n" - :: "r"((int){pred_guard}), - "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}) + "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;" + :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes) ); } )"; diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index b3d23df71ab2..fbced7e180a9 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -132,6 +132,9 @@ class PTXAsyncCopyInjector : public StmtMutator { // tir.if_then_else is a call to tir::builtin::if_then_else() if (call->op.same_as(builtin::if_then_else()) && call->args.size() == 3) { if (auto* load = call->args[1].as()) { + // Only default value of 0 is supported since 0 is the default value used by cp.async ptx. + // @see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-memory-operations + // section 9.7.8.22.3. bool else_value_is_zero = false; if (auto* b = call->args[2].as()) { if (auto* f = b->value.as()) { From d22a8ca8b9a4223e3982b0473f9c0a0c0a8cff63 Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Mon, 13 Feb 2023 16:29:18 +0000 Subject: [PATCH 05/21] add unittest --- .../unittest/test_cp_async_in_if_then_else.py | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 tests/python/unittest/test_cp_async_in_if_then_else.py diff --git a/tests/python/unittest/test_cp_async_in_if_then_else.py b/tests/python/unittest/test_cp_async_in_if_then_else.py new file mode 100644 index 000000000000..9ae2df3f3db7 --- /dev/null +++ b/tests/python/unittest/test_cp_async_in_if_then_else.py @@ -0,0 +1,94 @@ +import tvm +import numpy as np +from tvm.script import tir as T + +@tvm.script.ir_module +class Module: + @T.prim_func + def main(A: T.Buffer[(1012, 1014), "float32"], B: T.Buffer[(1014, 1017), "float32"], Y: T.Buffer[(1012, 1017), "float32"]): + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + Y_reindex_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local") + A_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") + B_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") + A_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local") + B_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local") + for ax0_0_ax1_0_fused in T.thread_binding(128, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":1024, "pragma_unroll_explicit":1}): + for ax0_1_ax1_1_fused in T.thread_binding(4, thread="vthread.x"): + for ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused in T.thread_binding(64, thread="threadIdx.x"): + for ax0_3_init, ax1_3_init, ax0_4_init, ax1_4_init in T.grid(4, 4, 2, 1): + with T.block("Y_init"): + v0 = T.axis.spatial(1024, ax0_0_ax1_0_fused // 8 * 64 + ax0_1_ax1_1_fused // 2 * 32 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + ax0_3_init * 2 + ax0_4_init) + v1 = T.axis.spatial(1024, ax1_4_init + ax0_0_ax1_0_fused % 8 * 128 + ax0_1_ax1_1_fused % 2 * 64 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 + ax1_3_init) + T.reads() + T.writes(Y_reindex_local[v0, v1]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + Y_reindex_local[v0, v1] = T.float32(0) + for ax2_0_fused in T.serial(256, annotations={"software_pipeline_async_stages":[0, 1], "software_pipeline_order":[0, 1, 3, 2, 4], "software_pipeline_stage":[0, 0, 2, 3, 3]}): + for ax0_ax1_fused_0 in T.serial(4): + for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("A_reindex_shared"): + v0 = T.axis.spatial(1024, ax0_0_ax1_0_fused // 8 * 64 + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) // 4) + v1 = T.axis.spatial(1024, ax2_0_fused * 4 + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) % 4) + T.reads(A[v0, v1]) + T.writes(A_reindex_shared[v1, v0 // 32 * 32 + v0 % 8 // 4 * 16 + v0 % 32 // 8 * 4 + v0 % 4]) + A_reindex_shared[v1, v0 // 32 * 32 + v0 % 8 // 4 * 16 + v0 % 32 // 8 * 4 + v0 % 4] = T.if_then_else(v0 < 1012 and v1 < 1014, A[v0, v1], T.float32(0), dtype="float32") + for ax0_ax1_fused_0 in T.serial(8): + for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("B_reindex_shared"): + v0 = T.axis.spatial(1024, ax2_0_fused * 4 + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) // 128) + v1 = T.axis.spatial(1024, ax0_0_ax1_0_fused % 8 * 128 + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) % 128) + T.reads(B[v0, v1]) + T.writes(B_reindex_shared[v0, v1 // 64 * 64 + v1 % 8 // 4 * 32 + v1 % 64 // 8 * 4 + v1 % 4]) + B_reindex_shared[v0, v1 // 64 * 64 + v1 % 8 // 4 * 32 + v1 % 64 // 8 * 4 + v1 % 4] = T.if_then_else(v0 < 1014 and v1 < 1017, B[v0, v1], T.float32(0), dtype="float32") + for ax2_1_fused in T.unroll(4, annotations={"software_pipeline_order":[0, 1, 2], "software_pipeline_stage":[0, 0, 1]}): + for ax0_ax1_fused_0 in T.unroll(2): + for ax0_ax1_fused_1 in T.vectorized(4): + with T.block("A_reindex_shared_local"): + v0 = T.axis.spatial(1024, ax2_0_fused * 4 + ax2_1_fused) + v1 = T.axis.spatial(1024, ax0_0_ax1_0_fused // 8 * 64 + ax0_1_ax1_1_fused // 2 * 32 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + ax0_ax1_fused_0 * 4 + ax0_ax1_fused_1) + T.reads(A_reindex_shared[v0, v1 // 32 * 32 + v1 % 8 // 4 * 16 + v1 % 32 // 8 * 4 + v1 % 4]) + T.writes(A_reindex_shared_local[v0, v1]) + A_reindex_shared_local[v0, v1] = A_reindex_shared[v0, v1 // 32 * 32 + v1 % 8 // 4 * 16 + v1 % 32 // 8 * 4 + v1 % 4] + for ax0_ax1_fused_0 in T.unroll(2): + for ax0_ax1_fused_1 in T.vectorized(2): + with T.block("B_reindex_shared_local"): + v0 = T.axis.spatial(1024, ax2_0_fused * 4 + ax2_1_fused) + v1 = T.axis.spatial(1024, ax0_0_ax1_0_fused % 8 * 128 + ax0_1_ax1_1_fused % 2 * 64 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 + ax0_ax1_fused_0 * 2 + ax0_ax1_fused_1) + T.reads(B_reindex_shared[v0, v1 // 64 * 64 + v1 % 8 // 4 * 32 + v1 % 64 // 8 * 4 + v1 % 4]) + T.writes(B_reindex_shared_local[v0, v1]) + B_reindex_shared_local[v0, v1] = B_reindex_shared[v0, v1 // 64 * 64 + v1 % 8 // 4 * 32 + v1 % 64 // 8 * 4 + v1 % 4] + for ax0_3, ax1_3, ax2_2, ax0_4, ax1_4 in T.grid(4, 4, 1, 2, 1): + with T.block("Y_update"): + v0 = T.axis.spatial(1024, ax0_0_ax1_0_fused // 8 * 64 + ax0_1_ax1_1_fused // 2 * 32 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + ax0_3 * 2 + ax0_4) + v1 = T.axis.spatial(1024, ax1_4 + ax0_0_ax1_0_fused % 8 * 128 + ax0_1_ax1_1_fused % 2 * 64 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 + ax1_3) + v2 = T.axis.reduce(1024, ax2_0_fused * 4 + ax2_1_fused + ax2_2) + T.reads(Y_reindex_local[v0, v1], A_reindex_shared_local[v2, v0], B_reindex_shared_local[v2, v1]) + T.writes(Y_reindex_local[v0, v1]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + Y_reindex_local[v0, v1] = Y_reindex_local[v0, v1] + A_reindex_shared_local[v2, v0] * B_reindex_shared_local[v2, v1] + for ax0, ax1 in T.grid(8, 4): + with T.block("Y_reindex_local"): + T.where(ax0_0_ax1_0_fused // 8 * 64 + ax0_1_ax1_1_fused // 2 * 32 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + ax0 < 1012 and ax0_0_ax1_0_fused % 8 * 128 + ax0_1_ax1_1_fused % 2 * 64 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 + ax1 < 1017) + v0 = T.axis.spatial(1024, ax0_0_ax1_0_fused // 8 * 64 + ax0_1_ax1_1_fused // 2 * 32 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + ax0) + v1 = T.axis.spatial(1024, ax0_0_ax1_0_fused % 8 * 128 + ax0_1_ax1_1_fused % 2 * 64 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 + ax1) + T.reads(Y_reindex_local[v0, v1]) + T.writes(Y[v0, v1]) + Y[v0, v1] = Y_reindex_local[v0, v1] + +with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + rt_mod = tvm.build(Module, target="cuda") + +M, N, K = 1012, 1017, 1014 +a_tvm = tvm.nd.array(np.random.rand(M, K).astype("float32"), device=tvm.cuda(0)) +b_tvm = tvm.nd.array(np.random.rand(K, N).astype("float32"), device=tvm.cuda(0)) +c_tvm = tvm.nd.array(np.empty((M, N)).astype("float32"), device=tvm.cuda(0)) +rt_mod(a_tvm, b_tvm, c_tvm) + +time_f = rt_mod.time_evaluator(rt_mod.entry_name, dev=tvm.cuda(0), number=10) +time = time_f(a_tvm, b_tvm, c_tvm).mean + +flop = (M * N * K + M * N) * 2 +print("GFLOPS: %.2f" % (flop / time / 1e9)) From 516af3be6ac7bede562e4471fba6fca8c6f420a7 Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Tue, 14 Feb 2023 04:18:33 +0000 Subject: [PATCH 06/21] reformat unittest --- .../unittest/test_cp_async_in_if_then_else.py | 294 +++++++++++++++--- 1 file changed, 252 insertions(+), 42 deletions(-) diff --git a/tests/python/unittest/test_cp_async_in_if_then_else.py b/tests/python/unittest/test_cp_async_in_if_then_else.py index 9ae2df3f3db7..5fdecad9c15a 100644 --- a/tests/python/unittest/test_cp_async_in_if_then_else.py +++ b/tests/python/unittest/test_cp_async_in_if_then_else.py @@ -2,10 +2,15 @@ import numpy as np from tvm.script import tir as T + @tvm.script.ir_module class Module: @T.prim_func - def main(A: T.Buffer[(1012, 1014), "float32"], B: T.Buffer[(1014, 1017), "float32"], Y: T.Buffer[(1012, 1017), "float32"]): + def main( + A: T.Buffer[(1012, 1014), "float32"], + B: T.Buffer[(1014, 1017), "float32"], + Y: T.Buffer[(1012, 1017), "float32"], + ): # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body @@ -15,80 +20,285 @@ def main(A: T.Buffer[(1012, 1014), "float32"], B: T.Buffer[(1014, 1017), "float3 B_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") A_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local") B_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local") - for ax0_0_ax1_0_fused in T.thread_binding(128, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":1024, "pragma_unroll_explicit":1}): + for ax0_0_ax1_0_fused in T.thread_binding( + 128, + thread="blockIdx.x", + annotations={"pragma_auto_unroll_max_step": 1024, "pragma_unroll_explicit": 1}, + ): for ax0_1_ax1_1_fused in T.thread_binding(4, thread="vthread.x"): - for ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused in T.thread_binding(64, thread="threadIdx.x"): + for ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused in T.thread_binding( + 64, thread="threadIdx.x" + ): for ax0_3_init, ax1_3_init, ax0_4_init, ax1_4_init in T.grid(4, 4, 2, 1): with T.block("Y_init"): - v0 = T.axis.spatial(1024, ax0_0_ax1_0_fused // 8 * 64 + ax0_1_ax1_1_fused // 2 * 32 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + ax0_3_init * 2 + ax0_4_init) - v1 = T.axis.spatial(1024, ax1_4_init + ax0_0_ax1_0_fused % 8 * 128 + ax0_1_ax1_1_fused % 2 * 64 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 + ax1_3_init) + v0 = T.axis.spatial( + 1024, + ax0_0_ax1_0_fused // 8 * 64 + + ax0_1_ax1_1_fused // 2 * 32 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + + ax0_3_init * 2 + + ax0_4_init, + ) + v1 = T.axis.spatial( + 1024, + ax1_4_init + + ax0_0_ax1_0_fused % 8 * 128 + + ax0_1_ax1_1_fused % 2 * 64 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 + + ax1_3_init, + ) T.reads() T.writes(Y_reindex_local[v0, v1]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + T.block_attr( + { + "meta_schedule.thread_extent_high_inclusive": 1024, + "meta_schedule.thread_extent_low_inclusive": 32, + "meta_schedule.tiling_structure": "SSSRRSRS", + } + ) Y_reindex_local[v0, v1] = T.float32(0) - for ax2_0_fused in T.serial(256, annotations={"software_pipeline_async_stages":[0, 1], "software_pipeline_order":[0, 1, 3, 2, 4], "software_pipeline_stage":[0, 0, 2, 3, 3]}): + for ax2_0_fused in T.serial( + 256, + annotations={ + "software_pipeline_async_stages": [0, 1], + "software_pipeline_order": [0, 1, 3, 2, 4], + "software_pipeline_stage": [0, 0, 2, 3, 3], + }, + ): for ax0_ax1_fused_0 in T.serial(4): for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): with T.block("A_reindex_shared"): - v0 = T.axis.spatial(1024, ax0_0_ax1_0_fused // 8 * 64 + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) // 4) - v1 = T.axis.spatial(1024, ax2_0_fused * 4 + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) % 4) + v0 = T.axis.spatial( + 1024, + ax0_0_ax1_0_fused // 8 * 64 + + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) // 4, + ) + v1 = T.axis.spatial( + 1024, + ax2_0_fused * 4 + + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) % 4, + ) T.reads(A[v0, v1]) - T.writes(A_reindex_shared[v1, v0 // 32 * 32 + v0 % 8 // 4 * 16 + v0 % 32 // 8 * 4 + v0 % 4]) - A_reindex_shared[v1, v0 // 32 * 32 + v0 % 8 // 4 * 16 + v0 % 32 // 8 * 4 + v0 % 4] = T.if_then_else(v0 < 1012 and v1 < 1014, A[v0, v1], T.float32(0), dtype="float32") + T.writes( + A_reindex_shared[ + v1, + v0 // 32 * 32 + + v0 % 8 // 4 * 16 + + v0 % 32 // 8 * 4 + + v0 % 4, + ] + ) + A_reindex_shared[ + v1, + v0 // 32 * 32 + + v0 % 8 // 4 * 16 + + v0 % 32 // 8 * 4 + + v0 % 4, + ] = T.if_then_else( + v0 < 1012 and v1 < 1014, + A[v0, v1], + T.float32(0), + dtype="float32", + ) for ax0_ax1_fused_0 in T.serial(8): for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): with T.block("B_reindex_shared"): - v0 = T.axis.spatial(1024, ax2_0_fused * 4 + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) // 128) - v1 = T.axis.spatial(1024, ax0_0_ax1_0_fused % 8 * 128 + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) % 128) + v0 = T.axis.spatial( + 1024, + ax2_0_fused * 4 + + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) // 128, + ) + v1 = T.axis.spatial( + 1024, + ax0_0_ax1_0_fused % 8 * 128 + + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) % 128, + ) T.reads(B[v0, v1]) - T.writes(B_reindex_shared[v0, v1 // 64 * 64 + v1 % 8 // 4 * 32 + v1 % 64 // 8 * 4 + v1 % 4]) - B_reindex_shared[v0, v1 // 64 * 64 + v1 % 8 // 4 * 32 + v1 % 64 // 8 * 4 + v1 % 4] = T.if_then_else(v0 < 1014 and v1 < 1017, B[v0, v1], T.float32(0), dtype="float32") - for ax2_1_fused in T.unroll(4, annotations={"software_pipeline_order":[0, 1, 2], "software_pipeline_stage":[0, 0, 1]}): + T.writes( + B_reindex_shared[ + v0, + v1 // 64 * 64 + + v1 % 8 // 4 * 32 + + v1 % 64 // 8 * 4 + + v1 % 4, + ] + ) + B_reindex_shared[ + v0, + v1 // 64 * 64 + + v1 % 8 // 4 * 32 + + v1 % 64 // 8 * 4 + + v1 % 4, + ] = T.if_then_else( + v0 < 1014 and v1 < 1017, + B[v0, v1], + T.float32(0), + dtype="float32", + ) + for ax2_1_fused in T.unroll( + 4, + annotations={ + "software_pipeline_order": [0, 1, 2], + "software_pipeline_stage": [0, 0, 1], + }, + ): for ax0_ax1_fused_0 in T.unroll(2): for ax0_ax1_fused_1 in T.vectorized(4): with T.block("A_reindex_shared_local"): v0 = T.axis.spatial(1024, ax2_0_fused * 4 + ax2_1_fused) - v1 = T.axis.spatial(1024, ax0_0_ax1_0_fused // 8 * 64 + ax0_1_ax1_1_fused // 2 * 32 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + ax0_ax1_fused_0 * 4 + ax0_ax1_fused_1) - T.reads(A_reindex_shared[v0, v1 // 32 * 32 + v1 % 8 // 4 * 16 + v1 % 32 // 8 * 4 + v1 % 4]) + v1 = T.axis.spatial( + 1024, + ax0_0_ax1_0_fused // 8 * 64 + + ax0_1_ax1_1_fused // 2 * 32 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused + // 32 + * 16 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + + ax0_ax1_fused_0 * 4 + + ax0_ax1_fused_1, + ) + T.reads( + A_reindex_shared[ + v0, + v1 // 32 * 32 + + v1 % 8 // 4 * 16 + + v1 % 32 // 8 * 4 + + v1 % 4, + ] + ) T.writes(A_reindex_shared_local[v0, v1]) - A_reindex_shared_local[v0, v1] = A_reindex_shared[v0, v1 // 32 * 32 + v1 % 8 // 4 * 16 + v1 % 32 // 8 * 4 + v1 % 4] + A_reindex_shared_local[v0, v1] = A_reindex_shared[ + v0, + v1 // 32 * 32 + + v1 % 8 // 4 * 16 + + v1 % 32 // 8 * 4 + + v1 % 4, + ] for ax0_ax1_fused_0 in T.unroll(2): for ax0_ax1_fused_1 in T.vectorized(2): with T.block("B_reindex_shared_local"): v0 = T.axis.spatial(1024, ax2_0_fused * 4 + ax2_1_fused) - v1 = T.axis.spatial(1024, ax0_0_ax1_0_fused % 8 * 128 + ax0_1_ax1_1_fused % 2 * 64 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 + ax0_ax1_fused_0 * 2 + ax0_ax1_fused_1) - T.reads(B_reindex_shared[v0, v1 // 64 * 64 + v1 % 8 // 4 * 32 + v1 % 64 // 8 * 4 + v1 % 4]) + v1 = T.axis.spatial( + 1024, + ax0_0_ax1_0_fused % 8 * 128 + + ax0_1_ax1_1_fused % 2 * 64 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused + % 32 + // 2 + * 4 + + ax0_ax1_fused_0 * 2 + + ax0_ax1_fused_1, + ) + T.reads( + B_reindex_shared[ + v0, + v1 // 64 * 64 + + v1 % 8 // 4 * 32 + + v1 % 64 // 8 * 4 + + v1 % 4, + ] + ) T.writes(B_reindex_shared_local[v0, v1]) - B_reindex_shared_local[v0, v1] = B_reindex_shared[v0, v1 // 64 * 64 + v1 % 8 // 4 * 32 + v1 % 64 // 8 * 4 + v1 % 4] + B_reindex_shared_local[v0, v1] = B_reindex_shared[ + v0, + v1 // 64 * 64 + + v1 % 8 // 4 * 32 + + v1 % 64 // 8 * 4 + + v1 % 4, + ] for ax0_3, ax1_3, ax2_2, ax0_4, ax1_4 in T.grid(4, 4, 1, 2, 1): with T.block("Y_update"): - v0 = T.axis.spatial(1024, ax0_0_ax1_0_fused // 8 * 64 + ax0_1_ax1_1_fused // 2 * 32 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + ax0_3 * 2 + ax0_4) - v1 = T.axis.spatial(1024, ax1_4 + ax0_0_ax1_0_fused % 8 * 128 + ax0_1_ax1_1_fused % 2 * 64 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 + ax1_3) + v0 = T.axis.spatial( + 1024, + ax0_0_ax1_0_fused // 8 * 64 + + ax0_1_ax1_1_fused // 2 * 32 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + + ax0_3 * 2 + + ax0_4, + ) + v1 = T.axis.spatial( + 1024, + ax1_4 + + ax0_0_ax1_0_fused % 8 * 128 + + ax0_1_ax1_1_fused % 2 * 64 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused + % 32 + // 2 + * 4 + + ax1_3, + ) v2 = T.axis.reduce(1024, ax2_0_fused * 4 + ax2_1_fused + ax2_2) - T.reads(Y_reindex_local[v0, v1], A_reindex_shared_local[v2, v0], B_reindex_shared_local[v2, v1]) + T.reads( + Y_reindex_local[v0, v1], + A_reindex_shared_local[v2, v0], + B_reindex_shared_local[v2, v1], + ) T.writes(Y_reindex_local[v0, v1]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) - Y_reindex_local[v0, v1] = Y_reindex_local[v0, v1] + A_reindex_shared_local[v2, v0] * B_reindex_shared_local[v2, v1] + T.block_attr( + { + "meta_schedule.thread_extent_high_inclusive": 1024, + "meta_schedule.thread_extent_low_inclusive": 32, + "meta_schedule.tiling_structure": "SSSRRSRS", + } + ) + Y_reindex_local[v0, v1] = ( + Y_reindex_local[v0, v1] + + A_reindex_shared_local[v2, v0] + * B_reindex_shared_local[v2, v1] + ) for ax0, ax1 in T.grid(8, 4): with T.block("Y_reindex_local"): - T.where(ax0_0_ax1_0_fused // 8 * 64 + ax0_1_ax1_1_fused // 2 * 32 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + ax0 < 1012 and ax0_0_ax1_0_fused % 8 * 128 + ax0_1_ax1_1_fused % 2 * 64 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 + ax1 < 1017) - v0 = T.axis.spatial(1024, ax0_0_ax1_0_fused // 8 * 64 + ax0_1_ax1_1_fused // 2 * 32 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + ax0) - v1 = T.axis.spatial(1024, ax0_0_ax1_0_fused % 8 * 128 + ax0_1_ax1_1_fused % 2 * 64 + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 + ax1) + T.where( + ax0_0_ax1_0_fused // 8 * 64 + + ax0_1_ax1_1_fused // 2 * 32 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + + ax0 + < 1012 + and ax0_0_ax1_0_fused % 8 * 128 + + ax0_1_ax1_1_fused % 2 * 64 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 + + ax1 + < 1017 + ) + v0 = T.axis.spatial( + 1024, + ax0_0_ax1_0_fused // 8 * 64 + + ax0_1_ax1_1_fused // 2 * 32 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + + ax0, + ) + v1 = T.axis.spatial( + 1024, + ax0_0_ax1_0_fused % 8 * 128 + + ax0_1_ax1_1_fused % 2 * 64 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 + + ax1, + ) T.reads(Y_reindex_local[v0, v1]) T.writes(Y[v0, v1]) Y[v0, v1] = Y_reindex_local[v0, v1] - -with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): - rt_mod = tvm.build(Module, target="cuda") -M, N, K = 1012, 1017, 1014 -a_tvm = tvm.nd.array(np.random.rand(M, K).astype("float32"), device=tvm.cuda(0)) -b_tvm = tvm.nd.array(np.random.rand(K, N).astype("float32"), device=tvm.cuda(0)) -c_tvm = tvm.nd.array(np.empty((M, N)).astype("float32"), device=tvm.cuda(0)) -rt_mod(a_tvm, b_tvm, c_tvm) -time_f = rt_mod.time_evaluator(rt_mod.entry_name, dev=tvm.cuda(0), number=10) -time = time_f(a_tvm, b_tvm, c_tvm).mean +def test_matmul(): + with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + rt_mod = tvm.build(Module, target="cuda") + + M, N, K = 1012, 1017, 1014 + a_tvm = tvm.nd.array(np.random.rand(M, K).astype("float32"), device=tvm.cuda(0)) + b_tvm = tvm.nd.array(np.random.rand(K, N).astype("float32"), device=tvm.cuda(0)) + c_tvm = tvm.nd.array(np.empty((M, N)).astype("float32"), device=tvm.cuda(0)) + rt_mod(a_tvm, b_tvm, c_tvm) + + time_f = rt_mod.time_evaluator(rt_mod.entry_name, dev=tvm.cuda(0), number=10) + time = time_f(a_tvm, b_tvm, c_tvm).mean + + flop = (M * N * K + M * N) * 2 + print("GFLOPS: %.2f" % (flop / time / 1e9)) + -flop = (M * N * K + M * N) * 2 -print("GFLOPS: %.2f" % (flop / time / 1e9)) +if __name__ == "__main__": + test_matmul() From bca2c63a824d8b7ac6d21d42ec9b27f1d5811d4c Mon Sep 17 00:00:00 2001 From: Tian Xia <74357442+Rainy-Memory@users.noreply.github.com> Date: Tue, 14 Feb 2023 12:26:53 +0800 Subject: [PATCH 07/21] Update src/target/source/codegen_cuda.cc Co-authored-by: Junru Shao --- src/target/source/codegen_cuda.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 4cee77641e3e..057f78e7fdfe 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -915,11 +915,13 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string src_offset = this->PrintExpr(op->args[3]); std::string size = this->PrintExpr(op->args[4]); // use size of argument list to indicate whether or not to use predicated cp.async - if (op->args.size() == 5) + if (op->args.size() == 5) { this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); - else + } + else { this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src, src_offset, size, this->PrintExpr(op->args[5])); + } } else if (op->op.same_as(builtin::ptx_commit_group())) { this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n"; } else if (op->op.same_as(builtin::ptx_wait_group())) { From 6d6f072a7ccecc98f0d6f2bfe1a6a02ae15d31d8 Mon Sep 17 00:00:00 2001 From: Tian Xia <74357442+Rainy-Memory@users.noreply.github.com> Date: Tue, 14 Feb 2023 12:27:02 +0800 Subject: [PATCH 08/21] Update src/tir/transforms/inject_ptx_async_copy.cc Co-authored-by: Junru Shao --- src/tir/transforms/inject_ptx_async_copy.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index fbced7e180a9..2d23a4da5d59 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -47,7 +47,7 @@ class PTXAsyncCopyInjector : public StmtMutator { return StmtMutator::VisitStmt_(attr); } - Stmt injectPTX(const BufferLoadNode* load, const BufferStoreNode* store, + Stmt InjectPTX(const BufferLoadNode* load, const BufferStoreNode* store, bool predicated = false, PrimExpr predicate_value = PrimExpr()) { if (load->buffer.scope() == "global") { ICHECK(load->indices.size() == 1 && store->indices.size() == 1); From bfee9f8ef7d72f6f0f3cdd76a88bad990ffe5ab8 Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Tue, 14 Feb 2023 04:34:04 +0000 Subject: [PATCH 09/21] add compute version check --- src/target/source/codegen_cuda.cc | 3 +-- tests/python/unittest/test_cp_async_in_if_then_else.py | 6 ++++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 057f78e7fdfe..ace91126faa7 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -917,8 +917,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { // use size of argument list to indicate whether or not to use predicated cp.async if (op->args.size() == 5) { this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); - } - else { + } else { this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src, src_offset, size, this->PrintExpr(op->args[5])); } diff --git a/tests/python/unittest/test_cp_async_in_if_then_else.py b/tests/python/unittest/test_cp_async_in_if_then_else.py index 5fdecad9c15a..64a565518e6f 100644 --- a/tests/python/unittest/test_cp_async_in_if_then_else.py +++ b/tests/python/unittest/test_cp_async_in_if_then_else.py @@ -283,7 +283,13 @@ def main( Y[v0, v1] = Y_reindex_local[v0, v1] +@tvm.testing.requires_cuda def test_matmul(): + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # At least sm80 is required + return with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): rt_mod = tvm.build(Module, target="cuda") From 7de4d880455d0b63c63a46f21fd53799b407c7ef Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Tue, 14 Feb 2023 04:52:02 +0000 Subject: [PATCH 10/21] fix & reformat inject_ptx_async_copy.cc --- src/tir/transforms/inject_ptx_async_copy.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index 2d23a4da5d59..f6ff947158ec 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -83,7 +83,9 @@ class PTXAsyncCopyInjector : public StmtMutator { tir::Mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)}; // use arguments size to indicate whether or not to use predicated cp.async - if (predicated) args.push_back(predicate_value); + if (predicated) { + args.push_back(predicate_value); + } return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), args)); } @@ -127,7 +129,7 @@ class PTXAsyncCopyInjector : public StmtMutator { Stmt VisitStmt_(const BufferStoreNode* store) { if (in_async && (store->buffer.scope() == "shared" || store->buffer.scope() == "shared.dyn")) { if (auto* load = store->value.as()) { - return injectPTX(load, store); + return InjectPTX(load, store); } else if (auto* call = store->value.as()) { // tir.if_then_else is a call to tir::builtin::if_then_else() if (call->op.same_as(builtin::if_then_else()) && call->args.size() == 3) { @@ -144,7 +146,9 @@ class PTXAsyncCopyInjector : public StmtMutator { if (auto* f = call->args[2].as()) { else_value_is_zero = f->value == 0.0f; } - if (else_value_is_zero) return injectPTX(load, store, true, call->args[0]); + if (else_value_is_zero) { + return InjectPTX(load, store, true, call->args[0]); + } } } } From 055289150b0dd16b345db266835b03b6e15107d6 Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Tue, 14 Feb 2023 05:27:52 +0000 Subject: [PATCH 11/21] update unittest using a small example --- .../unittest/test_cp_async_in_if_then_else.py | 471 +++++++----------- 1 file changed, 183 insertions(+), 288 deletions(-) diff --git a/tests/python/unittest/test_cp_async_in_if_then_else.py b/tests/python/unittest/test_cp_async_in_if_then_else.py index 64a565518e6f..59488cd5b663 100644 --- a/tests/python/unittest/test_cp_async_in_if_then_else.py +++ b/tests/python/unittest/test_cp_async_in_if_then_else.py @@ -1,310 +1,205 @@ import tvm import numpy as np + from tvm.script import tir as T +import tvm.testing +expected_cuda_script = r""" +#ifdef _WIN32 + using uint = unsigned int; + using uchar = unsigned char; + using ushort = unsigned short; + using int64_t = long long; + using uint64_t = unsigned long long; +#else + #define uint unsigned int + #define uchar unsigned char + #define ushort unsigned short + #define int64_t long long + #define uint64_t unsigned long long +#endif +extern "C" __global__ void __launch_bounds__(16) main_kernel0(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) { + __shared__ float A_shared[64]; + __shared__ float B_shared[64]; + A_shared[((int)threadIdx.x)] = 0.000000e+00f; + B_shared[((int)threadIdx.x)] = 0.000000e+00f; +__asm__ __volatile__("cp.async.commit_group;"); -@tvm.script.ir_module -class Module: - @T.prim_func - def main( - A: T.Buffer[(1012, 1014), "float32"], - B: T.Buffer[(1014, 1017), "float32"], - Y: T.Buffer[(1012, 1017), "float32"], - ): - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - Y_reindex_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local") - A_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") - B_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") - A_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local") - B_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local") - for ax0_0_ax1_0_fused in T.thread_binding( - 128, - thread="blockIdx.x", - annotations={"pragma_auto_unroll_max_step": 1024, "pragma_unroll_explicit": 1}, - ): - for ax0_1_ax1_1_fused in T.thread_binding(4, thread="vthread.x"): - for ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused in T.thread_binding( - 64, thread="threadIdx.x" - ): - for ax0_3_init, ax1_3_init, ax0_4_init, ax1_4_init in T.grid(4, 4, 2, 1): - with T.block("Y_init"): - v0 = T.axis.spatial( - 1024, - ax0_0_ax1_0_fused // 8 * 64 - + ax0_1_ax1_1_fused // 2 * 32 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 - + ax0_3_init * 2 - + ax0_4_init, - ) - v1 = T.axis.spatial( - 1024, - ax1_4_init - + ax0_0_ax1_0_fused % 8 * 128 - + ax0_1_ax1_1_fused % 2 * 64 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 - + ax1_3_init, - ) - T.reads() - T.writes(Y_reindex_local[v0, v1]) - T.block_attr( - { - "meta_schedule.thread_extent_high_inclusive": 1024, - "meta_schedule.thread_extent_low_inclusive": 32, - "meta_schedule.tiling_structure": "SSSRRSRS", - } - ) - Y_reindex_local[v0, v1] = T.float32(0) - for ax2_0_fused in T.serial( - 256, - annotations={ - "software_pipeline_async_stages": [0, 1], - "software_pipeline_order": [0, 1, 3, 2, 4], - "software_pipeline_stage": [0, 0, 2, 3, 3], - }, - ): - for ax0_ax1_fused_0 in T.serial(4): - for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): - with T.block("A_reindex_shared"): - v0 = T.axis.spatial( - 1024, - ax0_0_ax1_0_fused // 8 * 64 - + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) // 4, - ) - v1 = T.axis.spatial( - 1024, - ax2_0_fused * 4 - + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) % 4, - ) - T.reads(A[v0, v1]) - T.writes( - A_reindex_shared[ - v1, - v0 // 32 * 32 - + v0 % 8 // 4 * 16 - + v0 % 32 // 8 * 4 - + v0 % 4, - ] - ) - A_reindex_shared[ - v1, - v0 // 32 * 32 - + v0 % 8 // 4 * 16 - + v0 % 32 // 8 * 4 - + v0 % 4, - ] = T.if_then_else( - v0 < 1012 and v1 < 1014, - A[v0, v1], - T.float32(0), - dtype="float32", - ) - for ax0_ax1_fused_0 in T.serial(8): - for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): - with T.block("B_reindex_shared"): - v0 = T.axis.spatial( - 1024, - ax2_0_fused * 4 - + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) // 128, - ) - v1 = T.axis.spatial( - 1024, - ax0_0_ax1_0_fused % 8 * 128 - + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) % 128, - ) - T.reads(B[v0, v1]) - T.writes( - B_reindex_shared[ - v0, - v1 // 64 * 64 - + v1 % 8 // 4 * 32 - + v1 % 64 // 8 * 4 - + v1 % 4, - ] - ) - B_reindex_shared[ - v0, - v1 // 64 * 64 - + v1 % 8 // 4 * 32 - + v1 % 64 // 8 * 4 - + v1 % 4, - ] = T.if_then_else( - v0 < 1014 and v1 < 1017, - B[v0, v1], - T.float32(0), - dtype="float32", - ) - for ax2_1_fused in T.unroll( - 4, - annotations={ - "software_pipeline_order": [0, 1, 2], - "software_pipeline_stage": [0, 0, 1], - }, - ): - for ax0_ax1_fused_0 in T.unroll(2): - for ax0_ax1_fused_1 in T.vectorized(4): - with T.block("A_reindex_shared_local"): - v0 = T.axis.spatial(1024, ax2_0_fused * 4 + ax2_1_fused) - v1 = T.axis.spatial( - 1024, - ax0_0_ax1_0_fused // 8 * 64 - + ax0_1_ax1_1_fused // 2 * 32 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused - // 32 - * 16 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 - + ax0_ax1_fused_0 * 4 - + ax0_ax1_fused_1, - ) - T.reads( - A_reindex_shared[ - v0, - v1 // 32 * 32 - + v1 % 8 // 4 * 16 - + v1 % 32 // 8 * 4 - + v1 % 4, - ] - ) - T.writes(A_reindex_shared_local[v0, v1]) - A_reindex_shared_local[v0, v1] = A_reindex_shared[ - v0, - v1 // 32 * 32 - + v1 % 8 // 4 * 16 - + v1 % 32 // 8 * 4 - + v1 % 4, - ] - for ax0_ax1_fused_0 in T.unroll(2): - for ax0_ax1_fused_1 in T.vectorized(2): - with T.block("B_reindex_shared_local"): - v0 = T.axis.spatial(1024, ax2_0_fused * 4 + ax2_1_fused) - v1 = T.axis.spatial( - 1024, - ax0_0_ax1_0_fused % 8 * 128 - + ax0_1_ax1_1_fused % 2 * 64 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused - % 32 - // 2 - * 4 - + ax0_ax1_fused_0 * 2 - + ax0_ax1_fused_1, - ) - T.reads( - B_reindex_shared[ - v0, - v1 // 64 * 64 - + v1 % 8 // 4 * 32 - + v1 % 64 // 8 * 4 - + v1 % 4, - ] - ) - T.writes(B_reindex_shared_local[v0, v1]) - B_reindex_shared_local[v0, v1] = B_reindex_shared[ - v0, - v1 // 64 * 64 - + v1 % 8 // 4 * 32 - + v1 % 64 // 8 * 4 - + v1 % 4, - ] - for ax0_3, ax1_3, ax2_2, ax0_4, ax1_4 in T.grid(4, 4, 1, 2, 1): - with T.block("Y_update"): - v0 = T.axis.spatial( - 1024, - ax0_0_ax1_0_fused // 8 * 64 - + ax0_1_ax1_1_fused // 2 * 32 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 - + ax0_3 * 2 - + ax0_4, - ) - v1 = T.axis.spatial( - 1024, - ax1_4 - + ax0_0_ax1_0_fused % 8 * 128 - + ax0_1_ax1_1_fused % 2 * 64 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused - % 32 - // 2 - * 4 - + ax1_3, - ) - v2 = T.axis.reduce(1024, ax2_0_fused * 4 + ax2_1_fused + ax2_2) - T.reads( - Y_reindex_local[v0, v1], - A_reindex_shared_local[v2, v0], - B_reindex_shared_local[v2, v1], - ) - T.writes(Y_reindex_local[v0, v1]) - T.block_attr( - { - "meta_schedule.thread_extent_high_inclusive": 1024, - "meta_schedule.thread_extent_low_inclusive": 32, - "meta_schedule.tiling_structure": "SSSRRSRS", - } - ) - Y_reindex_local[v0, v1] = ( - Y_reindex_local[v0, v1] - + A_reindex_shared_local[v2, v0] - * B_reindex_shared_local[v2, v1] - ) - for ax0, ax1 in T.grid(8, 4): - with T.block("Y_reindex_local"): - T.where( - ax0_0_ax1_0_fused // 8 * 64 - + ax0_1_ax1_1_fused // 2 * 32 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 - + ax0 - < 1012 - and ax0_0_ax1_0_fused % 8 * 128 - + ax0_1_ax1_1_fused % 2 * 64 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 - + ax1 - < 1017 - ) - v0 = T.axis.spatial( - 1024, - ax0_0_ax1_0_fused // 8 * 64 - + ax0_1_ax1_1_fused // 2 * 32 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 - + ax0, - ) - v1 = T.axis.spatial( - 1024, - ax0_0_ax1_0_fused % 8 * 128 - + ax0_1_ax1_1_fused % 2 * 64 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 - + ax1, - ) - T.reads(Y_reindex_local[v0, v1]) - T.writes(Y[v0, v1]) - Y[v0, v1] = Y_reindex_local[v0, v1] + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(A_shared + (((int)threadIdx.x) + 16))) + ); + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2;" + :: "r"(addr), "l"((void*)(A + (((int)threadIdx.x) * 14))), "n"(4) + ); + } + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(B_shared + (((int)threadIdx.x) + 16))) + ); + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2;" + :: "r"(addr), "l"((void*)(B + (((int)threadIdx.x) * 14))), "n"(4) + ); + } +__asm__ __volatile__("cp.async.commit_group;"); + + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(A_shared + (((int)threadIdx.x) + 32))) + ); + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2;" + :: "r"(addr), "l"((void*)(A + ((((int)threadIdx.x) * 14) + 1))), "n"(4) + ); + } + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(B_shared + (((int)threadIdx.x) + 32))) + ); + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2;" + :: "r"(addr), "l"((void*)(B + ((((int)threadIdx.x) * 14) + 1))), "n"(4) + ); + } +__asm__ __volatile__("cp.async.commit_group;"); + + for (int i = 0; i < 13; ++i) { + bool cse_var_1 = (i < 12); + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(A_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x)))) + ); + int src_bytes = cse_var_1 ? 4 : 0; + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2, %3;" + :: "r"(addr), "l"((void*)(A + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(src_bytes) + ); + } +__asm__ __volatile__("cp.async.commit_group;"); + +__asm__ __volatile__("cp.async.wait_group 5;"); + + __syncthreads(); + C[((((int)threadIdx.x) * 16) + i)] = (A_shared[(((i & 3) * 16) + ((int)threadIdx.x))] + B_shared[(((i & 3) * 16) + ((int)threadIdx.x))]); + __syncthreads(); + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(B_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x)))) + ); + int src_bytes = cse_var_1 ? 4 : 0; + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2, %3;" + :: "r"(addr), "l"((void*)(B + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(src_bytes) + ); + } +__asm__ __volatile__("cp.async.commit_group;"); + + } +__asm__ __volatile__("cp.async.wait_group 2;"); + + __syncthreads(); + C[((((int)threadIdx.x) * 16) + 13)] = (A_shared[(((int)threadIdx.x) + 16)] + B_shared[(((int)threadIdx.x) + 16)]); +__asm__ __volatile__("cp.async.wait_group 1;"); + + __syncthreads(); + C[((((int)threadIdx.x) * 16) + 14)] = (A_shared[(((int)threadIdx.x) + 32)] + B_shared[(((int)threadIdx.x) + 32)]); +__asm__ __volatile__("cp.async.wait_group 0;"); + + __syncthreads(); + C[((((int)threadIdx.x) * 16) + 15)] = (A_shared[(((int)threadIdx.x) + 48)] + B_shared[(((int)threadIdx.x) + 48)]); +} + +""" @tvm.testing.requires_cuda -def test_matmul(): +def test_cp_async_in_if_then_else(): arch = tvm.contrib.nvcc.get_target_compute_version() major, _ = tvm.contrib.nvcc.parse_compute_version(arch) if major < 8: # At least sm80 is required return + + @T.prim_func + def simple_compute( + A: T.Buffer((16, 14), "float32"), + B: T.Buffer((16, 14), "float32"), + C: T.Buffer((16, 16), "float32"), + ): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 16, + annotations={ + "software_pipeline_stage": [0, 0, 3], + "software_pipeline_order": [0, 2, 1], + "software_pipeline_async_stages": [0], + }, + ): + with T.block("compute"): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + A_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + B_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(A_shared[tx, 0]) + A_shared[tx, 0] = T.if_then_else( + 1 <= i and i < 15, A[tx, i - 1], T.float32(0), dtype="float32" + ) + with T.block(): + T.reads(B[tx, i]) + T.writes(B_shared[tx, 0]) + B_shared[tx, 0] = T.if_then_else( + 1 <= i and i < 15, B[tx, i - 1], T.float32(0), dtype="float32" + ) + with T.block(): + T.reads(A_shared[tx, 0], B_shared[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = A_shared[tx, 0] + B_shared[tx, 0] + + mod = tvm.IRModule.from_expr(simple_compute) with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): - rt_mod = tvm.build(Module, target="cuda") + rt_mod = tvm.build(mod, target="cuda") + + assert rt_mod.imported_modules[0].get_source() == expected_cuda_script - M, N, K = 1012, 1017, 1014 - a_tvm = tvm.nd.array(np.random.rand(M, K).astype("float32"), device=tvm.cuda(0)) - b_tvm = tvm.nd.array(np.random.rand(K, N).astype("float32"), device=tvm.cuda(0)) - c_tvm = tvm.nd.array(np.empty((M, N)).astype("float32"), device=tvm.cuda(0)) + a_tvm = tvm.nd.array(np.random.rand(16, 14).astype("float32"), device=tvm.cuda(0)) + b_tvm = tvm.nd.array(np.random.rand(16, 14).astype("float32"), device=tvm.cuda(0)) + c_tvm = tvm.nd.array(np.empty((16, 16)).astype("float32"), device=tvm.cuda(0)) rt_mod(a_tvm, b_tvm, c_tvm) - time_f = rt_mod.time_evaluator(rt_mod.entry_name, dev=tvm.cuda(0), number=10) + time_f = rt_mod.time_evaluator(rt_mod.entry_name, dev=tvm.cuda(0), number=100) time = time_f(a_tvm, b_tvm, c_tvm).mean - flop = (M * N * K + M * N) * 2 - print("GFLOPS: %.2f" % (flop / time / 1e9)) + print(time) if __name__ == "__main__": - test_matmul() + test_cp_async_in_if_then_else() From 494503f70e0731bb05aa0b7862bc5bdbde715456 Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Tue, 14 Feb 2023 05:54:02 +0000 Subject: [PATCH 12/21] add gemm integration test --- .../test_gemm_cp_async_in_if_then_else.py | 313 ++++++++++++++++++ 1 file changed, 313 insertions(+) create mode 100644 tests/python/integration/test_gemm_cp_async_in_if_then_else.py diff --git a/tests/python/integration/test_gemm_cp_async_in_if_then_else.py b/tests/python/integration/test_gemm_cp_async_in_if_then_else.py new file mode 100644 index 000000000000..ef827197e689 --- /dev/null +++ b/tests/python/integration/test_gemm_cp_async_in_if_then_else.py @@ -0,0 +1,313 @@ +import tvm +import numpy as np + +from tvm.script import tir as T +import tvm.testing + + +@tvm.script.ir_module +class Module: + @T.prim_func + def main( + A: T.Buffer((1012, 1014), "float32"), + B: T.Buffer((1014, 1017), "float32"), + Y: T.Buffer((1012, 1017), "float32"), + ): + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + Y_reindex_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local") + A_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") + B_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") + A_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local") + B_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local") + for ax0_0_ax1_0_fused in T.thread_binding( + 128, + thread="blockIdx.x", + annotations={"pragma_auto_unroll_max_step": 1024, "pragma_unroll_explicit": 1}, + ): + for ax0_1_ax1_1_fused in T.thread_binding(4, thread="vthread.x"): + for ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused in T.thread_binding( + 64, thread="threadIdx.x" + ): + for ax0_3_init, ax1_3_init, ax0_4_init, ax1_4_init in T.grid(4, 4, 2, 1): + with T.block("Y_init"): + v0 = T.axis.spatial( + 1024, + ax0_0_ax1_0_fused // 8 * 64 + + ax0_1_ax1_1_fused // 2 * 32 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + + ax0_3_init * 2 + + ax0_4_init, + ) + v1 = T.axis.spatial( + 1024, + ax1_4_init + + ax0_0_ax1_0_fused % 8 * 128 + + ax0_1_ax1_1_fused % 2 * 64 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 + + ax1_3_init, + ) + T.reads() + T.writes(Y_reindex_local[v0, v1]) + T.block_attr( + { + "meta_schedule.thread_extent_high_inclusive": 1024, + "meta_schedule.thread_extent_low_inclusive": 32, + "meta_schedule.tiling_structure": "SSSRRSRS", + } + ) + Y_reindex_local[v0, v1] = T.float32(0) + for ax2_0_fused in T.serial( + 256, + annotations={ + "software_pipeline_async_stages": [0, 1], + "software_pipeline_order": [0, 1, 3, 2, 4], + "software_pipeline_stage": [0, 0, 2, 3, 3], + }, + ): + for ax0_ax1_fused_0 in T.serial(4): + for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("A_reindex_shared"): + v0 = T.axis.spatial( + 1024, + ax0_0_ax1_0_fused // 8 * 64 + + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) // 4, + ) + v1 = T.axis.spatial( + 1024, + ax2_0_fused * 4 + + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) % 4, + ) + T.reads(A[v0, v1]) + T.writes( + A_reindex_shared[ + v1, + v0 // 32 * 32 + + v0 % 8 // 4 * 16 + + v0 % 32 // 8 * 4 + + v0 % 4, + ] + ) + A_reindex_shared[ + v1, + v0 // 32 * 32 + + v0 % 8 // 4 * 16 + + v0 % 32 // 8 * 4 + + v0 % 4, + ] = T.if_then_else( + v0 < 1012 and v1 < 1014, + A[v0, v1], + T.float32(0), + dtype="float32", + ) + for ax0_ax1_fused_0 in T.serial(8): + for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("B_reindex_shared"): + v0 = T.axis.spatial( + 1024, + ax2_0_fused * 4 + + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) // 128, + ) + v1 = T.axis.spatial( + 1024, + ax0_0_ax1_0_fused % 8 * 128 + + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) % 128, + ) + T.reads(B[v0, v1]) + T.writes( + B_reindex_shared[ + v0, + v1 // 64 * 64 + + v1 % 8 // 4 * 32 + + v1 % 64 // 8 * 4 + + v1 % 4, + ] + ) + B_reindex_shared[ + v0, + v1 // 64 * 64 + + v1 % 8 // 4 * 32 + + v1 % 64 // 8 * 4 + + v1 % 4, + ] = T.if_then_else( + v0 < 1014 and v1 < 1017, + B[v0, v1], + T.float32(0), + dtype="float32", + ) + for ax2_1_fused in T.unroll( + 4, + annotations={ + "software_pipeline_order": [0, 1, 2], + "software_pipeline_stage": [0, 0, 1], + }, + ): + for ax0_ax1_fused_0 in T.unroll(2): + for ax0_ax1_fused_1 in T.vectorized(4): + with T.block("A_reindex_shared_local"): + v0 = T.axis.spatial(1024, ax2_0_fused * 4 + ax2_1_fused) + v1 = T.axis.spatial( + 1024, + ax0_0_ax1_0_fused // 8 * 64 + + ax0_1_ax1_1_fused // 2 * 32 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused + // 32 + * 16 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + + ax0_ax1_fused_0 * 4 + + ax0_ax1_fused_1, + ) + T.reads( + A_reindex_shared[ + v0, + v1 // 32 * 32 + + v1 % 8 // 4 * 16 + + v1 % 32 // 8 * 4 + + v1 % 4, + ] + ) + T.writes(A_reindex_shared_local[v0, v1]) + A_reindex_shared_local[v0, v1] = A_reindex_shared[ + v0, + v1 // 32 * 32 + + v1 % 8 // 4 * 16 + + v1 % 32 // 8 * 4 + + v1 % 4, + ] + for ax0_ax1_fused_0 in T.unroll(2): + for ax0_ax1_fused_1 in T.vectorized(2): + with T.block("B_reindex_shared_local"): + v0 = T.axis.spatial(1024, ax2_0_fused * 4 + ax2_1_fused) + v1 = T.axis.spatial( + 1024, + ax0_0_ax1_0_fused % 8 * 128 + + ax0_1_ax1_1_fused % 2 * 64 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused + % 32 + // 2 + * 4 + + ax0_ax1_fused_0 * 2 + + ax0_ax1_fused_1, + ) + T.reads( + B_reindex_shared[ + v0, + v1 // 64 * 64 + + v1 % 8 // 4 * 32 + + v1 % 64 // 8 * 4 + + v1 % 4, + ] + ) + T.writes(B_reindex_shared_local[v0, v1]) + B_reindex_shared_local[v0, v1] = B_reindex_shared[ + v0, + v1 // 64 * 64 + + v1 % 8 // 4 * 32 + + v1 % 64 // 8 * 4 + + v1 % 4, + ] + for ax0_3, ax1_3, ax2_2, ax0_4, ax1_4 in T.grid(4, 4, 1, 2, 1): + with T.block("Y_update"): + v0 = T.axis.spatial( + 1024, + ax0_0_ax1_0_fused // 8 * 64 + + ax0_1_ax1_1_fused // 2 * 32 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + + ax0_3 * 2 + + ax0_4, + ) + v1 = T.axis.spatial( + 1024, + ax1_4 + + ax0_0_ax1_0_fused % 8 * 128 + + ax0_1_ax1_1_fused % 2 * 64 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused + % 32 + // 2 + * 4 + + ax1_3, + ) + v2 = T.axis.reduce(1024, ax2_0_fused * 4 + ax2_1_fused + ax2_2) + T.reads( + Y_reindex_local[v0, v1], + A_reindex_shared_local[v2, v0], + B_reindex_shared_local[v2, v1], + ) + T.writes(Y_reindex_local[v0, v1]) + T.block_attr( + { + "meta_schedule.thread_extent_high_inclusive": 1024, + "meta_schedule.thread_extent_low_inclusive": 32, + "meta_schedule.tiling_structure": "SSSRRSRS", + } + ) + Y_reindex_local[v0, v1] = ( + Y_reindex_local[v0, v1] + + A_reindex_shared_local[v2, v0] + * B_reindex_shared_local[v2, v1] + ) + for ax0, ax1 in T.grid(8, 4): + with T.block("Y_reindex_local"): + T.where( + ax0_0_ax1_0_fused // 8 * 64 + + ax0_1_ax1_1_fused // 2 * 32 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + + ax0 + < 1012 + and ax0_0_ax1_0_fused % 8 * 128 + + ax0_1_ax1_1_fused % 2 * 64 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 + + ax1 + < 1017 + ) + v0 = T.axis.spatial( + 1024, + ax0_0_ax1_0_fused // 8 * 64 + + ax0_1_ax1_1_fused // 2 * 32 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 + + ax0, + ) + v1 = T.axis.spatial( + 1024, + ax0_0_ax1_0_fused % 8 * 128 + + ax0_1_ax1_1_fused % 2 * 64 + + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 + + ax1, + ) + T.reads(Y_reindex_local[v0, v1]) + T.writes(Y[v0, v1]) + Y[v0, v1] = Y_reindex_local[v0, v1] + + +@tvm.testing.requires_cuda +def test_matmul(): + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # At least sm80 is required + return + + with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + rt_mod = tvm.build(Module, target="cuda") + + M, N, K = 1012, 1017, 1014 + a_tvm = tvm.nd.array(np.random.rand(M, K).astype("float32"), device=tvm.cuda(0)) + b_tvm = tvm.nd.array(np.random.rand(K, N).astype("float32"), device=tvm.cuda(0)) + c_tvm = tvm.nd.array(np.empty((M, N)).astype("float32"), device=tvm.cuda(0)) + rt_mod(a_tvm, b_tvm, c_tvm) + + time_f = rt_mod.time_evaluator(rt_mod.entry_name, dev=tvm.cuda(0), number=10) + time = time_f(a_tvm, b_tvm, c_tvm).mean + + flop = (M * N * K + M * N) * 2 + print("GFLOPS: %.2f" % (flop / time / 1e9)) + + +if __name__ == "__main__": + test_matmul() From 59c87dddded52856d1aa92406725f289298be707 Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Tue, 14 Feb 2023 07:22:36 +0000 Subject: [PATCH 13/21] add correctness test for integration test --- .../integration/test_gemm_cp_async_in_if_then_else.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/python/integration/test_gemm_cp_async_in_if_then_else.py b/tests/python/integration/test_gemm_cp_async_in_if_then_else.py index ef827197e689..a184e30f818c 100644 --- a/tests/python/integration/test_gemm_cp_async_in_if_then_else.py +++ b/tests/python/integration/test_gemm_cp_async_in_if_then_else.py @@ -297,10 +297,14 @@ def test_matmul(): rt_mod = tvm.build(Module, target="cuda") M, N, K = 1012, 1017, 1014 - a_tvm = tvm.nd.array(np.random.rand(M, K).astype("float32"), device=tvm.cuda(0)) - b_tvm = tvm.nd.array(np.random.rand(K, N).astype("float32"), device=tvm.cuda(0)) + a_np = np.random.rand(M, K).astype("float32") + b_np = np.random.rand(K, N).astype("float32") + c_np = a_np @ b_np + a_tvm = tvm.nd.array(a_np, device=tvm.cuda(0)) + b_tvm = tvm.nd.array(b_np, device=tvm.cuda(0)) c_tvm = tvm.nd.array(np.empty((M, N)).astype("float32"), device=tvm.cuda(0)) rt_mod(a_tvm, b_tvm, c_tvm) + assert np.allclose(c_tvm.numpy(), c_np) time_f = rt_mod.time_evaluator(rt_mod.entry_name, dev=tvm.cuda(0), number=10) time = time_f(a_tvm, b_tvm, c_tvm).mean From 8cf7c797b471eea863429f42acb4380a9935013a Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Thu, 16 Feb 2023 05:51:21 +0000 Subject: [PATCH 14/21] update test script to support device < sm80 --- .../test_gemm_cp_async_in_if_then_else.py | 317 ------------------ .../unittest/test_cp_async_in_if_then_else.py | 39 ++- 2 files changed, 30 insertions(+), 326 deletions(-) delete mode 100644 tests/python/integration/test_gemm_cp_async_in_if_then_else.py diff --git a/tests/python/integration/test_gemm_cp_async_in_if_then_else.py b/tests/python/integration/test_gemm_cp_async_in_if_then_else.py deleted file mode 100644 index a184e30f818c..000000000000 --- a/tests/python/integration/test_gemm_cp_async_in_if_then_else.py +++ /dev/null @@ -1,317 +0,0 @@ -import tvm -import numpy as np - -from tvm.script import tir as T -import tvm.testing - - -@tvm.script.ir_module -class Module: - @T.prim_func - def main( - A: T.Buffer((1012, 1014), "float32"), - B: T.Buffer((1014, 1017), "float32"), - Y: T.Buffer((1012, 1017), "float32"), - ): - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - Y_reindex_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local") - A_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") - B_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") - A_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local") - B_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local") - for ax0_0_ax1_0_fused in T.thread_binding( - 128, - thread="blockIdx.x", - annotations={"pragma_auto_unroll_max_step": 1024, "pragma_unroll_explicit": 1}, - ): - for ax0_1_ax1_1_fused in T.thread_binding(4, thread="vthread.x"): - for ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused in T.thread_binding( - 64, thread="threadIdx.x" - ): - for ax0_3_init, ax1_3_init, ax0_4_init, ax1_4_init in T.grid(4, 4, 2, 1): - with T.block("Y_init"): - v0 = T.axis.spatial( - 1024, - ax0_0_ax1_0_fused // 8 * 64 - + ax0_1_ax1_1_fused // 2 * 32 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 - + ax0_3_init * 2 - + ax0_4_init, - ) - v1 = T.axis.spatial( - 1024, - ax1_4_init - + ax0_0_ax1_0_fused % 8 * 128 - + ax0_1_ax1_1_fused % 2 * 64 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 - + ax1_3_init, - ) - T.reads() - T.writes(Y_reindex_local[v0, v1]) - T.block_attr( - { - "meta_schedule.thread_extent_high_inclusive": 1024, - "meta_schedule.thread_extent_low_inclusive": 32, - "meta_schedule.tiling_structure": "SSSRRSRS", - } - ) - Y_reindex_local[v0, v1] = T.float32(0) - for ax2_0_fused in T.serial( - 256, - annotations={ - "software_pipeline_async_stages": [0, 1], - "software_pipeline_order": [0, 1, 3, 2, 4], - "software_pipeline_stage": [0, 0, 2, 3, 3], - }, - ): - for ax0_ax1_fused_0 in T.serial(4): - for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): - with T.block("A_reindex_shared"): - v0 = T.axis.spatial( - 1024, - ax0_0_ax1_0_fused // 8 * 64 - + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) // 4, - ) - v1 = T.axis.spatial( - 1024, - ax2_0_fused * 4 - + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) % 4, - ) - T.reads(A[v0, v1]) - T.writes( - A_reindex_shared[ - v1, - v0 // 32 * 32 - + v0 % 8 // 4 * 16 - + v0 % 32 // 8 * 4 - + v0 % 4, - ] - ) - A_reindex_shared[ - v1, - v0 // 32 * 32 - + v0 % 8 // 4 * 16 - + v0 % 32 // 8 * 4 - + v0 % 4, - ] = T.if_then_else( - v0 < 1012 and v1 < 1014, - A[v0, v1], - T.float32(0), - dtype="float32", - ) - for ax0_ax1_fused_0 in T.serial(8): - for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): - with T.block("B_reindex_shared"): - v0 = T.axis.spatial( - 1024, - ax2_0_fused * 4 - + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) // 128, - ) - v1 = T.axis.spatial( - 1024, - ax0_0_ax1_0_fused % 8 * 128 - + (ax0_ax1_fused_0 * 64 + ax0_ax1_fused_1) % 128, - ) - T.reads(B[v0, v1]) - T.writes( - B_reindex_shared[ - v0, - v1 // 64 * 64 - + v1 % 8 // 4 * 32 - + v1 % 64 // 8 * 4 - + v1 % 4, - ] - ) - B_reindex_shared[ - v0, - v1 // 64 * 64 - + v1 % 8 // 4 * 32 - + v1 % 64 // 8 * 4 - + v1 % 4, - ] = T.if_then_else( - v0 < 1014 and v1 < 1017, - B[v0, v1], - T.float32(0), - dtype="float32", - ) - for ax2_1_fused in T.unroll( - 4, - annotations={ - "software_pipeline_order": [0, 1, 2], - "software_pipeline_stage": [0, 0, 1], - }, - ): - for ax0_ax1_fused_0 in T.unroll(2): - for ax0_ax1_fused_1 in T.vectorized(4): - with T.block("A_reindex_shared_local"): - v0 = T.axis.spatial(1024, ax2_0_fused * 4 + ax2_1_fused) - v1 = T.axis.spatial( - 1024, - ax0_0_ax1_0_fused // 8 * 64 - + ax0_1_ax1_1_fused // 2 * 32 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused - // 32 - * 16 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 - + ax0_ax1_fused_0 * 4 - + ax0_ax1_fused_1, - ) - T.reads( - A_reindex_shared[ - v0, - v1 // 32 * 32 - + v1 % 8 // 4 * 16 - + v1 % 32 // 8 * 4 - + v1 % 4, - ] - ) - T.writes(A_reindex_shared_local[v0, v1]) - A_reindex_shared_local[v0, v1] = A_reindex_shared[ - v0, - v1 // 32 * 32 - + v1 % 8 // 4 * 16 - + v1 % 32 // 8 * 4 - + v1 % 4, - ] - for ax0_ax1_fused_0 in T.unroll(2): - for ax0_ax1_fused_1 in T.vectorized(2): - with T.block("B_reindex_shared_local"): - v0 = T.axis.spatial(1024, ax2_0_fused * 4 + ax2_1_fused) - v1 = T.axis.spatial( - 1024, - ax0_0_ax1_0_fused % 8 * 128 - + ax0_1_ax1_1_fused % 2 * 64 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused - % 32 - // 2 - * 4 - + ax0_ax1_fused_0 * 2 - + ax0_ax1_fused_1, - ) - T.reads( - B_reindex_shared[ - v0, - v1 // 64 * 64 - + v1 % 8 // 4 * 32 - + v1 % 64 // 8 * 4 - + v1 % 4, - ] - ) - T.writes(B_reindex_shared_local[v0, v1]) - B_reindex_shared_local[v0, v1] = B_reindex_shared[ - v0, - v1 // 64 * 64 - + v1 % 8 // 4 * 32 - + v1 % 64 // 8 * 4 - + v1 % 4, - ] - for ax0_3, ax1_3, ax2_2, ax0_4, ax1_4 in T.grid(4, 4, 1, 2, 1): - with T.block("Y_update"): - v0 = T.axis.spatial( - 1024, - ax0_0_ax1_0_fused // 8 * 64 - + ax0_1_ax1_1_fused // 2 * 32 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 - + ax0_3 * 2 - + ax0_4, - ) - v1 = T.axis.spatial( - 1024, - ax1_4 - + ax0_0_ax1_0_fused % 8 * 128 - + ax0_1_ax1_1_fused % 2 * 64 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused - % 32 - // 2 - * 4 - + ax1_3, - ) - v2 = T.axis.reduce(1024, ax2_0_fused * 4 + ax2_1_fused + ax2_2) - T.reads( - Y_reindex_local[v0, v1], - A_reindex_shared_local[v2, v0], - B_reindex_shared_local[v2, v1], - ) - T.writes(Y_reindex_local[v0, v1]) - T.block_attr( - { - "meta_schedule.thread_extent_high_inclusive": 1024, - "meta_schedule.thread_extent_low_inclusive": 32, - "meta_schedule.tiling_structure": "SSSRRSRS", - } - ) - Y_reindex_local[v0, v1] = ( - Y_reindex_local[v0, v1] - + A_reindex_shared_local[v2, v0] - * B_reindex_shared_local[v2, v1] - ) - for ax0, ax1 in T.grid(8, 4): - with T.block("Y_reindex_local"): - T.where( - ax0_0_ax1_0_fused // 8 * 64 - + ax0_1_ax1_1_fused // 2 * 32 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 - + ax0 - < 1012 - and ax0_0_ax1_0_fused % 8 * 128 - + ax0_1_ax1_1_fused % 2 * 64 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 - + ax1 - < 1017 - ) - v0 = T.axis.spatial( - 1024, - ax0_0_ax1_0_fused // 8 * 64 - + ax0_1_ax1_1_fused // 2 * 32 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8 - + ax0, - ) - v1 = T.axis.spatial( - 1024, - ax0_0_ax1_0_fused % 8 * 128 - + ax0_1_ax1_1_fused % 2 * 64 - + ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4 - + ax1, - ) - T.reads(Y_reindex_local[v0, v1]) - T.writes(Y[v0, v1]) - Y[v0, v1] = Y_reindex_local[v0, v1] - - -@tvm.testing.requires_cuda -def test_matmul(): - arch = tvm.contrib.nvcc.get_target_compute_version() - major, _ = tvm.contrib.nvcc.parse_compute_version(arch) - if major < 8: - # At least sm80 is required - return - - with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): - rt_mod = tvm.build(Module, target="cuda") - - M, N, K = 1012, 1017, 1014 - a_np = np.random.rand(M, K).astype("float32") - b_np = np.random.rand(K, N).astype("float32") - c_np = a_np @ b_np - a_tvm = tvm.nd.array(a_np, device=tvm.cuda(0)) - b_tvm = tvm.nd.array(b_np, device=tvm.cuda(0)) - c_tvm = tvm.nd.array(np.empty((M, N)).astype("float32"), device=tvm.cuda(0)) - rt_mod(a_tvm, b_tvm, c_tvm) - assert np.allclose(c_tvm.numpy(), c_np) - - time_f = rt_mod.time_evaluator(rt_mod.entry_name, dev=tvm.cuda(0), number=10) - time = time_f(a_tvm, b_tvm, c_tvm).mean - - flop = (M * N * K + M * N) * 2 - print("GFLOPS: %.2f" % (flop / time / 1e9)) - - -if __name__ == "__main__": - test_matmul() diff --git a/tests/python/unittest/test_cp_async_in_if_then_else.py b/tests/python/unittest/test_cp_async_in_if_then_else.py index 59488cd5b663..0eb1cf42c96d 100644 --- a/tests/python/unittest/test_cp_async_in_if_then_else.py +++ b/tests/python/unittest/test_cp_async_in_if_then_else.py @@ -138,13 +138,32 @@ """ +generated_code = "" +support_async = True + + +@tvm.register_func +def tvm_callback_cuda_postproc(code): + global generated_code + generated_code = code + # return a dummy code so that device < sm80 could build correctly + if not support_async: + return ( + 'extern "C" __global__ void __launch_bounds__(32) ' + "main_kernel0(float* __restrict__ inputs, " + "float* __restrict__ weight, float* __restrict__ conv2d_nhwc) {}" + ) + return code + + @tvm.testing.requires_cuda def test_cp_async_in_if_then_else(): arch = tvm.contrib.nvcc.get_target_compute_version() major, _ = tvm.contrib.nvcc.parse_compute_version(arch) if major < 8: # At least sm80 is required - return + global support_async + support_async = False @T.prim_func def simple_compute( @@ -188,17 +207,19 @@ def simple_compute( with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): rt_mod = tvm.build(mod, target="cuda") - assert rt_mod.imported_modules[0].get_source() == expected_cuda_script + assert generated_code == expected_cuda_script + print(generated_code) - a_tvm = tvm.nd.array(np.random.rand(16, 14).astype("float32"), device=tvm.cuda(0)) - b_tvm = tvm.nd.array(np.random.rand(16, 14).astype("float32"), device=tvm.cuda(0)) - c_tvm = tvm.nd.array(np.empty((16, 16)).astype("float32"), device=tvm.cuda(0)) - rt_mod(a_tvm, b_tvm, c_tvm) + if support_async: + a_tvm = tvm.nd.array(np.random.rand(16, 14).astype("float32"), device=tvm.cuda(0)) + b_tvm = tvm.nd.array(np.random.rand(16, 14).astype("float32"), device=tvm.cuda(0)) + c_tvm = tvm.nd.array(np.empty((16, 16)).astype("float32"), device=tvm.cuda(0)) + rt_mod(a_tvm, b_tvm, c_tvm) - time_f = rt_mod.time_evaluator(rt_mod.entry_name, dev=tvm.cuda(0), number=100) - time = time_f(a_tvm, b_tvm, c_tvm).mean + time_f = rt_mod.time_evaluator(rt_mod.entry_name, dev=tvm.cuda(0), number=100) + time = time_f(a_tvm, b_tvm, c_tvm).mean - print(time) + print(time) if __name__ == "__main__": From 07fcc4f5929350dfa5a03a43ad5f9f3f43149c52 Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Thu, 16 Feb 2023 09:27:38 +0000 Subject: [PATCH 15/21] fix unittest --- tests/python/unittest/test_cp_async_in_if_then_else.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/python/unittest/test_cp_async_in_if_then_else.py b/tests/python/unittest/test_cp_async_in_if_then_else.py index 0eb1cf42c96d..2f316d77df6e 100644 --- a/tests/python/unittest/test_cp_async_in_if_then_else.py +++ b/tests/python/unittest/test_cp_async_in_if_then_else.py @@ -220,6 +220,10 @@ def simple_compute( time = time_f(a_tvm, b_tvm, c_tvm).mean print(time) + else: + global support_async + # avoid return dummy code to other tests + support_async = True if __name__ == "__main__": From 1c298639db97e0e41548c1692218346e549efcd4 Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Thu, 16 Feb 2023 09:44:03 +0000 Subject: [PATCH 16/21] fix unittest dummy code function name --- .../unittest/test_cp_async_in_if_then_else.py | 35 +++++++------------ 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/tests/python/unittest/test_cp_async_in_if_then_else.py b/tests/python/unittest/test_cp_async_in_if_then_else.py index 2f316d77df6e..86e0391a65ca 100644 --- a/tests/python/unittest/test_cp_async_in_if_then_else.py +++ b/tests/python/unittest/test_cp_async_in_if_then_else.py @@ -148,11 +148,13 @@ def tvm_callback_cuda_postproc(code): generated_code = code # return a dummy code so that device < sm80 could build correctly if not support_async: - return ( - 'extern "C" __global__ void __launch_bounds__(32) ' - "main_kernel0(float* __restrict__ inputs, " - "float* __restrict__ weight, float* __restrict__ conv2d_nhwc) {}" - ) + ret = '' + for line in code.split('\n'): + ret += line + '\n' + if line.startswith('extern "C" __global__'): + break + ret += '}' + return ret return code @@ -205,25 +207,14 @@ def simple_compute( mod = tvm.IRModule.from_expr(simple_compute) with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): - rt_mod = tvm.build(mod, target="cuda") + tvm.build(mod, target="cuda") assert generated_code == expected_cuda_script - print(generated_code) - - if support_async: - a_tvm = tvm.nd.array(np.random.rand(16, 14).astype("float32"), device=tvm.cuda(0)) - b_tvm = tvm.nd.array(np.random.rand(16, 14).astype("float32"), device=tvm.cuda(0)) - c_tvm = tvm.nd.array(np.empty((16, 16)).astype("float32"), device=tvm.cuda(0)) - rt_mod(a_tvm, b_tvm, c_tvm) - - time_f = rt_mod.time_evaluator(rt_mod.entry_name, dev=tvm.cuda(0), number=100) - time = time_f(a_tvm, b_tvm, c_tvm).mean - - print(time) - else: - global support_async - # avoid return dummy code to other tests - support_async = True + + if not support_async: + global support_async + # avoid return dummy code to other tests + support_async = True if __name__ == "__main__": From f5f606fcf7535bd697388c0a4adf166e31338b24 Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Thu, 16 Feb 2023 09:49:48 +0000 Subject: [PATCH 17/21] reformat unittest --- tests/python/unittest/test_cp_async_in_if_then_else.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_cp_async_in_if_then_else.py b/tests/python/unittest/test_cp_async_in_if_then_else.py index 86e0391a65ca..d839ca093005 100644 --- a/tests/python/unittest/test_cp_async_in_if_then_else.py +++ b/tests/python/unittest/test_cp_async_in_if_then_else.py @@ -148,12 +148,12 @@ def tvm_callback_cuda_postproc(code): generated_code = code # return a dummy code so that device < sm80 could build correctly if not support_async: - ret = '' - for line in code.split('\n'): - ret += line + '\n' + ret = "" + for line in code.split("\n"): + ret += line + "\n" if line.startswith('extern "C" __global__'): break - ret += '}' + ret += "}" return ret return code From 28b866fb68ee7f108b9f67916d8dcbd0df6b0e6e Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Thu, 16 Feb 2023 13:08:21 +0000 Subject: [PATCH 18/21] add license and doc string --- .../unittest/test_cp_async_in_if_then_else.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/python/unittest/test_cp_async_in_if_then_else.py b/tests/python/unittest/test_cp_async_in_if_then_else.py index d839ca093005..539b4956f943 100644 --- a/tests/python/unittest/test_cp_async_in_if_then_else.py +++ b/tests/python/unittest/test_cp_async_in_if_then_else.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""test the correctness of inject async memory copy from an if_then_else load""" import tvm import numpy as np From 1de76256ef972d99b066434bb0ffa69b0587fbd5 Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Thu, 16 Feb 2023 13:35:10 +0000 Subject: [PATCH 19/21] reformat inject_ptx_async_copy.cc --- src/tir/transforms/inject_ptx_async_copy.cc | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index f6ff947158ec..2e3c906e89c1 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -47,8 +47,8 @@ class PTXAsyncCopyInjector : public StmtMutator { return StmtMutator::VisitStmt_(attr); } - Stmt InjectPTX(const BufferLoadNode* load, const BufferStoreNode* store, - bool predicated = false, PrimExpr predicate_value = PrimExpr()) { + Stmt InjectPTX(const BufferLoadNode* load, const BufferStoreNode* store, bool predicated = false, + PrimExpr predicate_value = PrimExpr()) { if (load->buffer.scope() == "global") { ICHECK(load->indices.size() == 1 && store->indices.size() == 1); ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes()); @@ -79,8 +79,7 @@ class PTXAsyncCopyInjector : public StmtMutator { if (indices_lanes == 1) { auto src_offset = load->indices[0]; auto dst_offset = store->indices[0]; - Array args = {store->buffer->data, - tir::Mul(dst_offset, PrimExpr(index_factor)), + Array args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)}; // use arguments size to indicate whether or not to use predicated cp.async if (predicated) { @@ -115,10 +114,9 @@ class PTXAsyncCopyInjector : public StmtMutator { }(); if (src_offset.defined() && dst_offset.defined()) { - return Evaluate( - Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), - {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), - load->buffer->data, src_offset, PrimExpr(bytes)})); + return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), + {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes)})); } } } @@ -134,9 +132,9 @@ class PTXAsyncCopyInjector : public StmtMutator { // tir.if_then_else is a call to tir::builtin::if_then_else() if (call->op.same_as(builtin::if_then_else()) && call->args.size() == 3) { if (auto* load = call->args[1].as()) { - // Only default value of 0 is supported since 0 is the default value used by cp.async ptx. - // @see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-memory-operations - // section 9.7.8.22.3. + // Only default value of 0 is supported since 0 is the default value used by cp.async + // ptx. @see section 9.7.8.22.3. of + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-memory-operations bool else_value_is_zero = false; if (auto* b = call->args[2].as()) { if (auto* f = b->value.as()) { From 57f019304e863563fcbcf8147a17053f31d2ad1f Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Thu, 16 Feb 2023 15:09:15 +0000 Subject: [PATCH 20/21] reformat codegen_cuda.cc --- src/target/source/codegen_cuda.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index ace91126faa7..9bf0109cace1 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -918,8 +918,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { if (op->args.size() == 5) { this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); } else { - this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src, src_offset, - size, this->PrintExpr(op->args[5])); + this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src, src_offset, size, + this->PrintExpr(op->args[5])); } } else if (op->op.same_as(builtin::ptx_commit_group())) { this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n"; From 202fead6a17352fcbb8c283455d8a937caca19dd Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Thu, 16 Feb 2023 16:54:36 +0000 Subject: [PATCH 21/21] fix global syntax error in unittest --- tests/python/unittest/test_cp_async_in_if_then_else.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_cp_async_in_if_then_else.py b/tests/python/unittest/test_cp_async_in_if_then_else.py index 539b4956f943..08de5ba34da1 100644 --- a/tests/python/unittest/test_cp_async_in_if_then_else.py +++ b/tests/python/unittest/test_cp_async_in_if_then_else.py @@ -162,6 +162,7 @@ @tvm.register_func def tvm_callback_cuda_postproc(code): global generated_code + global support_async generated_code = code # return a dummy code so that device < sm80 could build correctly if not support_async: @@ -177,11 +178,11 @@ def tvm_callback_cuda_postproc(code): @tvm.testing.requires_cuda def test_cp_async_in_if_then_else(): + global support_async arch = tvm.contrib.nvcc.get_target_compute_version() major, _ = tvm.contrib.nvcc.parse_compute_version(arch) if major < 8: # At least sm80 is required - global support_async support_async = False @T.prim_func @@ -229,7 +230,6 @@ def simple_compute( assert generated_code == expected_cuda_script if not support_async: - global support_async # avoid return dummy code to other tests support_async = True