Skip to content

Commit

Permalink
feat: make rule more general
Browse files Browse the repository at this point in the history
  • Loading branch information
Celve committed Jan 5, 2024
1 parent 0f030cd commit 84ba006
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 77 deletions.
12 changes: 3 additions & 9 deletions python/tvm/dlight/gpu/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,9 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
if not identify_rsqrt_block(sch.get(rsqrt)):
return None

for name in [read, sqr, redsum, norm]:
sch.transform_block_layout(
block=name,
index_map=lambda v_ax0, v_ax1, v_ax2: (
v_ax1,
v_ax2,
),
)
sch.transform_block_layout(block=rsqrt, index_map=lambda v_ax0, v_ax1: (v_ax1,))
for name in [read, sqr, redsum, rsqrt, norm, write]:
loops = sch.get_loops(name)
sch.fuse(*loops[:-1])

block_loop, loops = sch.get_loops(block=read)
thread_loop, _, _ = sch.split(
Expand Down
146 changes: 78 additions & 68 deletions tests/python/dlight/test_gpu_rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,50 +109,55 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T
rsqrt_shared = T.alloc_buffer((1, n), scope="shared")
T_rms_norm_local = T.alloc_buffer((1, n, 4096), scope="local")
data_local = T.alloc_buffer((1, n, 4096), "float16", scope="local")
for ax0 in T.thread_binding(n, thread="blockIdx.x"):
for ax1_0 in T.thread_binding(512, thread="threadIdx.x"):
for ax1_1 in range(1):
for ax1_2 in T.vectorized(8):
for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x"):
for ax2_0 in T.thread_binding(512, thread="threadIdx.x"):
for ax2_1 in range(1):
for ax2_2 in T.vectorized(8):
with T.block("data_local"):
v0 = T.axis.spatial(n, ax0)
v1 = T.axis.spatial(4096, ax1_0 * 8 + ax1_1 * 8 + ax1_2)
T.reads(data[0, v0, v1])
T.writes(data_local[0, v0, v1])
data_local[0, v0, v1] = data[0, v0, v1]
for ax0_1 in range(8):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(n, ax0_ax1_fused)
v2 = T.axis.spatial(4096, ax2_0 * 8 + ax2_1 * 8 + ax2_2)
T.reads(data[v0, v1, v2])
T.writes(data_local[v0, v1, v2])
data_local[v0, v1, v2] = data[v0, v1, v2]
for ax0 in range(8):
with T.block("T_multiply"):
v0 = T.axis.spatial(n, ax0)
v1 = T.axis.spatial(4096, ax1_0 * 8 + ax0_1)
T.reads(data_local[0, v0, v1])
T.writes(T_multiply_local[0, v0, v1])
T_multiply_local[0, v0, v1] = T.Cast("float32", data_local[0, v0, v1]) * T.Cast("float32", data_local[0, v0, v1])
for ax0_1 in range(8):
v_ax0 = T.axis.spatial(1, 0)
v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
v_ax2 = T.axis.spatial(4096, ax2_0 * 8 + ax0)
T.reads(data_local[v_ax0, v_ax1, v_ax2])
T.writes(T_multiply_local[v_ax0, v_ax1, v_ax2])
T_multiply_local[v_ax0, v_ax1, v_ax2] = T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2]) * T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2])
for ax0 in range(8):
with T.block("T_multiply_red"):
v0 = T.axis.spatial(n, ax0)
v1 = T.axis.reduce(4096, ax1_0 * 8 + ax0_1)
T.reads(T_multiply_local[0, v0, v1])
T.writes(T_multiply_red_local[0, v0])
v_ax0 = T.axis.spatial(1, 0)
v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
v_k2 = T.axis.reduce(4096, ax2_0 * 8 + ax0)
T.reads(T_multiply_local[v_ax0, v_ax1, v_k2])
T.writes(T_multiply_red_local[v_ax0, v_ax1])
with T.init():
T_multiply_red_local[0, v0] = T.float32(0)
T_multiply_red_local[0, v0] = T_multiply_red_local[0, v0] + T_multiply_local[0, v0, v1]
T_multiply_red_local[v_ax0, v_ax1] = T.float32(0)
T_multiply_red_local[v_ax0, v_ax1] = T_multiply_red_local[v_ax0, v_ax1] + T_multiply_local[v_ax0, v_ax1, v_k2]
with T.block("rsqrt"):
v0 = T.axis.spatial(n, ax0)
T.reads(T_multiply_red_local[0, v0])
T.writes(rsqrt_shared[0, v0])
rsqrt_shared[0, v0] = T.rsqrt(T_multiply_red_local[0, v0] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))
v_ax0 = T.axis.spatial(1, 0)
v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
T.reads(T_multiply_red_local[v_ax0, v_ax1])
T.writes(rsqrt_shared[v_ax0, v_ax1])
rsqrt_shared[v_ax0, v_ax1] = T.rsqrt(T_multiply_red_local[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))
for ax0_0 in T.thread_binding(512, thread="threadIdx.x"):
for ax0_1, ax0_2 in T.grid(1, 8):
with T.block("T_rms_norm"):
v0 = T.axis.spatial(n, ax0)
v1 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2)
T.reads(rsqrt_shared[0, v0], data_local[0, v0, v1], weight[v1])
T.writes(T_rms_norm_local[0, v0, v1])
T_rms_norm_local[0, v0, v1] = rsqrt_shared[0, v0] * T.Cast("float32", data_local[0, v0, v1]) * T.Cast("float32", weight[v1])
for ax0_1 in T.vectorized(8):
v_ax0 = T.axis.spatial(1, 0)
v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
v_ax2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2)
T.reads(rsqrt_shared[v_ax0, v_ax1], data_local[v_ax0, v_ax1, v_ax2], weight[v_ax2])
T.writes(T_rms_norm_local[v_ax0, v_ax1, v_ax2])
T_rms_norm_local[v_ax0, v_ax1, v_ax2] = rsqrt_shared[v_ax0, v_ax1] * T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2]) * T.Cast("float32", weight[v_ax2])
for ax0 in T.vectorized(8):
with T.block("T_cast_local"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(n, ax0)
v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1)
v1 = T.axis.spatial(n, ax0_ax1_fused)
v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0)
T.reads(T_rms_norm_local[v0, v1, v2])
T.writes(T_cast[v0, v1, v2])
T_cast[v0, v1, v2] = T.Cast("float16", T_rms_norm_local[v0, v1, v2])
Expand Down Expand Up @@ -222,50 +227,55 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T
rsqrt_shared = T.alloc_buffer((1, n), scope="shared")
T_rms_norm_local = T.alloc_buffer((1, n, 4096), scope="local")
data_local = T.alloc_buffer((1, n, 4096), scope="local")
for ax0 in T.thread_binding(n, thread="blockIdx.x"):
for ax1_0 in T.thread_binding(512, thread="threadIdx.x"):
for ax1_1 in range(1):
for ax1_2 in T.vectorized(8):
for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x"):
for ax2_0 in T.thread_binding(512, thread="threadIdx.x"):
for ax2_1 in range(1):
for ax2_2 in T.vectorized(8):
with T.block("data_local"):
v0 = T.axis.spatial(n, ax0)
v1 = T.axis.spatial(4096, ax1_0 * 8 + ax1_1 * 8 + ax1_2)
T.reads(data[0, v0, v1])
T.writes(data_local[0, v0, v1])
data_local[0, v0, v1] = data[0, v0, v1]
for ax0_1 in range(8):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(n, ax0_ax1_fused)
v2 = T.axis.spatial(4096, ax2_0 * 8 + ax2_1 * 8 + ax2_2)
T.reads(data[v0, v1, v2])
T.writes(data_local[v0, v1, v2])
data_local[v0, v1, v2] = data[v0, v1, v2]
for ax0 in range(8):
with T.block("T_multiply"):
v0 = T.axis.spatial(n, ax0)
v1 = T.axis.spatial(4096, ax1_0 * 8 + ax0_1)
T.reads(data_local[0, v0, v1])
T.writes(T_multiply_local[0, v0, v1])
T_multiply_local[0, v0, v1] = data_local[0, v0, v1] * data_local[0, v0, v1]
for ax0_1 in range(8):
v_ax0 = T.axis.spatial(1, 0)
v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
v_ax2 = T.axis.spatial(4096, ax2_0 * 8 + ax0)
T.reads(data_local[v_ax0, v_ax1, v_ax2])
T.writes(T_multiply_local[v_ax0, v_ax1, v_ax2])
T_multiply_local[v_ax0, v_ax1, v_ax2] = data_local[v_ax0, v_ax1, v_ax2] * data_local[v_ax0, v_ax1, v_ax2]
for ax0 in range(8):
with T.block("T_multiply_red"):
v0 = T.axis.spatial(n, ax0)
v1 = T.axis.reduce(4096, ax1_0 * 8 + ax0_1)
T.reads(T_multiply_local[0, v0, v1])
T.writes(T_multiply_red_local[0, v0])
v_ax0 = T.axis.spatial(1, 0)
v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
v_k2 = T.axis.reduce(4096, ax2_0 * 8 + ax0)
T.reads(T_multiply_local[v_ax0, v_ax1, v_k2])
T.writes(T_multiply_red_local[v_ax0, v_ax1])
with T.init():
T_multiply_red_local[0, v0] = T.float32(0)
T_multiply_red_local[0, v0] = T_multiply_red_local[0, v0] + T_multiply_local[0, v0, v1]
T_multiply_red_local[v_ax0, v_ax1] = T.float32(0)
T_multiply_red_local[v_ax0, v_ax1] = T_multiply_red_local[v_ax0, v_ax1] + T_multiply_local[v_ax0, v_ax1, v_k2]
with T.block("rsqrt"):
v0 = T.axis.spatial(n, ax0)
T.reads(T_multiply_red_local[0, v0])
T.writes(rsqrt_shared[0, v0])
rsqrt_shared[0, v0] = T.rsqrt(T_multiply_red_local[0, v0] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))
v_ax0 = T.axis.spatial(1, 0)
v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
T.reads(T_multiply_red_local[v_ax0, v_ax1])
T.writes(rsqrt_shared[v_ax0, v_ax1])
rsqrt_shared[v_ax0, v_ax1] = T.rsqrt(T_multiply_red_local[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))
for ax0_0 in T.thread_binding(512, thread="threadIdx.x"):
for ax0_1, ax0_2 in T.grid(1, 8):
with T.block("T_rms_norm"):
v0 = T.axis.spatial(n, ax0)
v1 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2)
T.reads(rsqrt_shared[0, v0], data_local[0, v0, v1], weight[v1])
T.writes(T_rms_norm_local[0, v0, v1])
T_rms_norm_local[0, v0, v1] = rsqrt_shared[0, v0] * data_local[0, v0, v1] * weight[v1]
for ax0_1 in T.vectorized(8):
v_ax0 = T.axis.spatial(1, 0)
v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
v_ax2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2)
T.reads(rsqrt_shared[v_ax0, v_ax1], data_local[v_ax0, v_ax1, v_ax2], weight[v_ax2])
T.writes(T_rms_norm_local[v_ax0, v_ax1, v_ax2])
T_rms_norm_local[v_ax0, v_ax1, v_ax2] = rsqrt_shared[v_ax0, v_ax1] * data_local[v_ax0, v_ax1, v_ax2] * weight[v_ax2]
for ax0 in T.vectorized(8):
with T.block("T_cast_local"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(n, ax0)
v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1)
v1 = T.axis.spatial(n, ax0_ax1_fused)
v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0)
T.reads(T_rms_norm_local[v0, v1, v2])
T.writes(T_cast[v0, v1, v2])
T_cast[v0, v1, v2] = T_rms_norm_local[v0, v1, v2]
Expand Down

0 comments on commit 84ba006

Please sign in to comment.