Skip to content

Commit

Permalink
[DLIGHT][GPU] Improve matmul schedule for adreno (#17430)
Browse files Browse the repository at this point in the history
Improved matmul schedule with layout transpose approach, which improves
as follows -
----Model-------prefill baseline ---------prefill optimized
--Llama-2-7b-------51 tok/sec --------------86 tok/sec
--Llama-3-8b-------48 tok/sec --------------79 tok/sec
--gemma-2b -------140 tok/sec -------------245 tok/sec

---------
  • Loading branch information
krishnaraj36 authored Sep 30, 2024
1 parent d9ee637 commit e808010
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 126 deletions.
108 changes: 61 additions & 47 deletions python/tvm/dlight/gpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tvm.tir import IterVar, PrimExpr, Var
from tvm.tir.analysis import undefined_vars
from tvm.tir.schedule.schedule import BlockRV
from tvm.script import tir as T

from ..base import analysis, BlockInfo, IterInfo
from .base import GPUScheduleRule
Expand Down Expand Up @@ -945,14 +946,14 @@ def get_configs(self, target: Target) -> Config:
):
return Matmul.Config(
block_size_x=32,
block_size_y=8,
block_size_y=4,
vthread_x=1,
vthread_y=1,
micro_size_x=8,
micro_size_y=2,
micro_size_k=16,
vector_size=8,
unroll=4,
unroll=16,
use_shared=False,
storage_align=False,
inner_x=True,
Expand Down Expand Up @@ -1147,7 +1148,7 @@ def get_max_factor(n, factors):
if not (
isinstance(sch.get(n).extent, tir.IntImm)
and isinstance(sch.get(mb).extent, tir.IntImm)
and isinstance(sch.get(ms).extent, tir.Var)
and not isinstance(sch.get(ms).extent, tir.IntImm)
):
return None

Expand All @@ -1157,6 +1158,7 @@ def get_max_factor(n, factors):
config.vector_size,
config.unroll,
)

VecSize = min(get_max_factor(sch.get(n).extent // Threads_X, [1, 2, 4, 8]), VecSize)
dequant_block = None
matmul_block = reduction_block
Expand All @@ -1169,61 +1171,73 @@ def get_max_factor(n, factors):
elif blk is not matmul_block:
sch.compute_inline(blk)

m = sch.fuse(mb, ms)

sch.pad_einsum(matmul_block, [1, Threads_Y * Unroll_M, Threads_X * VecSize, 1])

rmat_block, wmat_block = (
block = sch.reindex(reduction_block, ("read", 0))
sch.pad_einsum(reduction_block, [1, Unroll_M, 1, 1])
sch.compute_inline(block)
trans_block, matmul_reindex = (
sch.get_producers(matmul_block)[0],
sch.get_consumers(matmul_block)[0],
)
mo, mi, mu = sch.split(m, [None, Threads_Y, Unroll_M])
no, ni, nv = sch.split(n, [None, Threads_X, VecSize])
k0, k1, k2, k3 = sch.split(k, [None, (Threads_X * VecSize) // 32, 4, 8])
sch.reorder(no, mo, ni, mi, k0, k1, k2, k3, mu, nv)

sch.compute_at(rmat_block, k0)
if dequant_block is not None:
sch.compute_at(dequant_block, k3)
sch.reverse_compute_at(wmat_block, mi)
sch.set_scope(rmat_block, 0, "shared")
sch.set_scope(matmul_block, 0, "local")
if epilogue_block is not None:
sch.compute_inline(matmul_reindex)
matmul_reindex = epilogue_block

if dequant_block is not None:
sch.set_scope(dequant_block, 0, "local")
sch.transform_layout(
trans_block,
("write", 0),
T.index_map(lambda i0, i1, i2: (i0, i1 // Unroll_M, i2, i1 % Unroll_M)),
)

sch.bind(mo, "blockIdx.y")
sch.bind(no, "blockIdx.x")
sch.bind(mi, "threadIdx.y")
sch.bind(ni, "threadIdx.x")
sch.vectorize(sch.get_loops(matmul_block)[-1])
# transpose block schedules
# sch.set_scope(trans_block, 0, "global.texture-1d")
tb, tn, tk = sch.get_loops(trans_block)
tbx, ttx = sch.split(tk, [None, Threads_X])
tby, tty, tc = sch.split(tn, [None, Threads_Y, Unroll_M])
sch.bind(tb, "blockIdx.z")
sch.bind(tby, "blockIdx.y")
sch.bind(tbx, "blockIdx.x")
sch.bind(tty, "threadIdx.y")
sch.bind(ttx, "threadIdx.x")
sch.reorder(tb, tby, tbx, tty, ttx, tc)
sch.vectorize(tc)

mb, ms, n, k = sch.get_loops(matmul_block)
m = sch.fuse(mb, ms)
bx, tx, vec = sch.split(n, [None, Threads_X, VecSize])
by, ty, unr = sch.split(m, [None, Threads_Y, Unroll_M])
k1, k2, k3 = sch.split(k, [None, 4, 8])
sch.reorder(bx, by, tx, ty, k1, k2, k3, unr, vec)
sch.set_scope(matmul_block, 0, "local")
if dequant_block is not None:
sch.vectorize(sch.get_loops(dequant_block)[-1])
sch.compute_at(dequant_block, k3)
sch.set_scope(dequant_block, 0, "local")
sch.bind(by, "blockIdx.y")
sch.bind(bx, "blockIdx.x")
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
sch.vectorize(vec)

# Co-operative Memory Fetch
ro, rv = sch.split(sch.get_loops(rmat_block)[-1], [None, VecSize])
sch.bind(ro, "threadIdx.x")
sch.vectorize(rv)
inp = sch.cache_read(matmul_block, read_buffer_index=0, storage_scope="local")
sch.compute_at(inp, k3, preserve_unit_loops=True)
sch.vectorize(sch.get_loops(inp)[-1])

wv = sch.get_loops(wmat_block)[-1]
sch.vectorize(wv)
sch.unroll(unr)
sch.unroll(k3)

# Scale and Quant Cache
if dequant_block is not None:
qb = sch.cache_read(dequant_block, 0, "local")
sb = sch.cache_read(dequant_block, 1, "local")
sch.compute_at(sb, k1)
sch.compute_at(qb, k2)
sch.set_scope(sb, 0, "local")
sch.set_scope(qb, 0, "local")
sch.vectorize(sch.get_loops(qb)[-1])
sch.vectorize(sch.get_loops(sb)[-1])
Aq_local = sch.cache_read(dequant_block, read_buffer_index=0, storage_scope="local")
sch.compute_at(Aq_local, k2, preserve_unit_loops=True)
sch.vectorize(sch.get_loops(Aq_local)[-1])
As_local = sch.cache_read(dequant_block, read_buffer_index=1, storage_scope="local")
sch.compute_at(As_local, k1, preserve_unit_loops=True)
sch.vectorize(sch.get_loops(As_local)[-1])
sch.vectorize(sch.get_loops(dequant_block)[-1])

if epilogue_block is not None:
sch.reverse_compute_at(epilogue_block, mi, preserve_unit_loops=True)
sch.set_scope(wmat_block, 0, "local")
sch.compute_inline(wmat_block)
sch.vectorize(sch.get_loops(epilogue_block)[-1])
sch.reverse_compute_at(matmul_reindex, ty)
o_ur, o_vec = sch.get_loops(matmul_reindex)[-2:]
sch.vectorize(o_vec)
sch.unroll(o_ur)
sch.decompose_reduction(matmul_block, k1)

sch.decompose_reduction(matmul_block, k0)
return sch
Loading

0 comments on commit e808010

Please sign in to comment.