diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 324eedafb98a1..54407c46c881f 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -87,6 +87,23 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target"; } } + if (Optional opt_sm = context->target.value()->GetAttr("arch")) { + std::string sm = opt_sm.value(); + if (support::StartsWith(sm, "sm_")) { + sm = sm.substr(3); + try { + // only sm_80 or higher supports async memcopy + if (std::stoi(sm) >= 80) { + // only stage = 4 & 5 is tested. all integer that is bigger than 2 + // is theoretically feasible, but no guarantee for great performance. + this->stages.insert(this->stages.end(), {4, 5}); + } + } catch (const std::invalid_argument& e) { + LOG(WARNING) << "ValueError: Unable to parse `target.arch`: " << sm + << ". Details: " << e.what(); + } + } + } logger = context->logger; } @@ -115,6 +132,8 @@ std::vector MultiLevelTilingNode::ApplySubRules(std::vector states states = SubRule(std::move(states), [&](State state) { return TileLoopNest(std::move(state)); }); states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(std::move(state)); }); states = SubRule(std::move(states), [&](State state) { return AddReadReuse(std::move(state)); }); + states = + SubRule(std::move(states), [&](State state) { return AddAsyncPipeline(std::move(state)); }); return states; } @@ -280,6 +299,43 @@ std::vector MultiLevelTilingNode::AddReadReuse(State state) const { return results; } +std::vector MultiLevelTilingNode::AddAsyncPipeline(State state) const { + // For arch that does not support async pipeline, this->stages will be an empty vector + if (r_indices_.size() < 1 || this->stages.empty()) { + return {state}; + } + // Current only support default config used by ScheduleRule::DefaultCUDA + // @see src/meta_schedule/schedule_rule/schedule_rule.cc + // check the reduce loop contains exactly 3 for loops + // therefore it matches the notation array size in the following code + tir::StmtSRef r_loop_sref = state->sch->GetSRef(state->tiles[r_indices_[0]].back()); + const tir::ForNode* r_for_loop = TVM_SREF_TO_FOR(r_loop_sref); + Array seq = Downcast(r_for_loop->body)->seq; + if (seq.size() != 3) { + return {state}; + } + for (auto& stmt : seq) { + if (!stmt.as()) { + return {state}; + } + } + + std::vector ret; + ret.push_back(state); + for (int stage : this->stages) { + State new_state = state->Copy(); + LoopRV r_loop_fused = new_state->sch->Fuse(new_state->tiles[r_indices_[0]]); + new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_stage, + Array{0, 0, stage - 2}); + new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_order, + Array{0, 1, 2}); + new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_async_stages, + Array{0}); + ret.push_back(std::move(new_state)); + } + return ret; +} + void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch, const tir::BlockRV& block) const { // Filter out invalid vector lanes according to the data type. diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index d8725a3060b1e..ff38756ff06be 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -148,6 +148,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode { std::vector TileLoopNest(State state) const; // SubRule 3. add read cache std::vector AddReadReuse(State state) const; + // SubRule 4. add async pipeline + std::vector AddAsyncPipeline(State state) const; // Do nothing; Inherited from ScheduleRuleNode void InitializeWithTuneContext(const TuneContext& context) final; @@ -192,6 +194,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode { int thread_warp_size_; /*! \brief The maximum number of threads to be used size of a thread warp */ int max_threads_per_block_; + /*! \brief All available async pipeline stages. */ + std::vector stages; /*! \brief The logging function */ PackedFunc logger; /*! \brief The function to overwrite the default condition for applying MultiLevelTiling. */ diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py index 66eb819122932..497915cd65646 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py @@ -365,7 +365,7 @@ def cuda_matmul_0( actual = generate_design_space( kind="cuda", mod=mod, - target=Target("nvidia/geforce-rtx-3080"), + target=Target("nvidia/geforce-rtx-2080"), # disable async trace using sm75 types=ms.schedule_rule.MultiLevelTiling, ) check_sketches( @@ -483,7 +483,7 @@ def cuda_matmul_relu_0( actual = generate_design_space( kind="cuda", mod=mod, - target=Target("nvidia/geforce-rtx-3080"), + target=Target("nvidia/geforce-rtx-2080"), # disable async trace using sm75 types=ms.schedule_rule.MultiLevelTiling, ) check_sketches( @@ -723,7 +723,7 @@ def cache_read_specify_consumer_0( space = generate_design_space( kind="cuda", mod=mod, - target=Target("nvidia/geforce-rtx-3080"), + target=Target("nvidia/geforce-rtx-2080"), # disable async trace using sm75 types=ms.schedule_rule.MultiLevelTiling, ) check_sketches( diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py index 241fe63e1da00..bc674064d1d66 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda.py @@ -27,7 +27,7 @@ def _target(): - return Target("nvidia/geforce-rtx-3070") + return Target("nvidia/geforce-rtx-2080") # disable async trace using sm75 def _design_space(mod): diff --git a/tests/python/unittest/test_meta_schedule_space_cuda_async.py b/tests/python/unittest/test_meta_schedule_space_cuda_async.py new file mode 100644 index 0000000000000..d31d62669687e --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_space_cuda_async.py @@ -0,0 +1,340 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for MetaSchedule search space on CUDA""" +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.space_generation import ( + check_sketches, + generate_design_space, + print_sketches, +) +from tvm.meta_schedule.testing.te_workload import create_te_workload +from tvm.script import tir as T +from tvm.target import Target + + +def _target(): + return Target("nvidia/geforce-rtx-3070") + + +def _design_space(mod): + return generate_design_space( + kind="cuda", + mod=mod, + target=_target(), + types=ms.ScheduleRule, + ) + + +def get_c2d_prim_func(stage: int): + if stage == 0: + # fmt: off + @T.prim_func + def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.unroll_explicit": 1024}) + conv2d_nhwc_local = T.alloc_buffer((1, 112, 112, 64), scope="local") + PadInput_shared = T.alloc_buffer((1, 230, 230, 3), scope="shared") + weight_shared = T.alloc_buffer((7, 7, 3, 64), scope="shared") + for n_0_h_0_w_0_co_0_fused in T.thread_binding(112, thread="blockIdx.x"): + for n_1_h_1_w_1_co_1_fused in T.thread_binding(8, thread="vthread.x"): + for n_2_h_2_w_2_co_2_fused in T.thread_binding(64, thread="threadIdx.x"): + for rh_0, rw_0, rc_0 in T.grid(1, 1, 3): + for ax0_ax1_ax2_ax3_fused in range(693): + with T.block("PadInput_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(230, n_0_h_0_w_0_co_0_fused // 8 * 16 + ax0_ax1_ax2_ax3_fused // 33) + v2 = T.axis.spatial(230, n_0_h_0_w_0_co_0_fused % 8 * 28 + ax0_ax1_ax2_ax3_fused % 33) + v3 = T.axis.spatial(3, rc_0) + T.reads(inputs[v0, v1 - 3, v2 - 3, v3]) + T.writes(PadInput_shared[v0, v1, v2, v3]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + PadInput_shared[v0, v1, v2, v3] = T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, inputs[v0, v1 - 3, v2 - 3, v3], T.float32(0)) + for ax0_ax1_ax2_ax3_fused in range(3136): + with T.block("weight_shared"): + v0 = T.axis.spatial(7, ax0_ax1_ax2_ax3_fused // 448) + v1 = T.axis.spatial(7, ax0_ax1_ax2_ax3_fused % 448 // 64) + v2 = T.axis.spatial(3, rc_0) + v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) + T.reads(weight[v0, v1, v2, v3]) + T.writes(weight_shared[v0, v1, v2, v3]) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] + for rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3, rh_2, rw_2, rc_2, n_4, h_4, w_4, co_4 in T.grid(7, 1, 1, 1, 1, 14, 1, 1, 7, 1, 1, 1, 1, 1): + with T.block("conv2d_nhwc"): + v_n = T.axis.spatial(1, n_3 + n_4) + v_h = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 + n_2_h_2_w_2_co_2_fused // 16 + h_3 + h_4) + v_w = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused % 8 * 14 + w_3 + w_4) + v_co = T.axis.spatial(64, co_3 + co_4 + n_1_h_1_w_1_co_1_fused % 4 * 16 + n_2_h_2_w_2_co_2_fused % 16) + v_rh = T.axis.reduce(7, rh_0 * 7 + rh_1 + rh_2) + v_rw = T.axis.reduce(7, rw_0 * 7 + rw_1 * 7 + rw_2) + v_rc = T.axis.reduce(3, rc_1 + rc_2 + rc_0) + T.reads(PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight_shared[v_rh, v_rw, v_rc, v_co]) + T.writes(conv2d_nhwc_local[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + with T.init(): + conv2d_nhwc_local[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_nhwc_local[v_n, v_h, v_w, v_co] = conv2d_nhwc_local[v_n, v_h, v_w, v_co] + PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight_shared[v_rh, v_rw, v_rc, v_co] + for ax0, ax1, ax2, ax3 in T.grid(1, 1, 14, 1): + with T.block("conv2d_nhwc_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 + n_2_h_2_w_2_co_2_fused // 16 + ax1) + v2 = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused % 8 * 14 + ax2) + v3 = T.axis.spatial(64, n_1_h_1_w_1_co_1_fused % 4 * 16 + n_2_h_2_w_2_co_2_fused % 16 + ax3) + T.reads(conv2d_nhwc_local[v0, v1, v2, v3]) + T.writes(conv2d_nhwc[v0, v1, v2, v3]) + conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_local[v0, v1, v2, v3] + + # fmt: on + else: + # fmt: off + @T.prim_func + def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.unroll_explicit": 1024}) + conv2d_nhwc_local = T.alloc_buffer((1, 112, 112, 64), scope="local") + PadInput_shared = T.alloc_buffer((1, 230, 230, 3), scope="shared") + weight_shared = T.alloc_buffer((7, 7, 3, 64), scope="shared") + for n_0_h_0_w_0_co_0_fused in T.thread_binding(112, thread="blockIdx.x"): + for n_1_h_1_w_1_co_1_fused in T.thread_binding(8, thread="vthread.x"): + for n_2_h_2_w_2_co_2_fused in T.thread_binding(64, thread="threadIdx.x"): + for rh_0_rw_0_rc_0_fused in T.serial(3, annotations={"software_pipeline_async_stages": [0], "software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, stage - 2]}): + for ax0_ax1_ax2_ax3_fused in range(693): + with T.block("PadInput_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(230, n_0_h_0_w_0_co_0_fused // 8 * 16 + ax0_ax1_ax2_ax3_fused // 33) + v2 = T.axis.spatial(230, n_0_h_0_w_0_co_0_fused % 8 * 28 + ax0_ax1_ax2_ax3_fused % 33) + v3 = T.axis.spatial(3, rh_0_rw_0_rc_0_fused) + T.reads(inputs[v0, v1 - 3, v2 - 3, v3]) + T.writes(PadInput_shared[v0, v1, v2, v3]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + PadInput_shared[v0, v1, v2, v3] = T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, inputs[v0, v1 - 3, v2 - 3, v3], T.float32(0)) + for ax0_ax1_ax2_ax3_fused in range(3136): + with T.block("weight_shared"): + v0 = T.axis.spatial(7, ax0_ax1_ax2_ax3_fused // 448) + v1 = T.axis.spatial(7, ax0_ax1_ax2_ax3_fused % 448 // 64) + v2 = T.axis.spatial(3, rh_0_rw_0_rc_0_fused) + v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) + T.reads(weight[v0, v1, v2, v3]) + T.writes(weight_shared[v0, v1, v2, v3]) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] + for rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3, rh_2, rw_2, rc_2, n_4, h_4, w_4, co_4 in T.grid(7, 1, 1, 1, 1, 14, 1, 1, 7, 1, 1, 1, 1, 1): + with T.block("conv2d_nhwc"): + v_n = T.axis.spatial(1, n_4 + n_3) + v_h = T.axis.spatial(112, h_3 + h_4 + n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 + n_2_h_2_w_2_co_2_fused // 16) + v_w = T.axis.spatial(112, w_4 + n_0_h_0_w_0_co_0_fused % 8 * 14 + w_3) + v_co = T.axis.spatial(64, co_4 + n_1_h_1_w_1_co_1_fused % 4 * 16 + n_2_h_2_w_2_co_2_fused % 16 + co_3) + v_rh = T.axis.reduce(7, rh_2 + rh_1) + v_rw = T.axis.reduce(7, rw_1 * 7 + rw_2) + v_rc = T.axis.reduce(3, rc_2 + rh_0_rw_0_rc_0_fused + rc_1) + T.reads(PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight_shared[v_rh, v_rw, v_rc, v_co]) + T.writes(conv2d_nhwc_local[v_n, v_h, v_w, v_co]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + with T.init(): + conv2d_nhwc_local[v_n, v_h, v_w, v_co] = T.float32(0) + conv2d_nhwc_local[v_n, v_h, v_w, v_co] = conv2d_nhwc_local[v_n, v_h, v_w, v_co] + PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight_shared[v_rh, v_rw, v_rc, v_co] + for ax0, ax1, ax2, ax3 in T.grid(1, 1, 14, 1): + with T.block("conv2d_nhwc_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 + n_2_h_2_w_2_co_2_fused // 16 + ax1) + v2 = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused % 8 * 14 + ax2) + v3 = T.axis.spatial(64, n_1_h_1_w_1_co_1_fused % 4 * 16 + n_2_h_2_w_2_co_2_fused % 16 + ax3) + T.reads(conv2d_nhwc_local[v0, v1, v2, v3]) + T.writes(conv2d_nhwc[v0, v1, v2, v3]) + conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_local[v0, v1, v2, v3] + # fmt: on + return c2d + + +def test_cuda_c2d(): + c2d_decision = [ + ("SamplePerfectTile", [1, 1, 1, 1, 1]), + ("SamplePerfectTile", [14, 2, 4, 1, 1]), + ("SamplePerfectTile", [8, 1, 1, 14, 1]), + ("SamplePerfectTile", [1, 4, 16, 1, 1]), + ("SamplePerfectTile", [1, 7, 1]), + ("SamplePerfectTile", [1, 1, 7]), + ("SamplePerfectTile", [3, 1, 1]), + ("SampleCategorical", 3), + ("SampleCategorical", 2), + ("SampleCategorical", 4), + ] + + mod = create_te_workload("C2D", 0) + actual = _design_space(mod) + check_sketches( + mod, + sketches=actual, + expected_mods=[ + get_c2d_prim_func(stage=0), + get_c2d_prim_func(stage=4), + get_c2d_prim_func(stage=5), + ], + expected_decisions=[c2d_decision, c2d_decision, c2d_decision], + ) + + +def get_gmm_prim_func(stage: int): + if stage == 0: + # fmt: off + @T.prim_func + def gmm(A: T.Buffer((1, 1024, 1024), "float32"), B: T.Buffer((1, 1024, 1024), "float32"), Y: T.Buffer((1, 1024, 1024), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.unroll_explicit": 16}) + Y_local = T.alloc_buffer((1, 1024, 1024), scope="local") + A_shared = T.alloc_buffer((1, 1024, 1024), scope="shared") + B_shared = T.alloc_buffer((1, 1024, 1024), scope="shared") + for b_0_i_0_j_0_fused in T.thread_binding(256, thread="blockIdx.x"): + for b_1_i_1_j_1_fused in T.thread_binding(32, thread="vthread.x"): + for b_2_i_2_j_2_fused in T.thread_binding(64, thread="threadIdx.x"): + for k_0 in range(64): + for ax0_ax1_ax2_fused in range(1024): + with T.block("A_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1024, b_0_i_0_j_0_fused // 16 * 64 + ax0_ax1_ax2_fused // 16) + v2 = T.axis.spatial(1024, k_0 * 16 + ax0_ax1_ax2_fused % 16) + T.reads(A[v0, v1, v2]) + T.writes(A_shared[v0, v1, v2]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + A_shared[v0, v1, v2] = A[v0, v1, v2] + for ax0_ax1_ax2_fused in range(1024): + with T.block("B_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1024, k_0 * 16 + ax0_ax1_ax2_fused // 64) + v2 = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + ax0_ax1_ax2_fused % 64) + T.reads(B[v0, v1, v2]) + T.writes(B_shared[v0, v1, v2]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + B_shared[v0, v1, v2] = B[v0, v1, v2] + for k_1, b_3, i_3, j_3, k_2, b_4, i_4, j_4 in T.grid(2, 1, 1, 1, 8, 1, 1, 2): + with T.block("Y"): + v_b = T.axis.spatial(1, b_4 + b_3) + v_i = T.axis.spatial(1024, b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused // 8 + i_3 + i_4) + v_j = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused % 8 * 2 + j_3 * 2 + j_4) + v_k = T.axis.reduce(1024, k_0 * 16 + k_1 * 8 + k_2) + T.reads(A_shared[v_b, v_i, v_k], B_shared[v_b, v_k, v_j]) + T.writes(Y_local[v_b, v_i, v_j]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + with T.init(): + Y_local[v_b, v_i, v_j] = T.float32(0) + Y_local[v_b, v_i, v_j] = Y_local[v_b, v_i, v_j] + A_shared[v_b, v_i, v_k] * B_shared[v_b, v_k, v_j] + for ax0, ax1, ax2 in T.grid(1, 1, 2): + with T.block("Y_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(1024, b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused // 8 + ax1) + v2 = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused % 8 * 2 + ax2) + T.reads(Y_local[v0, v1, v2]) + T.writes(Y[v0, v1, v2]) + Y[v0, v1, v2] = Y_local[v0, v1, v2] + + # fmt: on + else: + # fmt: off + @T.prim_func + def gmm(A: T.Buffer((1, 1024, 1024), "float32"), B: T.Buffer((1, 1024, 1024), "float32"), Y: T.Buffer((1, 1024, 1024), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.unroll_explicit": 16}) + Y_local = T.alloc_buffer((1, 1024, 1024), scope="local") + A_shared = T.alloc_buffer((1, 1024, 1024), scope="shared") + B_shared = T.alloc_buffer((1, 1024, 1024), scope="shared") + for b_0_i_0_j_0_fused in T.thread_binding(256, thread="blockIdx.x"): + for b_1_i_1_j_1_fused in T.thread_binding(32, thread="vthread.x"): + for b_2_i_2_j_2_fused in T.thread_binding(64, thread="threadIdx.x"): + for k_0_fused in T.serial(64, annotations={"software_pipeline_async_stages": [0], "software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, stage - 2]}): + for ax0_ax1_ax2_fused in range(1024): + with T.block("A_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1024, b_0_i_0_j_0_fused // 16 * 64 + ax0_ax1_ax2_fused // 16) + v2 = T.axis.spatial(1024, k_0_fused * 16 + ax0_ax1_ax2_fused % 16) + T.reads(A[v0, v1, v2]) + T.writes(A_shared[v0, v1, v2]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + A_shared[v0, v1, v2] = A[v0, v1, v2] + for ax0_ax1_ax2_fused in range(1024): + with T.block("B_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1024, k_0_fused * 16 + ax0_ax1_ax2_fused // 64) + v2 = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + ax0_ax1_ax2_fused % 64) + T.reads(B[v0, v1, v2]) + T.writes(B_shared[v0, v1, v2]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + B_shared[v0, v1, v2] = B[v0, v1, v2] + for k_1, b_3, i_3, j_3, k_2, b_4, i_4, j_4 in T.grid(2, 1, 1, 1, 8, 1, 1, 2): + with T.block("Y"): + v_b = T.axis.spatial(1, b_3 + b_4) + v_i = T.axis.spatial(1024, i_3 + i_4 + b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused // 8) + v_j = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused % 8 * 2 + j_3 * 2 + j_4) + v_k = T.axis.reduce(1024, k_0_fused * 16 + k_1 * 8 + k_2) + T.reads(A_shared[v_b, v_i, v_k], B_shared[v_b, v_k, v_j]) + T.writes(Y_local[v_b, v_i, v_j]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + with T.init(): + Y_local[v_b, v_i, v_j] = T.float32(0) + Y_local[v_b, v_i, v_j] = Y_local[v_b, v_i, v_j] + A_shared[v_b, v_i, v_k] * B_shared[v_b, v_k, v_j] + for ax0, ax1, ax2 in T.grid(1, 1, 2): + with T.block("Y_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(1024, b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused // 8 + ax1) + v2 = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused % 8 * 2 + ax2) + T.reads(Y_local[v0, v1, v2]) + T.writes(Y[v0, v1, v2]) + Y[v0, v1, v2] = Y_local[v0, v1, v2] + + # fmt: on + return gmm + + +def test_cuda_gmm(): + gmm_decision = [ + ("SamplePerfectTile", [1, 1, 1, 1, 1]), + ("SamplePerfectTile", [16, 8, 8, 1, 1]), + ("SamplePerfectTile", [16, 4, 8, 1, 2]), + ("SamplePerfectTile", [64, 2, 8]), + ("SampleCategorical", 3), + ("SampleCategorical", 3), + ("SampleCategorical", 1), + ] + + mod = create_te_workload("GMM", 3) + actual = _design_space(mod) + check_sketches( + mod, + sketches=actual, + expected_mods=[ + get_gmm_prim_func(stage=0), + get_gmm_prim_func(stage=4), + get_gmm_prim_func(stage=5), + ], + expected_decisions=[gmm_decision, gmm_decision, gmm_decision], + ) + + +if __name__ == "__main__": + test_cuda_c2d() + test_cuda_gmm() diff --git a/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py b/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py index 87a8fcac98006..e8ed3bb8b2a1e 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py @@ -27,7 +27,7 @@ def _target(): - return Target("nvidia/geforce-rtx-3070") + return Target("nvidia/geforce-rtx-2080") # disable async trace using sm75 def _design_space(mod):