diff --git a/tests/python/unittest/test_meta_schedule_space_cpu.py b/tests/python/unittest/test_meta_schedule_space_cpu.py index 12aa150f576b..7d601a7b0b02 100644 --- a/tests/python/unittest/test_meta_schedule_space_cpu.py +++ b/tests/python/unittest/test_meta_schedule_space_cpu.py @@ -1079,6 +1079,128 @@ def dil_2(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, ) +def test_cpu_gmm(): + # fmt: off + @T.prim_func + def gmm_0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64}) + Z_global = T.alloc_buffer([1, 128, 128], dtype="float32") + for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1 in T.grid(1, 4, 2, 1, 1, 8): + for i3_0, i0_2, i1_2, i2_2, i3_1, i0_3, i1_3, i2_3 in T.grid(128, 1, 16, 1, 1, 1, 2, 8): + with T.block("Z"): + b = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3) + i = T.axis.spatial(128, i1_0 * 32 + i1_1 * 32 + i1_2 * 2 + i1_3) + j = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + i2_2 * 8 + i2_3) + k = T.axis.reduce(128, i3_1 + i3_0) + T.reads(X[b, i, k], Y[b, k, j]) + T.writes(Z_global[b, i, j]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + with T.init(): + Z_global[b, i, j] = T.float32(0) + Z_global[b, i, j] = Z_global[b, i, j] + X[b, i, k] * Y[b, k, j] + for ax0, ax1, ax2 in T.grid(1, 32, 8): + with T.block("Z_global"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(128, i1_0 * 32 + ax1) + v2 = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + ax2) + T.reads(Z_global[v0, v1, v2]) + T.writes(Z[v0, v1, v2]) + Z[v0, v1, v2] = Z_global[v0, v1, v2] + @T.prim_func + def gmm_1(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64}) + Z_global = T.alloc_buffer([1, 128, 128], dtype="float32") + for i0_0, i1_0, i2_0 in T.grid(1, 4, 2): + for i0_1, i1_1, i2_1, i3_0, i0_2, i1_2, i2_2, i3_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 8, 128, 1, 16, 1, 1, 1, 2, 8): + with T.block("Z"): + b = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3) + i = T.axis.spatial(128, i1_0 * 32 + i1_1 * 32 + i1_2 * 2 + i1_3) + j = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + i2_2 * 8 + i2_3) + k = T.axis.reduce(128, i3_1 + i3_0) + T.reads(X[b, i, k], Y[b, k, j]) + T.writes(Z_global[b, i, j]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + with T.init(): + Z_global[b, i, j] = T.float32(0) + Z_global[b, i, j] = Z_global[b, i, j] + X[b, i, k] * Y[b, k, j] + for ax0, ax1, ax2 in T.grid(1, 32, 64): + with T.block("Z_global"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(128, i1_0 * 32 + ax1) + v2 = T.axis.spatial(128, i2_0 * 64 + ax2) + T.reads(Z_global[v0, v1, v2]) + T.writes(Z[v0, v1, v2]) + Z[v0, v1, v2] = Z_global[v0, v1, v2] + @T.prim_func + def gmm_2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64}) + for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1, i3_0, i0_2, i1_2, i2_2, i3_1, i0_3, i1_3, i2_3 in T.grid(1, 4, 2, 1, 1, 8, 128, 1, 16, 1, 1, 1, 2, 8): + with T.block("Z"): + b = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3) + i = T.axis.spatial(128, i1_0 * 32 + i1_1 * 32 + i1_2 * 2 + i1_3) + j = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + i2_2 * 8 + i2_3) + k = T.axis.reduce(128, i3_1 + i3_0) + T.reads(X[b, i, k], Y[b, k, j]) + T.writes(Z[b, i, j]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + with T.init(): + Z[b, i, j] = T.float32(0) + Z[b, i, j] = Z[b, i, j] + X[b, i, k] * Y[b, k, j] + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [1, 1, 1, 1]), + ("SamplePerfectTile", [4, 1, 16, 2]), + ("SamplePerfectTile", [2, 8, 1, 8]), + ("SamplePerfectTile", [128, 1]), + ("SampleCategorical", 1), + ] + decision_1 = [ + ("SamplePerfectTile", [1, 1, 1, 1]), + ("SamplePerfectTile", [4, 1, 16, 2]), + ("SamplePerfectTile", [2, 8, 1, 8]), + ("SamplePerfectTile", [128, 1]), + ("SampleCategorical", 1), + ] + decision_2 = [ + ("SamplePerfectTile", [1, 1, 1, 1]), + ("SamplePerfectTile", [4, 1, 16, 2]), + ("SamplePerfectTile", [2, 8, 1, 8]), + ("SamplePerfectTile", [128, 1]), + ("SampleCategorical", 1), + ] + mod = create_te_workload("GMM", 0) + actual = ms.TuneContext( + mod=mod, + target=_target(), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules="default", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[gmm_0, gmm_1, gmm_2], + expected_decisions=[decision_0, decision_1, decision_2], + ) + + if __name__ == "__main__": test_cpu_c1d() test_cpu_c2d() @@ -1086,3 +1208,4 @@ def dil_2(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, test_cpu_cap() test_cpu_dep() test_cpu_dil() + test_cpu_gmm() diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py index 7323bc441fd8..3bf2666cdc01 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda.py @@ -572,6 +572,87 @@ def dil_0(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, ) +def test_cuda_gmm(): + # fmt: off + @T.prim_func + def gmm_0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.unroll_explicit":1024}) + Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") + X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + for i0_0_i1_0_i2_0_fused in T.thread_binding(1, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_fused in T.thread_binding(32, thread="vthread.x"): + for i0_2_i1_2_i2_2_fused in T.thread_binding(2, thread="threadIdx.x"): + for i3_0 in T.serial(1): + for ax0_ax1_ax2_fused in T.serial(16384): + with T.block("X_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, ax0_ax1_ax2_fused // 128) + v2 = T.axis.spatial(128, ax0_ax1_ax2_fused % 128) + T.reads(X[v0, v1, v2]) + T.writes(X_shared[v0, v1, v2]) + T.block_attr({"meta_schedule.cooperative_fetch":2}) + X_shared[v0, v1, v2] = X[v0, v1, v2] + for ax0_ax1_ax2_fused in T.serial(16384): + with T.block("Y_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, ax0_ax1_ax2_fused // 128) + v2 = T.axis.spatial(128, ax0_ax1_ax2_fused % 128) + T.reads(Y[v0, v1, v2]) + T.writes(Y_shared[v0, v1, v2]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + Y_shared[v0, v1, v2] = Y[v0, v1, v2] + for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(32, 1, 2, 64, 4, 1, 2, 1): + with T.block("Z"): + b = T.axis.spatial(1, i0_4 + i0_3) + i = T.axis.spatial(128, i0_1_i1_1_i2_1_fused * 4 + i1_3 * 2 + i1_4) + j = T.axis.spatial(128, i2_4 + i0_2_i1_2_i2_2_fused * 64 + i2_3) + k = T.axis.reduce(128, i3_0 * 128 + i3_1 * 4 + i3_2) + T.reads(X_shared[b, i, k], Y_shared[b, k, j]) + T.writes(Z_local[b, i, j]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + with T.init(): + Z_local[b, i, j] = T.float32(0) + Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] + for ax0, ax1, ax2 in T.grid(1, 4, 64): + with T.block("Z_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(128, i0_1_i1_1_i2_1_fused * 4 + ax1) + v2 = T.axis.spatial(128, i0_2_i1_2_i2_2_fused * 64 + ax2) + T.reads(Z_local[v0, v1, v2]) + T.writes(Z[v0, v1, v2]) + Z[v0, v1, v2] = Z_local[v0, v1, v2] + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [1, 1, 1, 1, 1]), + ("SamplePerfectTile", [1, 32, 1, 2, 2]), + ("SamplePerfectTile", [1, 1, 2, 64, 1]), + ("SamplePerfectTile", [1, 32, 4]), + ("SampleCategorical", 1), + ("SampleCategorical", 0), + ("SampleCategorical", 4), + ] + mod = create_te_workload("GMM", 0) + actual = ms.TuneContext( + mod=mod, + target=_target(), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules="default", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[gmm_0], + expected_decisions=[decision_0], + ) + + if __name__ == "__main__": test_cuda_c1d() test_cuda_c2d() @@ -579,3 +660,4 @@ def dil_0(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, test_cuda_cap() test_cuda_dep() test_cuda_dil() + test_cuda_gmm()