From 84ba0065b8dda6fa618f12eee7b48a7705a47a32 Mon Sep 17 00:00:00 2001 From: Celve Date: Fri, 5 Jan 2024 16:43:26 +0000 Subject: [PATCH] feat: make rule more general --- python/tvm/dlight/gpu/rmsnorm.py | 12 +- tests/python/dlight/test_gpu_rmsnorm.py | 146 +++++++++++++----------- 2 files changed, 81 insertions(+), 77 deletions(-) diff --git a/python/tvm/dlight/gpu/rmsnorm.py b/python/tvm/dlight/gpu/rmsnorm.py index 4f6960f3aef4..f8b2bb4a172d 100644 --- a/python/tvm/dlight/gpu/rmsnorm.py +++ b/python/tvm/dlight/gpu/rmsnorm.py @@ -107,15 +107,9 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if not identify_rsqrt_block(sch.get(rsqrt)): return None - for name in [read, sqr, redsum, norm]: - sch.transform_block_layout( - block=name, - index_map=lambda v_ax0, v_ax1, v_ax2: ( - v_ax1, - v_ax2, - ), - ) - sch.transform_block_layout(block=rsqrt, index_map=lambda v_ax0, v_ax1: (v_ax1,)) + for name in [read, sqr, redsum, rsqrt, norm, write]: + loops = sch.get_loops(name) + sch.fuse(*loops[:-1]) block_loop, loops = sch.get_loops(block=read) thread_loop, _, _ = sch.split( diff --git a/tests/python/dlight/test_gpu_rmsnorm.py b/tests/python/dlight/test_gpu_rmsnorm.py index f128c48c06b3..301dac5c66ac 100644 --- a/tests/python/dlight/test_gpu_rmsnorm.py +++ b/tests/python/dlight/test_gpu_rmsnorm.py @@ -109,50 +109,55 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T rsqrt_shared = T.alloc_buffer((1, n), scope="shared") T_rms_norm_local = T.alloc_buffer((1, n, 4096), scope="local") data_local = T.alloc_buffer((1, n, 4096), "float16", scope="local") - for ax0 in T.thread_binding(n, thread="blockIdx.x"): - for ax1_0 in T.thread_binding(512, thread="threadIdx.x"): - for ax1_1 in range(1): - for ax1_2 in T.vectorized(8): + for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x"): + for ax2_0 in T.thread_binding(512, thread="threadIdx.x"): + for ax2_1 in range(1): + for ax2_2 in T.vectorized(8): with T.block("data_local"): - v0 = T.axis.spatial(n, ax0) - v1 = T.axis.spatial(4096, ax1_0 * 8 + ax1_1 * 8 + ax1_2) - T.reads(data[0, v0, v1]) - T.writes(data_local[0, v0, v1]) - data_local[0, v0, v1] = data[0, v0, v1] - for ax0_1 in range(8): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(n, ax0_ax1_fused) + v2 = T.axis.spatial(4096, ax2_0 * 8 + ax2_1 * 8 + ax2_2) + T.reads(data[v0, v1, v2]) + T.writes(data_local[v0, v1, v2]) + data_local[v0, v1, v2] = data[v0, v1, v2] + for ax0 in range(8): with T.block("T_multiply"): - v0 = T.axis.spatial(n, ax0) - v1 = T.axis.spatial(4096, ax1_0 * 8 + ax0_1) - T.reads(data_local[0, v0, v1]) - T.writes(T_multiply_local[0, v0, v1]) - T_multiply_local[0, v0, v1] = T.Cast("float32", data_local[0, v0, v1]) * T.Cast("float32", data_local[0, v0, v1]) - for ax0_1 in range(8): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_ax2 = T.axis.spatial(4096, ax2_0 * 8 + ax0) + T.reads(data_local[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply_local[v_ax0, v_ax1, v_ax2]) + T_multiply_local[v_ax0, v_ax1, v_ax2] = T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2]) * T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2]) + for ax0 in range(8): with T.block("T_multiply_red"): - v0 = T.axis.spatial(n, ax0) - v1 = T.axis.reduce(4096, ax1_0 * 8 + ax0_1) - T.reads(T_multiply_local[0, v0, v1]) - T.writes(T_multiply_red_local[0, v0]) + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_k2 = T.axis.reduce(4096, ax2_0 * 8 + ax0) + T.reads(T_multiply_local[v_ax0, v_ax1, v_k2]) + T.writes(T_multiply_red_local[v_ax0, v_ax1]) with T.init(): - T_multiply_red_local[0, v0] = T.float32(0) - T_multiply_red_local[0, v0] = T_multiply_red_local[0, v0] + T_multiply_local[0, v0, v1] + T_multiply_red_local[v_ax0, v_ax1] = T.float32(0) + T_multiply_red_local[v_ax0, v_ax1] = T_multiply_red_local[v_ax0, v_ax1] + T_multiply_local[v_ax0, v_ax1, v_k2] with T.block("rsqrt"): - v0 = T.axis.spatial(n, ax0) - T.reads(T_multiply_red_local[0, v0]) - T.writes(rsqrt_shared[0, v0]) - rsqrt_shared[0, v0] = T.rsqrt(T_multiply_red_local[0, v0] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + T.reads(T_multiply_red_local[v_ax0, v_ax1]) + T.writes(rsqrt_shared[v_ax0, v_ax1]) + rsqrt_shared[v_ax0, v_ax1] = T.rsqrt(T_multiply_red_local[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) for ax0_0 in T.thread_binding(512, thread="threadIdx.x"): for ax0_1, ax0_2 in T.grid(1, 8): with T.block("T_rms_norm"): - v0 = T.axis.spatial(n, ax0) - v1 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2) - T.reads(rsqrt_shared[0, v0], data_local[0, v0, v1], weight[v1]) - T.writes(T_rms_norm_local[0, v0, v1]) - T_rms_norm_local[0, v0, v1] = rsqrt_shared[0, v0] * T.Cast("float32", data_local[0, v0, v1]) * T.Cast("float32", weight[v1]) - for ax0_1 in T.vectorized(8): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_ax2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2) + T.reads(rsqrt_shared[v_ax0, v_ax1], data_local[v_ax0, v_ax1, v_ax2], weight[v_ax2]) + T.writes(T_rms_norm_local[v_ax0, v_ax1, v_ax2]) + T_rms_norm_local[v_ax0, v_ax1, v_ax2] = rsqrt_shared[v_ax0, v_ax1] * T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2]) * T.Cast("float32", weight[v_ax2]) + for ax0 in T.vectorized(8): with T.block("T_cast_local"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(n, ax0) - v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1) + v1 = T.axis.spatial(n, ax0_ax1_fused) + v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0) T.reads(T_rms_norm_local[v0, v1, v2]) T.writes(T_cast[v0, v1, v2]) T_cast[v0, v1, v2] = T.Cast("float16", T_rms_norm_local[v0, v1, v2]) @@ -222,50 +227,55 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T rsqrt_shared = T.alloc_buffer((1, n), scope="shared") T_rms_norm_local = T.alloc_buffer((1, n, 4096), scope="local") data_local = T.alloc_buffer((1, n, 4096), scope="local") - for ax0 in T.thread_binding(n, thread="blockIdx.x"): - for ax1_0 in T.thread_binding(512, thread="threadIdx.x"): - for ax1_1 in range(1): - for ax1_2 in T.vectorized(8): + for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x"): + for ax2_0 in T.thread_binding(512, thread="threadIdx.x"): + for ax2_1 in range(1): + for ax2_2 in T.vectorized(8): with T.block("data_local"): - v0 = T.axis.spatial(n, ax0) - v1 = T.axis.spatial(4096, ax1_0 * 8 + ax1_1 * 8 + ax1_2) - T.reads(data[0, v0, v1]) - T.writes(data_local[0, v0, v1]) - data_local[0, v0, v1] = data[0, v0, v1] - for ax0_1 in range(8): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(n, ax0_ax1_fused) + v2 = T.axis.spatial(4096, ax2_0 * 8 + ax2_1 * 8 + ax2_2) + T.reads(data[v0, v1, v2]) + T.writes(data_local[v0, v1, v2]) + data_local[v0, v1, v2] = data[v0, v1, v2] + for ax0 in range(8): with T.block("T_multiply"): - v0 = T.axis.spatial(n, ax0) - v1 = T.axis.spatial(4096, ax1_0 * 8 + ax0_1) - T.reads(data_local[0, v0, v1]) - T.writes(T_multiply_local[0, v0, v1]) - T_multiply_local[0, v0, v1] = data_local[0, v0, v1] * data_local[0, v0, v1] - for ax0_1 in range(8): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_ax2 = T.axis.spatial(4096, ax2_0 * 8 + ax0) + T.reads(data_local[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply_local[v_ax0, v_ax1, v_ax2]) + T_multiply_local[v_ax0, v_ax1, v_ax2] = data_local[v_ax0, v_ax1, v_ax2] * data_local[v_ax0, v_ax1, v_ax2] + for ax0 in range(8): with T.block("T_multiply_red"): - v0 = T.axis.spatial(n, ax0) - v1 = T.axis.reduce(4096, ax1_0 * 8 + ax0_1) - T.reads(T_multiply_local[0, v0, v1]) - T.writes(T_multiply_red_local[0, v0]) + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_k2 = T.axis.reduce(4096, ax2_0 * 8 + ax0) + T.reads(T_multiply_local[v_ax0, v_ax1, v_k2]) + T.writes(T_multiply_red_local[v_ax0, v_ax1]) with T.init(): - T_multiply_red_local[0, v0] = T.float32(0) - T_multiply_red_local[0, v0] = T_multiply_red_local[0, v0] + T_multiply_local[0, v0, v1] + T_multiply_red_local[v_ax0, v_ax1] = T.float32(0) + T_multiply_red_local[v_ax0, v_ax1] = T_multiply_red_local[v_ax0, v_ax1] + T_multiply_local[v_ax0, v_ax1, v_k2] with T.block("rsqrt"): - v0 = T.axis.spatial(n, ax0) - T.reads(T_multiply_red_local[0, v0]) - T.writes(rsqrt_shared[0, v0]) - rsqrt_shared[0, v0] = T.rsqrt(T_multiply_red_local[0, v0] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + T.reads(T_multiply_red_local[v_ax0, v_ax1]) + T.writes(rsqrt_shared[v_ax0, v_ax1]) + rsqrt_shared[v_ax0, v_ax1] = T.rsqrt(T_multiply_red_local[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) for ax0_0 in T.thread_binding(512, thread="threadIdx.x"): for ax0_1, ax0_2 in T.grid(1, 8): with T.block("T_rms_norm"): - v0 = T.axis.spatial(n, ax0) - v1 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2) - T.reads(rsqrt_shared[0, v0], data_local[0, v0, v1], weight[v1]) - T.writes(T_rms_norm_local[0, v0, v1]) - T_rms_norm_local[0, v0, v1] = rsqrt_shared[0, v0] * data_local[0, v0, v1] * weight[v1] - for ax0_1 in T.vectorized(8): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_ax2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2) + T.reads(rsqrt_shared[v_ax0, v_ax1], data_local[v_ax0, v_ax1, v_ax2], weight[v_ax2]) + T.writes(T_rms_norm_local[v_ax0, v_ax1, v_ax2]) + T_rms_norm_local[v_ax0, v_ax1, v_ax2] = rsqrt_shared[v_ax0, v_ax1] * data_local[v_ax0, v_ax1, v_ax2] * weight[v_ax2] + for ax0 in T.vectorized(8): with T.block("T_cast_local"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(n, ax0) - v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1) + v1 = T.axis.spatial(n, ax0_ax1_fused) + v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0) T.reads(T_rms_norm_local[v0, v1, v2]) T.writes(T_cast[v0, v1, v2]) T_cast[v0, v1, v2] = T_rms_norm_local[v0, v1, v2]