diff --git a/scripts/amd/gemm/README.md b/scripts/amd/gemm/README.md index a81fbc416391d..230ebe3f9c6d8 100644 --- a/scripts/amd/gemm/README.md +++ b/scripts/amd/gemm/README.md @@ -162,13 +162,36 @@ so each element of the bias vector is added to all elements of the corresponding # GEMM Tuning Script v3.3 +### API changes + +no API changes + ### Implementation changes -- A separate file, named `generate_all_kernels.py`, is generated for compilations only. -It contains all the kernels in the tuning space and they will be compiled by 64 threads by default. -Compiling all the kernels in a single file in parallel is faster than splitting them -into multiple files. This can greatly reduce the compile time of the tuning process. -- `configStr` does not contain gemm size anymore. This allows the same matmul_{configStr} kernel -to be reused by different gemm sizes (this is not implemented yet). +- Before the start of the for loop of tuning, we generated a file named myKernels.py +(this is obtained from `get_filename_myKernels()`) which contains all the matmul_kernel +with configs in the full, un-pruned tuning space. This file will be used by this tuning +session shared by all gemm sizes. +- Inside the for loop of tuning, each iteration tunes one gemm size + 1. Compilation stage: + - We generate a single compilation driver file, named compile_driver.py (this is + obtained from `get_filename_compile_driver`) which contains the wrapper functions + of all the configs in the **pruned** tuning space for this gemm size. + - All the kernels will be compiled by 32 threads by default. Compiling all the + kernels in a single file in parallel is faster than splitting them into multiple + files. This can greatly reduce the compile time of the tuning process. + - Note that we no longer generate matmul_kernel in this file. Kernels are imported + from myKernels.py. + 2. Profile stage + - We generate one task file per job, named `profile_driver_MxNxK_{job_id}.py` + (this is obtained from `get_filename_profile_driver`). The only difference is + that we no longer generate matmul_kernel in this file. Kernels are imported + from myKernels.py. +- `configStr` does not contain gemm size anymore. This allows the same matmul_{configStr} +kernel to be reused by different gemm sizes. +- `configStr` does not contain `_bias` if bias is provided. This is because we do not +expect to compare the same kernel w/ and w/o bias. Therefore, we treat bias in the same +way as gemm sizes. - Add support for `EVEN_K` in the matmul kernel. Now the kernel support `BLOCK_SIZE_K` that cannot divide `K`. +- Now we use `rocprofv2` to measure kernel time. diff --git a/scripts/amd/gemm/tune_gemm.py b/scripts/amd/gemm/tune_gemm.py index ef998b74e500b..3116d6f31491a 100644 --- a/scripts/amd/gemm/tune_gemm.py +++ b/scripts/amd/gemm/tune_gemm.py @@ -16,6 +16,9 @@ import multiprocessing import pandas as pd +from utils.file_generator import * +from utils.name_utils import * + def is_hip_available(): try: __import__("hip") @@ -28,18 +31,18 @@ def is_hip_available(): def get_full_tuning_space(): configs = [] - block_mn_range = [16, 32, 64, 128, 256] - block_k_range = [16, 32, 64, 128, 256] + block_mn_range = [256] + block_k_range = [64] split_k_range = [1, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] num_warps_range = [1, 2, 4, 8] - group_m_range = [1, 4, 8, 16, 32] + group_m_range = [1] # For now we see better perf with num_stages=0 for all gemm configs we care # But keep this explicit so that we do not forget we may need to set it to # other values in the future num_stage_range = [0] waves_per_eu_range = [0] - matrix_instr_nonkdim_range = [16, 32] - kpack_range = [1, 2] + matrix_instr_nonkdim_range = [16] + kpack_range = [2] for block_m in block_mn_range: for block_n in block_mn_range: @@ -156,336 +159,14 @@ def run_bash_command(commandstring, capture=True): proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash') return None -def read_config(config): - block_m = config.get('BLOCK_SIZE_M') - block_n = config.get('BLOCK_SIZE_N') - block_k = config.get('BLOCK_SIZE_K') - group_m = config.get('GROUP_SIZE_M') - split_k = config.get('SPLIT_K') - num_warps = config.get('num_warps') - num_stages = config.get('num_stages') - waves_per_eu = config.get('waves_per_eu') - mfma_instr_size = config.get('matrix_instr_nonkdim') - kpack = config.get('kpack') - return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack - - -def gen_configStr(config, bias_size): - block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(config) - - ## {M}_{N}_{K} is removed since the same kernel can be used for differen gemm sizes - configStr = f"BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}" - - if bias_size > 0: - configStr += "_bias" - - return configStr - -## construct the configStr and generate the wrapper function matmul_{configStr}() -## If `warmup` is set, the generated kernel will be **compiled** -def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype_c, bias_size, warmup): - block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(config) - - configStr = gen_configStr(config, bias_size) - - use_bias = bias_size > 0 - - if warmup: - torch_dtype_a = 'fp16' - torch_dtype_b = 'fp16' - torch_dtype_c = 'fp16' - if dtype_a: - torch_dtype_a = tl_to_torch_types[name_to_tl_types[dtype_a]] - if dtype_b: - torch_dtype_b = tl_to_torch_types[name_to_tl_types[dtype_b]] - if dtype_c: - torch_dtype_c = tl_to_torch_types[name_to_tl_types[dtype_c]] - - matmul_def_str = f""" -def matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn): - matmul_kernel_{configStr}.warmup( - {torch_dtype_a}, {torch_dtype_b}, {torch_dtype_c}, {torch_dtype_c}, - M, N, K, - am, ak, bk, bn, cm, cn, biasn, - BLOCK_SIZE_M = {block_m}, - BLOCK_SIZE_N = {block_n}, - BLOCK_SIZE_K = {block_k}, - GROUP_SIZE_M = {group_m}, - SPLIT_K = {split_k}, - num_warps = {num_warps}, - num_stages = {num_stages}, - waves_per_eu = {waves_per_eu}, - matrix_instr_nonkdim = {mfmaInstrSize}, - kpack = {kpack}, - BIAS={use_bias}, - EVEN_K={EVEN_K}, - grid=(1,), - ) - return None - -def try_compile_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn): - try: - matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn) - return True - except Exception as e: - print(f'invalid config(compilation): {configStr}: ', e, flush=True) - return False -""" - else: - matmul_def_str = f""" -def matmul_{configStr}(a, b, c, bias, M, N, K, am, ak, bk, bn, cm, cn, biasn): - grid = triton.cdiv(M, {block_m}) * triton.cdiv(N, {block_n}), {split_k} - matmul_kernel_{configStr}[grid]( - a, b, c, bias, - M, N, K, - am, ak, bk, bn, cm, cn, biasn, - BLOCK_SIZE_M = {block_m}, - BLOCK_SIZE_N = {block_n}, - BLOCK_SIZE_K = {block_k}, - GROUP_SIZE_M = {group_m}, - SPLIT_K = {split_k}, - num_warps = {num_warps}, - num_stages = {num_stages}, - waves_per_eu = {waves_per_eu}, - matrix_instr_nonkdim = {mfmaInstrSize}, - kpack = {kpack}, - BIAS = {use_bias}, - EVEN_K = {EVEN_K} - ) - return c -""" - return configStr, matmul_def_str - - -def generated_kernel_name(M, N, K, gpu_id): - path = os.path.dirname(os.path.abspath(__file__)) - return f"{path}/generated_kernel{M}-{N}-{K}-{gpu_id}.py" - - -def generate_compile_kernel_filename(): - path = os.path.dirname(os.path.abspath(__file__)) - return f"{path}/generated_all_kernels.py" - - -# Open {len(jobs)} files -# generated_kernelM-N-K-0.py, generated_kernelM-N-K-1.py, ..., generated_kernelM-N-K-{njobs-1}.py -# and generate -# 1. matmul kernels of all configs -# 2. wrapper function matmul to invoke all the generated kernels -# 3. test_gemm to invoke matmul in a loop of {iters} iterations -def generate_profile_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, jobs, iters, run_bench, rotating_buffer_size, bias_size, icache_flush): - filenames = [] - for i in range(jobs): - filenames.append(generated_kernel_name(M, N, K, i)) - f_kernel = [open(path, 'w') for path in filenames] - - # write imports - import_str = """import torch -import triton -import triton.language as tl -import argparse -import sys -import multiprocessing -from tune_gemm import gen_rotating_tensors -""" - if icache_flush: - import_str += """ -from icache_flush import icache_flush -""" - for fi in range(jobs): - f_kernel[fi].write(import_str + "\n") - - # write definitions of matmul_kernel_xxx - # and matmul_xxx and try_config - with open(os.path.dirname(os.path.abspath(__file__))+"/matmul_kernel.py") as file: - matmul_kernel_code = file.read() - idx = 0 - for config in configs: - file_idx = idx % jobs - EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False - configStr, matmul_def_str = gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype_c, bias_size, False) - # Copy the matmul_kernel with name replaced - matmul_kernel_config = matmul_kernel_code.replace("matmul_kernel", f"matmul_kernel_{configStr}") - matmul_kernel_config = matmul_kernel_config.replace("import triton.language as tl", "") - matmul_kernel_config = matmul_kernel_config.replace("import triton", "") - f_kernel[file_idx].write(matmul_kernel_config + "\n\n") - f_kernel[file_idx].write(matmul_def_str + "\n") - idx += 1 - - # write test_gemm - # pre string - test_gemm_pre_str = f"""def test_gemm(M, N, K, rotating_buffer_size, bias_size): - tensors = gen_rotating_tensors(M, N, K, '{dtype_a}', {col_a}, '{dtype_b}', {col_b}, '{dtype_c}', - 1, '{init_type}', rotating_buffer_size, bias_size, device='cuda') - - a = tensors['input_a'][0] - b = tensors['input_b'][0] - c = tensors['output_c'][0] - assert bias_size == M or bias_size == 0 - - stride_bias = tensors['bias'][0].stride(0) if bias_size > 0 else 0 - - try: - with open("{generate_compile_kernel_filename()}.failed_configs", "r") as f: - failed_configs = [cfg.strip() for cfg in f.readlines()] - except Exception: - failed_configs = [] -""" - for fi in range(jobs): - f_kernel[fi].write(test_gemm_pre_str + "\n") - - # call all matmul_xxx functions - idx = 0 - runs = iters if run_bench else 200 - call_icache_flush = 'icache_flush()' if icache_flush else '' - for config in configs: - configStr = gen_configStr(config, bias_size) - matmul_call_str = f""" - if '{configStr}' not in failed_configs: - rotating_num = tensors['rotating_num'] - for i in range({runs}): - a = tensors['input_a'][i % rotating_num] - b = tensors['input_b'][i % rotating_num] - c = tensors['output_c'][i % rotating_num] - bias = tensors['bias'][i % rotating_num] if bias_size > 0 else None""" - if icache_flush: - matmul_call_str += f""" - icache_flush()""" - matmul_call_str += f""" - d = matmul_{configStr}(a, b, c, bias, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), bias.stride(0))""" - f_kernel[idx % jobs].write(matmul_call_str + "\n") - idx += 1 - # post string - for fi in range(jobs): - f_kernel[fi].write(" return d\n") - - # def main and call test_gemm - def_main_str = f""" -def main(): - parser = argparse.ArgumentParser( - prog="tune a specific gemm size", - allow_abbrev=False,) - parser.add_argument("-n", type=int, default=1, help='number of threads') - parser.add_argument("-rotating_tensor", type=int, default={rotating_buffer_size}, help='size of rotating buffer (MB), default: {rotating_buffer_size}') - args = parser.parse_args() - numThreads = args.n - rotating_buffer_size = args.rotating_tensor - """ - test_gemm_call_str = f'test_gemm({M}, {N}, {K}, rotating_buffer_size, {M})' - for fi in range(jobs): - f_kernel[fi].write(def_main_str) - f_kernel[fi].write(test_gemm_call_str + "\n\n") - f_kernel[fi].write("""if __name__ == '__main__': - sys.exit(main())""") - f_kernel[fi].close() - - -## Generate a single file that contains all kernels in the tuning space. -## This file is used to **compile** the kernels in parallel -def generate_compile_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, jobs, iters, run_bench, rotating_buffer_size, bias_size, icache_flush): - - filename = generate_compile_kernel_filename() - f_kernel = open(filename, 'w') - - # write imports - import_str = """import torch -import triton -import triton.language as tl -import argparse -import sys -import multiprocessing -from tune_gemm import gen_rotating_tensors -""" - - f_kernel.write(import_str + "\n") - - # write definitions of matmul_kernel_xxx - # and matmul_xxx and try_compile_config - with open(os.path.dirname(os.path.abspath(__file__))+"/matmul_kernel.py") as file: - matmul_kernel_code = file.read() - for config in configs: - EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False - configStr, matmul_def_str = gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype_c, bias_size, True) - # Copy the matmul_kernel with name replaced - matmul_kernel_config = matmul_kernel_code.replace("matmul_kernel", f"matmul_kernel_{configStr}") - matmul_kernel_config = matmul_kernel_config.replace("import triton.language as tl", "") - matmul_kernel_config = matmul_kernel_config.replace("import triton", "") - f_kernel.write(matmul_kernel_config + "\n\n") - f_kernel.write(matmul_def_str + "\n") - - # write compile_kernels - # pre string - compile_kernels_pre_str = f"""def compile_kernels(M, N, K, rotating_buffer_size, bias_size, num_threads): - thread_pool = multiprocessing.Pool(processes=num_threads) - tensors = gen_rotating_tensors(M, N, K, '{dtype_a}', {col_a}, '{dtype_b}', {col_b}, '{dtype_c}', - 1, '{init_type}', rotating_buffer_size, bias_size, device='cuda') - - a = tensors['input_a'][0] - b = tensors['input_b'][0] - c = tensors['output_c'][0] - assert bias_size == M or bias_size == 0 - - stride_bias = tensors['bias'][0].stride(0) if bias_size > 0 else 0 - task_args = (M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), stride_bias) - - results = [] - config_names = [] -""" - f_kernel.write(compile_kernels_pre_str + "\n") - - # warm up call of all matmul functions in parallel - for config in configs: - configStr = gen_configStr(config, bias_size) - task_str = f" results += [thread_pool.apply_async(try_compile_config_{configStr}, args=task_args)]\n" + \ - f" config_names += ['{configStr}']\n" - f_kernel.write(task_str) - - threadpool_str = """ - failed_configs = [] - for i in range(len(results)): - results[i].wait() - res = results[i].get() - if not res: - failed_configs += [config_names[i]] - thread_pool.close() - thread_pool.join() - if failed_configs: - with open("{filename}.failed_configs", "w") as f: - for cfg in failed_configs: - f.write(cfg + "\\n") -""".format(filename=filename) - f_kernel.write(threadpool_str) - - # def main and call compile_kernels - def_main_str = f""" -def main(): - parser = argparse.ArgumentParser( - prog="tune a specific gemm size", - allow_abbrev=False,) - parser.add_argument("-n", type=int, default=32, help='number of threads') - parser.add_argument("-rotating_tensor", type=int, default={rotating_buffer_size}, help='size of rotating buffer (MB), default: {rotating_buffer_size}') - args = parser.parse_args() - numThreads = args.n - rotating_buffer_size = args.rotating_tensor - """ - compile_kernels_call_str = f'compile_kernels({M}, {N}, {K}, rotating_buffer_size, {M}, numThreads)' - - f_kernel.write(def_main_str) - f_kernel.write(compile_kernels_call_str + "\n\n") - f_kernel.write("""if __name__ == '__main__': - sys.exit(main())""") - f_kernel.close() def extract_kernel_time(M, N, K, config, df, bias_size): # Correct the header by removing 'sig' and 'obj' to reduce number from 21 to 19 - # once the bug(https://github.com/ROCm/rocprofiler/issues/144) fixed, we should + # once the bug(https://github.com/ROCm/rocprofiler/issues/144) fixed, we should # not need below two lines cols = ['Index','KernelName','gpu-id','queue-id','queue-index','pid','tid','grd','wgr','lds','scr','arch_vgpr','accum_vgpr','sgpr','wave_size','DispatchNs','BeginNs','EndNs','CompleteNs'] df.columns = cols - configStr = gen_configStr(config, bias_size) + configStr = gen_configStr(config) filtered_df = df[df['KernelName'].str.contains(configStr, na=False)].copy() filtered_df['DurationNs'] = filtered_df['EndNs'] - filtered_df['BeginNs'] meanTime = filtered_df['DurationNs'].tail(100).mean() @@ -500,10 +181,10 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose): os.environ['ROCR_VISIBLE_DEVICES'] = str(gpuid) jobId = gpuIdx while jobId < jobs: - kernel_name = generated_kernel_name(M, N, K, jobId) + kernel_name = get_filename_profile_driver(M, N, K, jobId) if verbose: print(f"profiling {kernel_name} on GPU {gpuid}") - run_bash_command_wrapper(f"rocprofv2 --plugin file --plugin-version 1 --kernel-trace -o {jobId} python {generated_kernel_name(M, N, K, jobId)}", capture=(verbose < 2)) + run_bash_command_wrapper(f"rocprofv2 --plugin file --plugin-version 1 --kernel-trace -o {jobId} python {get_filename_profile_driver(M, N, K, jobId)}", capture=(verbose < 2)) jobId += ngpus @@ -515,19 +196,19 @@ def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type start_time = datetime.now() if not skipWarmup: # Generate kernel out of all configs - generate_compile_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, jobs, iters, run_bench, rotating_buffer_size, bias_size, icache_flush) + fname = generate_compile_driver(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, rotating_buffer_size, bias_size) # remove any compiled kernel in the cache run_bash_command("rm -rf ~/.triton/cache") - fname = generate_compile_kernel_filename() run_bash_command(f"python {fname} -n {num_threads}", capture=(verbose < 2)) compile_end = datetime.now() compile_time = compile_end - start_time if verbose: print(f"compile time: {compile_time}", flush=True) + # Generate kernels out of all configs - generate_profile_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, jobs, iters, run_bench, rotating_buffer_size, bias_size, icache_flush) + generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, jobs, iters, run_bench, rotating_buffer_size, bias_size, icache_flush) # profile generated kernels running = [multiprocessing.Process(target=profile_batch_kernels, args=(M, N, K, gpu_id, gpus, jobs, verbose)) for gpu_id in gpus] @@ -651,12 +332,12 @@ def gen_rotating_tensors(M, N, K, bs, bs_fp16 = gen_input(M, 1, dtype_b, need_Trans_b, 2, init_type, device='cuda') bias.append(bs.squeeze()) - in_outs = {"rotating_num": block_count, + in_outs = {"rotating_num": block_count, "input_a": a, "input_b": b, "output_c": c, "bias": bias} - + return in_outs @@ -671,6 +352,7 @@ def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps grid = triton.cdiv(M, block_m) * triton.cdiv(N, block_n), split_k stride_bias = bias.stride(0) if use_bias else 0 + EVEN_K = K % block_k == 0 matmul_kernel[grid]( a, b, c, bias, M, N, K, @@ -689,6 +371,7 @@ def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps matrix_instr_nonkdim=mfmaInstrSize, kpack=kpack, BIAS=use_bias, + EVEN_K=EVEN_K ) return c @@ -727,19 +410,6 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type print(f'{size_str} Incorrect❌') -def get_default_tuning_result_filename(): - git_branch_name = run_bash_command("git rev-parse --abbrev-ref HEAD") - git_branch_name = git_branch_name[0].decode() - # handle branch name of "xxx/xxx" format - git_branch_name = git_branch_name.replace('/', '_') - git_commit_hash = run_bash_command("git rev-parse --short HEAD") - git_commit_hash = git_commit_hash[0].decode() - - dt_string = datetime.now().strftime("%m-%d-%Y-%H:%M:%S") - defaultName = f"tuning_results_{git_branch_name}@{git_commit_hash}_{dt_string}.yaml" - return defaultName - - def parse_args(): parser = argparse.ArgumentParser( prog="tune a specific gemm size", @@ -782,29 +452,7 @@ def parse_args(): return args -TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz') -TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz') -tl_to_torch_types = { - tl.float16: torch.float16, - tl.bfloat16: torch.bfloat16, - tl.float32: torch.float32, - tl.int8: torch.int8, - tl.int32: torch.int32, -} -if TORCH_HAS_FP8E5B16: - tl_to_torch_types[tl.float8e5b16] = torch.float8_e5m2fnuz -if TORCH_HAS_FP8E4B8: - tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz - -name_to_tl_types = { - 'int8': tl.int8, - 'int32': tl.int32, - 'fp16': tl.float16, - 'fp32': tl.float32, - 'bf16': tl.bfloat16, - 'fp8': tl.float8e4b8, - 'bf8': tl.float8e5b16, -} + def process_item(item): M = item['M'] @@ -925,6 +573,10 @@ def main(): return configs_full = get_full_tuning_space() + ## Generate a file named myKernels.py that contains all the kernels + ## in the un-pruned space + generate_matmul_kernels(configs_full) + start_time = datetime.now() # Append to the output file so that we can save all results into one file @@ -936,6 +588,8 @@ def main(): else: print(f"Tuning {len(mnks)} gemm sizes starts at: {start_time}", flush=True) + ## Big for loop of tuning + ## Each iteration performs tuning for one gemm size for (M, N, K, col_a, col_b, myConfig) in mnks: start_local_time = datetime.now() # Obtain a pruned tuning space according to gemm size @@ -973,7 +627,7 @@ def main(): if not run_bench: print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ", flush=True) - bestConfig_compact_str = gen_configStr(bestConfig, bias_size) + bestConfig_compact_str = gen_configStr(bestConfig) if not run_bench: print(f'best_config: {bestConfig_compact_str}', end=" ", flush=True) @@ -991,13 +645,13 @@ def main(): # remove generated files if asked to if not keepTmp: if not skipWarmup: - os.remove(generate_compile_kernel_filename()) + os.remove(get_filename_compile_driver()) try: - os.remove(generate_compile_kernel_filename() + ".failed_configs") + os.remove(get_filename_compile_driver() + ".failed_configs") except OSError: pass for i in range(jobs): - generated_script = generated_kernel_name(M, N, K, i) + generated_script = get_filename_profile_driver(M, N, K, i) os.remove(generated_script) for f in glob.glob(f"results_{i}.*"): os.remove(f) @@ -1005,7 +659,7 @@ def main(): # Check correctness if asked to if args.compare: print("correctness: ", end=" ", flush=True) - test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, bestConfig, False) + test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, bestConfig, bias_vector, False) elif not run_bench: print("", flush=True) diff --git a/scripts/amd/gemm/utils/file_generator.py b/scripts/amd/gemm/utils/file_generator.py new file mode 100644 index 0000000000000..b759d9638fe6d --- /dev/null +++ b/scripts/amd/gemm/utils/file_generator.py @@ -0,0 +1,350 @@ +import os +from .name_utils import * + + +def read_config(config): + block_m = config.get('BLOCK_SIZE_M') + block_n = config.get('BLOCK_SIZE_N') + block_k = config.get('BLOCK_SIZE_K') + group_m = config.get('GROUP_SIZE_M') + split_k = config.get('SPLIT_K') + num_warps = config.get('num_warps') + num_stages = config.get('num_stages') + waves_per_eu = config.get('waves_per_eu') + mfma_instr_size = config.get('matrix_instr_nonkdim') + kpack = config.get('kpack') + return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack + + +def gen_configStr(config): + block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config( + config) + + ## {M}_{N}_{K} is removed since the same kernel can be used for differen gemm sizes + configStr = f"BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}" + + return configStr + + +def generate_matmul_kernels(configs): + """ + Generate all the kernels in the tuning space based on ++configs++ + + Use the matmul_kernel template (../matmul_kernel.py) and append config to the + kernel name. E.g. matmul_kernel_BM256_BN256_BK64_GM1_SK1_nW1_nS0_EU0_kP2_mfma16() + """ + + f_kernel = open(get_filename_myKernels(), 'w') + + # write imports + import_str = """import triton +import triton.language as tl""" + f_kernel.write(import_str) + + with open( + os.path.dirname(os.path.abspath(__file__)) + + "/../matmul_kernel.py") as file: + matmul_kernel_code = file.read() + + for config in configs: + configStr = gen_configStr(config) + # Copy the matmul_kernel with name replaced + matmul_kernel_config = matmul_kernel_code.replace( + "matmul_kernel", f"matmul_kernel_{configStr}") + matmul_kernel_config = matmul_kernel_config.replace( + "import triton.language as tl", "") + matmul_kernel_config = matmul_kernel_config.replace( + "import triton", "") + f_kernel.write(matmul_kernel_config) + + f_kernel.close() + + +## construct the configStr and generate the wrapper function matmul_{configStr}() +## If `warmup` is set, the generated kernel will be **compiled** +def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, + dtype_c, bias_size, warmup): + block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config( + config) + + configStr = gen_configStr(config) + + use_bias = bias_size > 0 + + if warmup: + torch_dtype_a = 'fp16' + torch_dtype_b = 'fp16' + torch_dtype_c = 'fp16' + if dtype_a: + torch_dtype_a = tl_to_torch_types[name_to_tl_types[dtype_a]] + if dtype_b: + torch_dtype_b = tl_to_torch_types[name_to_tl_types[dtype_b]] + if dtype_c: + torch_dtype_c = tl_to_torch_types[name_to_tl_types[dtype_c]] + + matmul_def_str = f""" +def matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn): + matmul_kernel_{configStr}.warmup( + {torch_dtype_a}, {torch_dtype_b}, {torch_dtype_c}, {torch_dtype_c}, + M, N, K, + am, ak, bk, bn, cm, cn, biasn, + BLOCK_SIZE_M = {block_m}, + BLOCK_SIZE_N = {block_n}, + BLOCK_SIZE_K = {block_k}, + GROUP_SIZE_M = {group_m}, + SPLIT_K = {split_k}, + num_warps = {num_warps}, + num_stages = {num_stages}, + waves_per_eu = {waves_per_eu}, + matrix_instr_nonkdim = {mfmaInstrSize}, + kpack = {kpack}, + BIAS={use_bias}, + EVEN_K={EVEN_K}, + grid=(1,), + ) + return None + +def try_compile_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn): + try: + matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn) + return True + except Exception as e: + print(f'invalid config(compilation): {configStr}: ', e, flush=True) + return False +""" + else: + matmul_def_str = f""" +def matmul_{configStr}(a, b, c, bias, M, N, K, am, ak, bk, bn, cm, cn, biasn): + grid = triton.cdiv(M, {block_m}) * triton.cdiv(N, {block_n}), {split_k} + matmul_kernel_{configStr}[grid]( + a, b, c, bias, + M, N, K, + am, ak, bk, bn, cm, cn, biasn, + BLOCK_SIZE_M = {block_m}, + BLOCK_SIZE_N = {block_n}, + BLOCK_SIZE_K = {block_k}, + GROUP_SIZE_M = {group_m}, + SPLIT_K = {split_k}, + num_warps = {num_warps}, + num_stages = {num_stages}, + waves_per_eu = {waves_per_eu}, + matrix_instr_nonkdim = {mfmaInstrSize}, + kpack = {kpack}, + BIAS = {use_bias}, + EVEN_K = {EVEN_K} + ) + return c +""" + return configStr, matmul_def_str + + +def generate_compile_driver(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, + init_type, configs, rotating_buffer_size, + bias_size): + """ + Generate a single file that contains all kernels in the tuning space. + This file is used to **compile** the kernels in parallel + """ + + filename = get_filename_compile_driver() + f_kernel = open(filename, 'w') + + # write imports + import_str = f"""import torch +import triton +import triton.language as tl +import argparse +import sys +import multiprocessing +from tune_gemm import gen_rotating_tensors +from {get_filename_without_extension(get_filename_myKernels())} import * +""" + + f_kernel.write(import_str + "\n") + + for config in configs: + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + configStr, matmul_def_str = gen_kernel_and_configStr_from_config( + config, EVEN_K, dtype_a, dtype_b, dtype_c, bias_size, True) + # Copy the matmul_kernel with name replaced + f_kernel.write(matmul_def_str + "\n") + + # write compile_kernels + # pre string + compile_kernels_pre_str = f"""def compile_kernels(M, N, K, rotating_buffer_size, bias_size, num_threads): + thread_pool = multiprocessing.Pool(processes=num_threads) + tensors = gen_rotating_tensors(M, N, K, '{dtype_a}', {col_a}, '{dtype_b}', {col_b}, '{dtype_c}', + 1, '{init_type}', rotating_buffer_size, bias_size, device='cuda') + + a = tensors['input_a'][0] + b = tensors['input_b'][0] + c = tensors['output_c'][0] + assert bias_size == M or bias_size == 0 + + stride_bias = tensors['bias'][0].stride(0) if bias_size > 0 else 0 + task_args = (M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), stride_bias) + + results = [] + config_names = [] +""" + f_kernel.write(compile_kernels_pre_str + "\n") + + # warm up call of all matmul functions in parallel + for config in configs: + configStr = gen_configStr(config) + task_str = f" results += [thread_pool.apply_async(try_compile_config_{configStr}, args=task_args)]\n" + \ + f" config_names += ['{configStr}']\n" + f_kernel.write(task_str) + + threadpool_str = """ + failed_configs = [] + for i in range(len(results)): + results[i].wait() + res = results[i].get() + if not res: + failed_configs += [config_names[i]] + thread_pool.close() + thread_pool.join() + if failed_configs: + with open("{filename}.failed_configs", "w") as f: + for cfg in failed_configs: + f.write(cfg + "\\n") +""".format(filename=filename) + f_kernel.write(threadpool_str) + + # def main and call compile_kernels + def_main_str = f""" +def main(): + parser = argparse.ArgumentParser( + prog="tune a specific gemm size", + allow_abbrev=False,) + parser.add_argument("-n", type=int, default=32, help='number of threads') + parser.add_argument("-rotating_tensor", type=int, default={rotating_buffer_size}, help='size of rotating buffer (MB), default: {rotating_buffer_size}') + args = parser.parse_args() + numThreads = args.n + rotating_buffer_size = args.rotating_tensor + """ + compile_kernels_call_str = f'compile_kernels({M}, {N}, {K}, rotating_buffer_size, {M}, numThreads)' + + f_kernel.write(def_main_str) + f_kernel.write(compile_kernels_call_str + "\n\n") + f_kernel.write("""if __name__ == '__main__': + sys.exit(main())""") + f_kernel.close() + + return filename + + +def generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, + init_type, configs, jobs, iters, run_bench, + rotating_buffer_size, bias_size, icache_flush): + """ + Open {len(jobs)} files + generated_kernelM-N-K-0.py, generated_kernelM-N-K-1.py, ..., generated_kernelM-N-K-{njobs-1}.py + and generate + 1. matmul kernels of all configs + 2. wrapper function matmul to invoke all the generated kernels + 3. test_gemm to invoke matmul in a loop of {iters} iterations + """ + + filenames = [] + for i in range(jobs): + filenames.append(get_filename_profile_driver(M, N, K, i)) + f_kernel = [open(path, 'w') for path in filenames] + + # write imports + import_str = f"""import torch +import triton +import triton.language as tl +import argparse +import sys +import multiprocessing +from tune_gemm import gen_rotating_tensors +from {get_filename_without_extension(get_filename_myKernels())} import * +""" + if icache_flush: + import_str += """ +from icache_flush import icache_flush +""" + for fi in range(jobs): + f_kernel[fi].write(import_str + "\n") + + idx = 0 + for config in configs: + file_idx = idx % jobs + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + configStr, matmul_def_str = gen_kernel_and_configStr_from_config( + config, EVEN_K, dtype_a, dtype_b, dtype_c, bias_size, False) + # Copy the matmul_kernel with name replaced + f_kernel[file_idx].write(matmul_def_str + "\n") + idx += 1 + + # write test_gemm + # pre string + test_gemm_pre_str = f"""def test_gemm(M, N, K, rotating_buffer_size, bias_size): + tensors = gen_rotating_tensors(M, N, K, '{dtype_a}', {col_a}, '{dtype_b}', {col_b}, '{dtype_c}', + 1, '{init_type}', rotating_buffer_size, bias_size, device='cuda') + + a = tensors['input_a'][0] + b = tensors['input_b'][0] + c = tensors['output_c'][0] + assert bias_size == M or bias_size == 0 + + stride_bias = tensors['bias'][0].stride(0) if bias_size > 0 else 0 + + try: + with open("{get_filename_compile_driver()}.failed_configs", "r") as f: + failed_configs = [cfg.strip() for cfg in f.readlines()] + except Exception: + failed_configs = [] +""" + for fi in range(jobs): + f_kernel[fi].write(test_gemm_pre_str + "\n") + + # call all matmul_xxx functions + idx = 0 + runs = iters if run_bench else 200 + call_icache_flush = 'icache_flush()' if icache_flush else '' + for config in configs: + configStr = gen_configStr(config) + matmul_call_str = f""" + if '{configStr}' not in failed_configs: + rotating_num = tensors['rotating_num'] + for i in range({runs}): + a = tensors['input_a'][i % rotating_num] + b = tensors['input_b'][i % rotating_num] + c = tensors['output_c'][i % rotating_num] + bias = tensors['bias'][i % rotating_num] if bias_size > 0 else None""" + if icache_flush: + matmul_call_str += f""" + icache_flush()""" + matmul_call_str += f""" + d = matmul_{configStr}(a, b, c, bias, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), bias.stride(0))""" + f_kernel[idx % jobs].write(matmul_call_str + "\n") + idx += 1 + # post string + for fi in range(jobs): + f_kernel[fi].write(" return d\n") + + # def main and call test_gemm + def_main_str = f""" +def main(): + parser = argparse.ArgumentParser( + prog="tune a specific gemm size", + allow_abbrev=False,) + parser.add_argument("-n", type=int, default=1, help='number of threads') + parser.add_argument("-rotating_tensor", type=int, default={rotating_buffer_size}, help='size of rotating buffer (MB), default: {rotating_buffer_size}') + args = parser.parse_args() + numThreads = args.n + rotating_buffer_size = args.rotating_tensor + """ + test_gemm_call_str = f'test_gemm({M}, {N}, {K}, rotating_buffer_size, {M})' + for fi in range(jobs): + f_kernel[fi].write(def_main_str) + f_kernel[fi].write(test_gemm_call_str + "\n\n") + f_kernel[fi].write("""if __name__ == '__main__': + sys.exit(main())""") + f_kernel[fi].close() diff --git a/scripts/amd/gemm/utils/name_utils.py b/scripts/amd/gemm/utils/name_utils.py new file mode 100644 index 0000000000000..5b969672c707a --- /dev/null +++ b/scripts/amd/gemm/utils/name_utils.py @@ -0,0 +1,65 @@ +import torch +import triton +import triton.language as tl + +import os + +TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz') +TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz') +tl_to_torch_types = { + tl.float16: torch.float16, + tl.bfloat16: torch.bfloat16, + tl.float32: torch.float32, + tl.int8: torch.int8, + tl.int32: torch.int32, +} +if TORCH_HAS_FP8E5B16: + tl_to_torch_types[tl.float8e5b16] = torch.float8_e5m2fnuz +if TORCH_HAS_FP8E4B8: + tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz + +name_to_tl_types = { + 'int8': tl.int8, + 'int32': tl.int32, + 'fp16': tl.float16, + 'fp32': tl.float32, + 'bf16': tl.bfloat16, + 'fp8': tl.float8e4b8, + 'bf8': tl.float8e5b16, +} + + +def get_filename_myKernels(): + path = os.path.dirname(os.path.abspath(__file__)) + return f"{path}/../myKernels.py" + + +def get_filename_without_extension(file_path): + base_name = os.path.basename(file_path) + file_name, _ = os.path.splitext(base_name) + return file_name + + +def get_filename_compile_driver(): + path = os.path.dirname(os.path.abspath(__file__)) + return f"{path}/../compile_driver.py" + + +def get_filename_profile_driver(M, N, K, job_id): + path = os.path.dirname(os.path.abspath(__file__)) + return f"{path}/../profile_driver_{M}x{N}x{K}_{job_id}.py" + + +def get_default_tuning_result_filename(): + git_branch_name = run_bash_command("git rev-parse --abbrev-ref HEAD") + git_branch_name = git_branch_name[0].decode() + # handle branch name of "xxx/xxx" format + git_branch_name = git_branch_name.replace('/', '_') + git_commit_hash = run_bash_command("git rev-parse --short HEAD") + git_commit_hash = git_commit_hash[0].decode() + + dt_string = datetime.now().strftime("%m-%d-%Y-%H:%M:%S") + + path = os.path.dirname(os.path.abspath(__file__)) + defaultName = f"{path}/../tuning_results_{git_branch_name}@{git_commit_hash}_{dt_string}.yaml" + return defaultName