diff --git a/python/perf-kernels/streamk/streamk_kernel.py b/python/perf-kernels/streamk/streamk_kernel.py index bc207dbb5d08..138e6540e203 100644 --- a/python/perf-kernels/streamk/streamk_kernel.py +++ b/python/perf-kernels/streamk/streamk_kernel.py @@ -3,11 +3,11 @@ @triton.jit() -def get_new_pid(current_pid, num_sms): +def get_new_pid(current_pid, num_cus): # Number of XCDs num_xcds = 8 # Number of pids per XCD in the new arrangement - pids_per_xcd = num_sms // num_xcds + pids_per_xcd = num_cus // num_xcds # Compute current XCD and local pid within the XCD xcd = current_pid % num_xcds local_pid = current_pid // num_xcds @@ -22,7 +22,7 @@ def get_tiles_config( M, N, K, - num_sms, + num_cus, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, @@ -32,14 +32,14 @@ def get_tiles_config( iters_per_tile = tl.cdiv(K, BLOCK_SIZE_K) total_tiles = total_blocks_M * total_blocks_N - if num_sms > 0 and total_tiles > num_sms: # Stream-K - total_streamk_tiles = total_tiles % num_sms + if num_cus > 0 and total_tiles > num_cus: # Stream-K + total_streamk_tiles = total_tiles % num_cus total_full_tiles = total_tiles - total_streamk_tiles total_streamk_iters = total_streamk_tiles * iters_per_tile # iterations related to full waves - streamk_iters_pcu = total_streamk_iters // num_sms + streamk_iters_pcu = total_streamk_iters // num_cus # iterations related to last (partial) wave - streamk_remainder_iters = total_streamk_iters % num_sms + streamk_remainder_iters = total_streamk_iters % num_cus else: # all tiles are computed using classical blocking total_full_tiles = total_tiles @@ -56,39 +56,36 @@ def streamk_gemm( A, B, C, - bias_ptr, P, locks, M, N, K, - num_sms, + num_cus, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, - stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - BIAS: tl.constexpr, EVEN_K: tl.constexpr, ): pid = tl.program_id(0) - pid = get_new_pid(pid, num_sms) + pid = get_new_pid(pid, num_cus) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) iters_per_tile, total_full_tiles, total_streamk_tiles, streamk_iters_pcu, streamk_remainder_iters = get_tiles_config( - M, N, K, num_sms, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K) + M, N, K, num_cus, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K) acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 rk = tl.arange(0, BLOCK_SIZE_K) - for tile_id in range(pid, total_full_tiles, num_sms): + for tile_id in range(pid, total_full_tiles, num_cus): if GROUP_SIZE_M == 1: pid_m = tile_id // num_pid_n pid_n = tile_id % num_pid_n @@ -107,10 +104,6 @@ def streamk_gemm( A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - if BIAS: - bias_ = bias_ptr + rm * stride_bias - bias = tl.load(bias_, mask=rm < M, other=0.0) - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): if EVEN_K: @@ -124,8 +117,6 @@ def streamk_gemm( B_BASE += BLOCK_SIZE_K * stride_bk c = acc.to(C.type.element_ty) - if BIAS: - c += bias[:, None] rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -160,10 +151,6 @@ def streamk_gemm( A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_SIZE_K * stride_ak * remainder B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_SIZE_K * stride_bk * remainder - if BIAS: - bias_ = bias_ptr + rm * stride_bias - bias = tl.load(bias_, mask=rm < M, other=0.0) - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for current_iter in range(start_iter, end_iter): if EVEN_K: @@ -183,9 +170,9 @@ def streamk_gemm( tile_iter_end = tile_iter + iters_per_tile next_pid = pid + 1 end = end_iter - while (end < tile_iter_end and next_pid < num_sms): + while (end < tile_iter_end and next_pid < num_cus): + # todo: try use tl.load once cache modifier landed upstream while tl.atomic_cas(locks + next_pid, 1, 1) != 1: - # while tl.load(locks + next_pid, cache_modifier='.ca') != 1: pass rm1 = tl.arange(0, BLOCK_SIZE_M) rn1 = tl.arange(0, BLOCK_SIZE_N) @@ -198,8 +185,6 @@ def streamk_gemm( next_pid += 1 c = acc.to(C.type.element_ty) - if BIAS: - c += bias[:, None] rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N diff --git a/python/perf-kernels/streamk/tune_streamk.py b/python/perf-kernels/streamk/tune_streamk.py index 1fb5bcc2042f..3b0fbdb960c7 100644 --- a/python/perf-kernels/streamk/tune_streamk.py +++ b/python/perf-kernels/streamk/tune_streamk.py @@ -157,7 +157,7 @@ def read_config(config): return block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack -def gen_kernel_and_configStr_from_config(M, N, K, num_sms, EVEN_K, config, dtype_a, dtype_b, dtype_c, dtype_p, +def gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock): block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(config) torch_dtype_a = 'fp16' @@ -178,13 +178,13 @@ def gen_kernel_and_configStr_from_config(M, N, K, num_sms, EVEN_K, config, dtype configStr = f"M{M}_N{N}_K{K}_BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}" matmul_def_str = f""" -def matmul_{configStr}(a, b, c, P, locks, M, N, K, num_sms, am, ak, bk, bn, cm, cn, warmup=False): - grid = num_sms +def matmul_{configStr}(a, b, c, P, locks, M, N, K, num_cus, am, ak, bk, bn, cm, cn, warmup=False): + grid = num_cus #print(f'config: streamk_gemm_{configStr}', flush=True) if warmup: streamk_gemm_{configStr}.warmup( {torch_dtype_a}, {torch_dtype_b}, {torch_dtype_c}, {torch_dtype_p}, {torch_dtype_lock}, - M, N, K, num_sms, + M, N, K, num_cus, am, ak, bk, bn, cm, cn, BLOCK_SIZE_M = {block_m}, BLOCK_SIZE_N = {block_n}, @@ -202,7 +202,7 @@ def matmul_{configStr}(a, b, c, P, locks, M, N, K, num_sms, am, ak, bk, bn, cm, else: streamk_gemm_{configStr}[grid,]( a, b, c, P, locks, - M, N, K, num_sms, + M, N, K, num_cus, am, ak, bk, bn, cm, cn, BLOCK_SIZE_M = {block_m}, BLOCK_SIZE_N = {block_n}, @@ -217,9 +217,9 @@ def matmul_{configStr}(a, b, c, P, locks, M, N, K, num_sms, am, ak, bk, bn, cm, ) return c -def try_config_{configStr}(M, N, K, num_sms, am, ak, bk, bn, cm, cn): +def try_config_{configStr}(M, N, K, num_cus, am, ak, bk, bn, cm, cn): try: - matmul_{configStr}(None, None, None, None, None, M, N, K, num_sms, am, ak, bk, bn, cm, cn, True) + matmul_{configStr}(None, None, None, None, None, M, N, K, num_cus, am, ak, bk, bn, cm, cn, True) return True except Exception as e: print(f'invalid config(compilation): {configStr}: ', e, flush=True) @@ -241,7 +241,7 @@ def generated_kernel_name(M, N, K, gpu_id): # 4. test_gemm to invoke # 4.1 run try_config in parallel # 4.2 matmul in a loop of 10 iterations -def generate_kernel(M, N, K, num_sms, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, configs, +def generate_kernel(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, configs, jobs, iters, run_bench): filenames = [] for i in range(jobs): @@ -268,7 +268,7 @@ def generate_kernel(M, N, K, num_sms, col_a, col_b, dtype_a, dtype_b, dtype_c, d for config in configs: file_idx = idx % jobs EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False - configStr, matmul_def_str = gen_kernel_and_configStr_from_config(M, N, K, num_sms, EVEN_K, config, dtype_a, + configStr, matmul_def_str = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock) # Copy the streamk_gemm with name replaced streamk_gemm_config = streamk_gemm_code.replace("streamk_gemm", f"streamk_gemm_{configStr}") @@ -282,12 +282,12 @@ def generate_kernel(M, N, K, num_sms, col_a, col_b, dtype_a, dtype_b, dtype_c, d # pre string block_m = config.get('BLOCK_SIZE_M') block_n = config.get('BLOCK_SIZE_N') - test_gemm_pre_str = f"""def test_gemm(M, N, K, num_sms, num_threads): + test_gemm_pre_str = f"""def test_gemm(M, N, K, num_cus, num_threads): thread_pool = multiprocessing.Pool(processes=num_threads) a, a_fp16 = gen_input(M, K, '{dtype_a}', {col_a}, 1, '{init_type}', device='cuda') b, b_fp16 = gen_input(K, N, '{dtype_b}', {col_b}, 2, '{init_type}', device='cuda') c = torch.zeros((M, N), device=a.device, dtype={tl_to_torch_types[name_to_tl_types[dtype_c]]}) - task_args = (M, N, K, num_sms, + task_args = (M, N, K, num_cus, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1)) @@ -303,7 +303,7 @@ def generate_kernel(M, N, K, num_sms, col_a, col_b, dtype_a, dtype_b, dtype_c, d idx = 0 for config in configs: EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False - configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, num_sms, EVEN_K, config, None, None, None, None, + configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, None, None, None, None, None) task_str = f" results += [thread_pool.apply_async(try_config_{configStr}, args=task_args)]\n" + \ f" config_names += ['{configStr}']\n" @@ -336,7 +336,7 @@ def generate_kernel(M, N, K, num_sms, col_a, col_b, dtype_a, dtype_b, dtype_c, d runs = iters if run_bench else 200 for config in configs: EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False - configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, num_sms, EVEN_K, config, None, None, None, None, + configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, None, None, None, None, None) block_m = config.get('BLOCK_SIZE_M') block_n = config.get('BLOCK_SIZE_N') @@ -344,9 +344,9 @@ def generate_kernel(M, N, K, num_sms, col_a, col_b, dtype_a, dtype_b, dtype_c, d if '{configStr}' not in failed_configs: print(f"{configStr}") for i in range({runs}): - locks = torch.zeros((num_sms,), device = "cuda", dtype = torch.int32) - P = torch.zeros((num_sms, {block_m}*{block_n}), device="cuda", dtype=torch.float32) - d = matmul_{configStr}(a, b, c, P, locks, M, N, K, num_sms, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))""" + locks = torch.zeros((num_cus,), device = "cuda", dtype = torch.int32) + P = torch.zeros((num_cus, {block_m}*{block_n}), device="cuda", dtype=torch.float32) + d = matmul_{configStr}(a, b, c, P, locks, M, N, K, num_cus, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))""" f_kernel[idx % jobs].write(matmul_call_str + "\n") idx += 1 # post string @@ -362,9 +362,9 @@ def main(): parser.add_argument("-n", type=int, default=1, help='number of threads') args = parser.parse_args() numThreads = args.n - num_sms = 304 + num_cus = 304 """ - test_gemm_call_str = f'test_gemm({M}, {N}, {K}, num_sms, numThreads)' + test_gemm_call_str = f'test_gemm({M}, {N}, {K}, num_cus, numThreads)' for fi in range(jobs): f_kernel[fi].write(def_main_str) f_kernel[fi].write(test_gemm_call_str + "\n\n") @@ -373,7 +373,7 @@ def main(): f_kernel[fi].close() -def extract_kernel_time(M, N, K, num_sms, EVEN_K, config, df): +def extract_kernel_time(M, N, K, num_cus, EVEN_K, config, df): # Correct the header by removing 'sig' and 'obj' to reduce number from 21 to 19 # once the bug is fixed, we should not need below two lines cols = [ @@ -382,7 +382,7 @@ def extract_kernel_time(M, N, K, num_sms, EVEN_K, config, df): ] df.columns = cols - configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, num_sms, EVEN_K, config, None, None, None, None, None) + configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, None, None, None, None, None) filtered_df = df[df['KernelName'].str.contains(configStr, na=False)].copy() filtered_df['DurationNs'] = filtered_df['EndNs'] - filtered_df['BeginNs'] @@ -390,7 +390,7 @@ def extract_kernel_time(M, N, K, num_sms, EVEN_K, config, df): return config, meanTime -def profile_batch_kernels(M, N, K, num_sms, gpuid, gpus, jobs, verbose): +def profile_batch_kernels(M, N, K, num_cus, gpuid, gpus, jobs, verbose): ngpus = len(gpus) gpuIdx = gpus.index(gpuid) if gpuIdx + 1 > jobs: @@ -406,10 +406,10 @@ def profile_batch_kernels(M, N, K, num_sms, gpuid, gpus, jobs, verbose): jobId += ngpus -def tune_gemm_config(M, N, K, num_sms, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, configs, +def tune_gemm_config(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, configs, run_bench, jobs, iters, skipWarmup, verbose=0, num_threads=16, gpus=[0]): # Generate kernel out of all configs - generate_kernel(M, N, K, num_sms, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, configs, + generate_kernel(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, configs, jobs, iters, run_bench) # remove any compiled kernel in the cache @@ -427,7 +427,7 @@ def tune_gemm_config(M, N, K, num_sms, col_a, col_b, dtype_a, dtype_b, dtype_c, # profile generated kernels running = [ - multiprocessing.Process(target=profile_batch_kernels, args=(M, N, K, num_sms, gpu_id, gpus, jobs, verbose)) + multiprocessing.Process(target=profile_batch_kernels, args=(M, N, K, num_cus, gpu_id, gpus, jobs, verbose)) for gpu_id in gpus ] for p in running: @@ -454,7 +454,7 @@ def tune_gemm_config(M, N, K, num_sms, col_a, col_b, dtype_a, dtype_b, dtype_c, EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False file_idx = idx % jobs tasks += [ - thread_pool.apply_async(extract_kernel_time, args=(M, N, K, num_sms, EVEN_K, config, df_prof[file_idx])) + thread_pool.apply_async(extract_kernel_time, args=(M, N, K, num_cus, EVEN_K, config, df_prof[file_idx])) ] idx += 1 thread_pool.close() @@ -525,7 +525,7 @@ def init_by_size_and_type(size, dtype, init_type): return input, input_f16 -def matmul(a, b, c, P, locks, num_sms, block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, +def matmul(a, b, c, P, locks, num_cus, block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, EVEN_K): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" @@ -535,17 +535,17 @@ def matmul(a, b, c, P, locks, num_sms, block_m, block_n, block_k, group_m, num_w K, N = b.shape # 1D launch kernel where each block gets its own program. - grid = num_sms + grid = num_cus streamk_gemm[ grid, - ](a, b, c, P, locks, M, N, K, num_sms, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), + ](a, b, c, P, locks, M, N, K, num_cus, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n, BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_m, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize, kpack=kpack, EVEN_K=EVEN_K) return c -def test_correctness(M, N, K, num_sms, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, verbose): +def test_correctness(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, verbose): block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(config) torch.manual_seed(0) #a = torch.randn((M, K), device='cuda', dtype=datatype) @@ -556,9 +556,9 @@ def test_correctness(M, N, K, num_sms, col_a, col_b, dtype_a, dtype_b, dtype_c, print(f"{block_k}") EVEN_K = K % block_k == 0 c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]]) - locks = torch.zeros((num_sms, ), device="cuda", dtype=torch.int32) - P = torch.zeros((num_sms, block_m * block_n), device="cuda", dtype=torch.float32) - triton_output = matmul(a, b, c, P, locks, num_sms, block_m, block_n, block_k, group_m, num_warps, num_stages, + locks = torch.zeros((num_cus, ), device="cuda", dtype=torch.int32) + P = torch.zeros((num_cus, block_m * block_n), device="cuda", dtype=torch.float32) + triton_output = matmul(a, b, c, P, locks, num_cus, block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, EVEN_K) torch_output = torch.matmul(a_fp16, b_fp16) # print(f"triton_output={triton_output}") @@ -698,7 +698,7 @@ def main(): jobs = args.jobs iters = args.iters skipWarmup = args.no_warmup - num_sms = 304 + num_cus = 304 # Get GPU ids ngpus = args.ngpus @@ -748,7 +748,7 @@ def main(): # Check correctness from given configs if args.compare_wo_tuning: for (M, N, K, col_a, col_b, myConfig) in mnks: - test_correctness(M, N, K, num_sms, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, myConfig, True) + test_correctness(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, myConfig, True) return configs_full = get_full_tuning_space() @@ -783,7 +783,7 @@ def main(): if args.verbose: verbose_level = 2 minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config( - M, N, K, num_sms, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, pruned_configs, + M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, pruned_configs, run_bench, jobs, iters, skipWarmup, num_threads=args.num_threads, gpus=gpus, verbose=verbose_level) EVEN_K = True if K % bestConfig.get('BLOCK_SIZE_K') == 0 else False @@ -795,7 +795,7 @@ def main(): if not run_bench: print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ", flush=True) - bestConfig_compact_str, _ = gen_kernel_and_configStr_from_config(M, N, K, num_sms, EVEN_K, bestConfig, None, + bestConfig_compact_str, _ = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, bestConfig, None, None, None, None, None) if not run_bench: print(f'best_config: {bestConfig_compact_str}', end=" ", flush=True) @@ -823,7 +823,7 @@ def main(): # Check correctness if asked to if args.compare: print("correctness: ", end=" ", flush=True) - test_correctness(M, N, K, num_sms, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, bestConfig, False) + test_correctness(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, bestConfig, False) elif not run_bench: print("", flush=True)