diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 4c8a3076a20b..ac35c0b41e0e 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1525,6 +1525,12 @@ constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensori /*! \brief Mark that a block is a preprocessor block for layout rewrite. */ constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc"; +/*! + * \brief Mark that a block is executed by a warp. This implies the extend of threadIdx.x is + * warp size. + */ +constexpr const char* warp_execution = "warp_execution"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 674359803880..57e58e6a79ff 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -25,9 +25,11 @@ namespace tir { class ThreadExtentChecker : private StmtVisitor { public: - static bool Check(const Stmt& stmt) { + static bool Check(const Stmt& stmt, int thread_warp_size) { try { - ThreadExtentChecker().VisitStmt(stmt); + ICHECK(thread_warp_size > 0); + ThreadExtentChecker checker(thread_warp_size); + checker.VisitStmt(stmt); return true; } catch (const dmlc::Error& e) { return false; @@ -35,6 +37,8 @@ class ThreadExtentChecker : private StmtVisitor { } private: + explicit ThreadExtentChecker(int thread_warp_size) : thread_warp_size_(thread_warp_size) {} + void VisitStmt_(const ForNode* loop) { runtime::ThreadScope thread_scope = GetThreadScope(loop); if (IsThreadIdx(thread_scope)) { @@ -64,6 +68,10 @@ class ThreadExtentChecker : private StmtVisitor { } void VisitStmt_(const BlockNode* block) { + int old_thread_idx_x = thread_idx_x; + if (block->annotations.count(attr::warp_execution)) { + thread_idx_x = thread_warp_size_; + } if (Optional low_inclusive = GetAnn(block, attr::meta_schedule_thread_extent_low_inclusive)) { if (Optional high_inclusive = @@ -77,11 +85,13 @@ class ThreadExtentChecker : private StmtVisitor { } } StmtVisitor::VisitStmt_(block); + thread_idx_x = old_thread_idx_x; } int64_t thread_idx_x = 1; int64_t thread_idx_y = 1; int64_t thread_idx_z = 1; + int thread_warp_size_ = -1; }; } // namespace tir @@ -104,6 +114,7 @@ Integer Extract(const Target& target, const char* name) { class VerifyGPUCodeNode : public PostprocNode { public: Map target_constraints_{nullptr}; + int thread_warp_size_ = -1; void InitializeWithTuneContext(const TuneContext& context) final { ICHECK(context->target.defined()); @@ -114,6 +125,7 @@ class VerifyGPUCodeNode : public PostprocNode { {"max_vthread", Integer(8)}, {"max_vector_bytes", Integer(16)}, }; + thread_warp_size_ = Extract(target, "thread_warp_size"); } bool Verify(const IRModule& mod) const { @@ -133,7 +145,7 @@ class VerifyGPUCodeNode : public PostprocNode { const GlobalVar& g_var = kv.first; const BaseFunc& base_func = kv.second; if (const auto* prim_func = base_func.as()) { - if (!tir::ThreadExtentChecker::Check(prim_func->body)) { + if (!tir::ThreadExtentChecker::Check(prim_func->body, thread_warp_size_)) { return false; } IRModule lowered{nullptr}; diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py index aacb889cb577..0b1e0f402b9d 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -393,58 +393,412 @@ def GmmCuda2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), " T.writes(Z[v0, v1, v2]) Z[v0, v1, v2] = Z_local[v0, v1, v2] -# fmt: on -# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant - - -def test_postproc_verify_gpu_0(): - mod = Conv2dCuda0 - ctx = _create_context(mod, target=_target()) - sch = tir.Schedule(mod, debug_mask="all") - assert ctx.postprocs[0].apply(sch) - -def test_postproc_verify_gpu_1(): - mod = Conv2dCuda1 - ctx = _create_context(mod, target=_target()) - sch = tir.Schedule(mod, debug_mask="all") - assert ctx.postprocs[0].apply(sch) - - -def test_postproc_verify_gpu_2(): - mod = Conv2dCuda2 - ctx = _create_context(mod, target=_target()) - sch = tir.Schedule(mod, debug_mask="all") - # Should fail due to too much local memory per block (large - # Apad_shared allocation). - assert not ctx.postprocs[0].apply(sch) +@T.prim_func +def GMMCUDATensorCore( + X: T.Buffer[(1024, 1024), "float16"], + Y: T.Buffer[(1024, 1024), "float16"], + Z: T.Buffer[(1024, 1024), "float32"], +) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + s0 = T.var("int32") + s0_1 = T.var("int32") + s0_2 = T.var("int32") + s1 = T.var("int32") + s1_1 = T.var("int32") + s1_2 = T.var("int32") + # body + # with T.block("root") + Z_wmma_accumulator = T.alloc_buffer([1024, 1024], dtype="float32", scope="wmma.accumulator") + X_shared = T.alloc_buffer([1024, 1024], dtype="float16", scope="shared") + Y_shared = T.alloc_buffer([1024, 1024], dtype="float16", scope="shared") + X_shared_wmma_matrix_a = T.alloc_buffer([1024, 1024], dtype="float16", scope="wmma.matrix_a") + Y_shared_wmma_matrix_b = T.alloc_buffer([1024, 1024], dtype="float16", scope="wmma.matrix_b") + for ax0_0_ax1_0_0_ax2_0_0_fused in T.thread_binding(64, thread="blockIdx.x"): + for ax0_1_ax1_0_1_ax2_0_1_fused in T.thread_binding(2, thread="blockIdx.y"): + for ax0_2_ax1_0_2_ax2_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax1_0_3_init, ax2_0_3_init, ax1_0_4_init, ax2_0_4_init in T.grid(2, 1, 2, 4): + with T.block("Z_o_init"): + v0 = T.axis.spatial(1, 0) + v1_o = T.axis.spatial( + 64, + ax0_0_ax1_0_0_ax2_0_0_fused % 64 // 16 * 16 + + ax0_1_ax1_0_1_ax2_0_1_fused % 2 * 8 + + ax0_2_ax1_0_2_ax2_0_2_fused % 2 * 4 + + ax1_0_3_init * 2 + + ax1_0_4_init, + ) + v2_o = T.axis.spatial( + 64, + (ax0_0_ax1_0_0_ax2_0_0_fused % 16 + 0 + 0 + ax2_0_3_init) * 4 + + ax2_0_4_init, + ) + T.reads() + T.writes( + Z_wmma_accumulator[ + v1_o * 16 : v1_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16 + ] + ) + T.block_attr( + { + "meta_schedule.thread_extent_high_inclusive": 1024, + "meta_schedule.thread_extent_low_inclusive": 32, + "warp_execution": 1, + } + ) + C = T.match_buffer( + Z_wmma_accumulator[ + v1_o * 16 : v1_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16 + ], + [16, 16], + dtype="float32", + scope="wmma.accumulator", + offset_factor=16, + ) + T.evaluate( + T.tvm_fill_fragment( + C.data, + 16, + 16, + 16, + C.elem_offset // 256 + C.elem_offset % 256 // 16, + T.float32(0), + dtype="handle", + ) + ) + for ax3_0_0 in T.serial(32): + for ax0_ax1_fused_0 in T.serial(16): + for ax0_ax1_fused_1 in T.thread_binding(2, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): + for ax0_ax1_fused_3 in T.vectorized(4): + with T.block("X_shared"): + v0 = T.axis.spatial( + 1024, + ax0_0_ax1_0_0_ax2_0_0_fused // 16 * 256 + + ax0_1_ax1_0_1_ax2_0_1_fused * 128 + + ( + ax0_ax1_fused_0 * 256 + + ax0_ax1_fused_1 * 128 + + ax0_ax1_fused_2 * 4 + + ax0_ax1_fused_3 + ) + // 32, + ) + v1 = T.axis.spatial( + 1024, + ax3_0_0 * 32 + + ( + ax0_ax1_fused_0 * 256 + + ax0_ax1_fused_1 * 128 + + ax0_ax1_fused_2 * 4 + + ax0_ax1_fused_3 + ) + % 32, + ) + T.reads(X[v0, v1]) + T.writes(X_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]]}) + X_shared[v0, v1] = X[v0, v1] + for ax0_ax1_fused_0 in T.serial(8): + for ax0_ax1_fused_1 in T.thread_binding(2, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): + for ax0_ax1_fused_3 in T.vectorized(4): + with T.block("Y_shared"): + v0 = T.axis.spatial( + 1024, + ax3_0_0 * 32 + + ( + ax0_ax1_fused_0 * 256 + + ax0_ax1_fused_1 * 128 + + ax0_ax1_fused_2 * 4 + + ax0_ax1_fused_3 + ) + // 64, + ) + v1 = T.axis.spatial( + 1024, + ax0_0_ax1_0_0_ax2_0_0_fused % 16 * 64 + + ( + ax0_ax1_fused_0 * 256 + + ax0_ax1_fused_1 * 128 + + ax0_ax1_fused_2 * 4 + + ax0_ax1_fused_3 + ) + % 64, + ) + T.reads(Y[v0, v1]) + T.writes(Y_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]]}) + Y_shared[v0, v1] = Y[v0, v1] + for ax3_0_1 in T.serial(2): + for ax0_0, ax1_0 in T.grid(4, 1): + with T.block("X_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial( + 64, + ax0_0_ax1_0_0_ax2_0_0_fused // 16 * 16 + + ax0_1_ax1_0_1_ax2_0_1_fused * 8 + + ax0_2_ax1_0_2_ax2_0_2_fused * 4 + + ax0_0, + ) + v1_o = T.axis.spatial(64, ax3_0_0 * 2 + ax3_0_1) + T.reads( + X_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16] + ) + T.writes( + X_shared_wmma_matrix_a[ + v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16 + ] + ) + A = T.match_buffer( + X_shared[ + v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16 + ], + [16, 16], + dtype="float16", + strides=[s1, s0], + scope="shared", + offset_factor=16, + ) + C_1 = T.match_buffer( + X_shared_wmma_matrix_a[ + v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16 + ], + [16, 16], + dtype="float16", + scope="wmma.matrix_a", + offset_factor=16, + ) + T.evaluate( + T.tvm_load_matrix_sync( + C_1.data, + 16, + 16, + 16, + C_1.elem_offset // 256 + C_1.elem_offset % 256 // 16, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A.data, + A.elem_offset, + s1 * 16, + 1, + dtype="handle", + ), + s1, + "row_major", + dtype="handle", + ) + ) + for ax0_0, ax1_0 in T.grid(1, 4): + with T.block("Y_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(64, ax3_0_0 * 2 + ax3_0_1) + v1_o = T.axis.spatial( + 64, ax0_0_ax1_0_0_ax2_0_0_fused % 16 * 4 + ax1_0 + ) + T.reads( + Y_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16] + ) + T.writes( + Y_shared_wmma_matrix_b[ + v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16 + ] + ) + A_1 = T.match_buffer( + Y_shared[ + v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16 + ], + [16, 16], + dtype="float16", + strides=[s1_1, s0_1], + scope="shared", + offset_factor=16, + ) + C_2 = T.match_buffer( + Y_shared_wmma_matrix_b[ + v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16 + ], + [16, 16], + dtype="float16", + scope="wmma.matrix_b", + offset_factor=16, + ) + T.evaluate( + T.tvm_load_matrix_sync( + C_2.data, + 16, + 16, + 16, + C_2.elem_offset // 256 + C_2.elem_offset % 256 // 16, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A_1.data, + A_1.elem_offset, + s1_1 * 16, + 1, + dtype="handle", + ), + s1_1, + "row_major", + dtype="handle", + ) + ) + for ax0_3, ax1_0_3, ax2_0_3, ax3_0_2, ax0_4, ax1_0_4, ax2_0_4 in T.grid( + 1, 2, 1, 1, 1, 2, 4 + ): + with T.block("Z_o_update"): + v0 = T.axis.spatial(1, 0) + v1_o = T.axis.spatial( + 64, + ax0_0_ax1_0_0_ax2_0_0_fused % 64 // 16 * 16 + + ax0_1_ax1_0_1_ax2_0_1_fused % 2 * 8 + + ax0_2_ax1_0_2_ax2_0_2_fused % 2 * 4 + + ax1_0_3 * 2 + + ax1_0_4, + ) + v2_o = T.axis.spatial( + 64, + (ax0_0_ax1_0_0_ax2_0_0_fused % 16 + 0 + 0 + ax2_0_3) * 4 + + ax2_0_4, + ) + v3_o = T.axis.reduce(64, ax3_0_0 * 2 + ax3_0_1 + ax3_0_2) + T.reads( + Z_wmma_accumulator[ + v1_o * 16 : v1_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16 + ], + X_shared_wmma_matrix_a[ + v1_o * 16 : v1_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16 + ], + Y_shared_wmma_matrix_b[ + v3_o * 16 : v3_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16 + ], + ) + T.writes( + Z_wmma_accumulator[ + v1_o * 16 : v1_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16 + ] + ) + T.block_attr( + { + "meta_schedule.thread_extent_high_inclusive": 1024, + "meta_schedule.thread_extent_low_inclusive": 32, + "warp_execution": 1, + } + ) + A_2 = T.match_buffer( + X_shared_wmma_matrix_a[ + v1_o * 16 : v1_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16 + ], + [16, 16], + dtype="float16", + scope="wmma.matrix_a", + offset_factor=16, + ) + B = T.match_buffer( + Y_shared_wmma_matrix_b[ + v3_o * 16 : v3_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16 + ], + [16, 16], + dtype="float16", + scope="wmma.matrix_b", + offset_factor=16, + ) + C_3 = T.match_buffer( + Z_wmma_accumulator[ + v1_o * 16 : v1_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16 + ], + [16, 16], + dtype="float32", + scope="wmma.accumulator", + offset_factor=16, + ) + T.evaluate( + T.tvm_mma_sync( + C_3.data, + C_3.elem_offset // 256 + C_3.elem_offset % 256 // 16, + A_2.data, + A_2.elem_offset // 256, + B.data, + B.elem_offset // 256, + C_3.data, + C_3.elem_offset // 256 + C_3.elem_offset % 256 // 16, + dtype="handle", + ) + ) + for ax0_0, ax1_0 in T.grid(4, 4): + with T.block("Z_wmma.accumulator_o"): + v0_o = T.axis.spatial( + 64, + ax0_0_ax1_0_0_ax2_0_0_fused // 16 * 16 + + ax0_1_ax1_0_1_ax2_0_1_fused * 8 + + ax0_2_ax1_0_2_ax2_0_2_fused * 4 + + ax0_0, + ) + v1_o = T.axis.spatial(64, ax0_0_ax1_0_0_ax2_0_0_fused % 16 * 4 + ax1_0) + T.reads( + Z_wmma_accumulator[ + v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16 + ] + ) + T.writes(Z[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + A_3 = T.match_buffer( + Z_wmma_accumulator[ + v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16 + ], + [16, 16], + dtype="float32", + scope="wmma.accumulator", + offset_factor=16, + ) + C_4 = T.match_buffer( + Z[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], + [16, 16], + dtype="float32", + strides=[s1_2, s0_2], + offset_factor=16, + ) + T.evaluate( + T.tvm_store_matrix_sync( + A_3.data, + 16, + 16, + 16, + A_3.elem_offset // 256 + A_3.elem_offset % 256 // 16, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + C_4.data, + C_4.elem_offset, + s1_2 * 16, + 2, + dtype="handle", + ), + s1_2, + "row_major", + dtype="handle", + ) + ) -def test_postproc_verify_gpu_3(): - mod = Conv2dCuda3 - ctx = _create_context(mod, target=_target()) - sch = tir.Schedule(mod, debug_mask="all") - # Should fail due to too many threads per block (large - # threadIdx.x extent). - assert not ctx.postprocs[0].apply(sch) +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant -def test_postproc_verify_gpu_4(): - mod = GmmCuda0 +@pytest.mark.parametrize("mod", [Conv2dCuda0, Conv2dCuda1, GmmCuda0, GMMCUDATensorCore]) +def test_postproc_check_pass(mod): ctx = _create_context(mod, target=_target()) sch = tir.Schedule(mod, debug_mask="all") assert ctx.postprocs[0].apply(sch) -def test_postproc_verify_gpu_5(): - mod = GmmCuda1 - ctx = _create_context(mod, target=_target()) - sch = tir.Schedule(mod, debug_mask="all") - assert not ctx.postprocs[0].apply(sch) - - -def test_postproc_verify_gpu_6(): - mod = GmmCuda2 +@pytest.mark.parametrize( + "mod", + [ + Conv2dCuda2, # Should fail due to too much local memory per block (large Apad_shared allocation) + Conv2dCuda3, # Should fail due to too many threads per block (large threadIdx.x extent) + GmmCuda1, + GmmCuda2, + ], +) +def test_postproc_check_fail(mod): ctx = _create_context(mod, target=_target()) sch = tir.Schedule(mod, debug_mask="all") assert not ctx.postprocs[0].apply(sch)