Skip to content

Commit

Permalink
change num_sms to num_cus
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaohuguo2023 committed Jul 31, 2024
1 parent 480ad2c commit bdd1f8e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 65 deletions.
41 changes: 13 additions & 28 deletions python/perf-kernels/streamk/streamk_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@


@triton.jit()
def get_new_pid(current_pid, num_sms):
def get_new_pid(current_pid, num_cus):
# Number of XCDs
num_xcds = 8
# Number of pids per XCD in the new arrangement
pids_per_xcd = num_sms // num_xcds
pids_per_xcd = num_cus // num_xcds
# Compute current XCD and local pid within the XCD
xcd = current_pid % num_xcds
local_pid = current_pid // num_xcds
Expand All @@ -22,7 +22,7 @@ def get_tiles_config(
M,
N,
K,
num_sms,
num_cus,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
Expand All @@ -32,14 +32,14 @@ def get_tiles_config(
iters_per_tile = tl.cdiv(K, BLOCK_SIZE_K)

total_tiles = total_blocks_M * total_blocks_N
if num_sms > 0 and total_tiles > num_sms: # Stream-K
total_streamk_tiles = total_tiles % num_sms
if num_cus > 0 and total_tiles > num_cus: # Stream-K
total_streamk_tiles = total_tiles % num_cus
total_full_tiles = total_tiles - total_streamk_tiles
total_streamk_iters = total_streamk_tiles * iters_per_tile
# iterations related to full waves
streamk_iters_pcu = total_streamk_iters // num_sms
streamk_iters_pcu = total_streamk_iters // num_cus
# iterations related to last (partial) wave
streamk_remainder_iters = total_streamk_iters % num_sms
streamk_remainder_iters = total_streamk_iters % num_cus

else: # all tiles are computed using classical blocking
total_full_tiles = total_tiles
Expand All @@ -56,39 +56,36 @@ def streamk_gemm(
A,
B,
C,
bias_ptr,
P,
locks,
M,
N,
K,
num_sms,
num_cus,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_bias,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
BIAS: tl.constexpr,
EVEN_K: tl.constexpr,
):
pid = tl.program_id(0)
pid = get_new_pid(pid, num_sms)
pid = get_new_pid(pid, num_cus)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)

iters_per_tile, total_full_tiles, total_streamk_tiles, streamk_iters_pcu, streamk_remainder_iters = get_tiles_config(
M, N, K, num_sms, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K)
M, N, K, num_cus, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K)

acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32
rk = tl.arange(0, BLOCK_SIZE_K)

for tile_id in range(pid, total_full_tiles, num_sms):
for tile_id in range(pid, total_full_tiles, num_cus):
if GROUP_SIZE_M == 1:
pid_m = tile_id // num_pid_n
pid_n = tile_id % num_pid_n
Expand All @@ -107,10 +104,6 @@ def streamk_gemm(
A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak
B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn

if BIAS:
bias_ = bias_ptr + rm * stride_bias
bias = tl.load(bias_, mask=rm < M, other=0.0)

acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
if EVEN_K:
Expand All @@ -124,8 +117,6 @@ def streamk_gemm(
B_BASE += BLOCK_SIZE_K * stride_bk

c = acc.to(C.type.element_ty)
if BIAS:
c += bias[:, None]

rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
Expand Down Expand Up @@ -160,10 +151,6 @@ def streamk_gemm(
A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_SIZE_K * stride_ak * remainder
B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_SIZE_K * stride_bk * remainder

if BIAS:
bias_ = bias_ptr + rm * stride_bias
bias = tl.load(bias_, mask=rm < M, other=0.0)

acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
for current_iter in range(start_iter, end_iter):
if EVEN_K:
Expand All @@ -183,9 +170,9 @@ def streamk_gemm(
tile_iter_end = tile_iter + iters_per_tile
next_pid = pid + 1
end = end_iter
while (end < tile_iter_end and next_pid < num_sms):
while (end < tile_iter_end and next_pid < num_cus):
# todo: try use tl.load once cache modifier landed upstream
while tl.atomic_cas(locks + next_pid, 1, 1) != 1:
# while tl.load(locks + next_pid, cache_modifier='.ca') != 1:
pass
rm1 = tl.arange(0, BLOCK_SIZE_M)
rn1 = tl.arange(0, BLOCK_SIZE_N)
Expand All @@ -198,8 +185,6 @@ def streamk_gemm(
next_pid += 1

c = acc.to(C.type.element_ty)
if BIAS:
c += bias[:, None]

rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
Expand Down
Loading

0 comments on commit bdd1f8e

Please sign in to comment.