diff --git a/scripts/amd/gemm/README.md b/scripts/amd/gemm/README.md index b1ab65b410e7..0abb5e410821 100644 --- a/scripts/amd/gemm/README.md +++ b/scripts/amd/gemm/README.md @@ -1,44 +1,176 @@ -# GEMM tuning script v2 +# GEMM tuning script (current v3.3) -This is the v2 version of the gemm tuning script, which is based on @scxiao's v1 (https://github.com/ROCmSoftwarePlatform/triton/pull/309) and @alefimov-amd's thread pool https://github.com/ROCmSoftwarePlatform/triton/pull/310 +## matmul kernel -### Main features -- `rocprof` is used to measure the time for kernels in the full tuning space -- Each kernel is executed 10 times and the execution time of the last instance is used -- All kernels are compiled in parallel -- Two modes for correctness checking - - During tuning, check correctness with the best perf_config for the current gemm size - - Without tuning, check correctness based on the tuning results, which includes best perf_config for each gemm size -- The process takes about 30 - 40 minutes for the full tuning space with ~15000 configs -- Limitations - - For now, only support fp16 as inputs. It should be trivial to extend to other types, but may require some work for mixed inputs +The matmul kernel implementation can be found as [matmul_kernel.py](https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/matmul_kernel.py), which includes the following features: +- grouping order of workgroup id, which is controlled by `GROUP_SIZE_M`, that +implements L2 cache optimization introduced in the [tutorial](https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#l2-cache-optimizations). +- split-k algorithm, which is controlled by `SPLIT_K`. +- Bias along M dim, which is controlled by `BIAS` and `bias_ptr`. +- Masked load along K dim inside the loop, which is controlled by `EVEN_K`. +This means `BLOCK_SIZE_K` does not need to divide K dim. -### Usage -Go to the script dir -```bash -cd triton/scripts/amd/gemm/ +### Differences between the tutorial + +Unlike the [matmul tutorial](https://github.com/triton-lang/triton/blob/main/python/tutorials/03-matrix-multiplication.py) (referred as the tutorial), +the matmul kernel used in the tuning script (referred as the kernel) does not +guard load along M and N dim +([this](https://github.com/triton-lang/triton/blob/main/python/tutorials/03-matrix-multiplication.py#L282-L283) shows how this is done in the tutorial). +When `BLOCK_SIZE_M` or `BLOCK_SIZE_N` does not divide M or N, the kernel will +load out-of-bound data. +In most cases this is fine, since the kernel does masked store at the end. +However, this may lead to GPU memory access fault in some cases, especially +when the tensor is large. +We will fix this issue in the future. + + +## Tuning script usage + +### Tuning mode + +The tuning script can take one or more gemm sizes and run tuning for them. +The input gemm sizes are prepared in a yaml file. Here is an example yaml file: +```yaml +- {'M': 4864, 'N': 4096, 'K': 8256, 'rowMajorA': 'T', 'rowMajorB': 'N'} +- {'M': 512, 'N': 512, 'K': 512, 'rowMajorA': 'T', 'rowMajorB': 'N'} ``` -1. Tune gemm sizes given in a yaml file and check correctness on the way -```bash -python tune_gemm.py --gemm_size_file input_gemm_sizes.yaml --compare +The tuning script works as follows +```python +./tune_gemm.py --gemm_size_file input.yaml [options] +``` +The following `options` are supported in the tuning mode + +- Input data types: + - `-dtype_a dtype`, `-dtype_b dtype`, and `-dtype_c dtype`: input and output element type. + - Supported `dtype`: fp16 (default), bf16, fp8, bf8, int8, int32, fp32 +- Parallel compilation of kernels: + - `num_threads n` controls that n threads will + be used in the compilation stage. The default value is 32. + - `--no_warmup` can be used to skip the compilation stage. Thus kernels will be + compiled during the profiling stage. This increases tuning time. But it's + required for some old torch version, in which some function used in the warmup + kernel launch is not supported. +- Parallel profiling of kernels: The tuning space is first divided into a number +of tasks, which is controlled by `--jobs n`. And all the tasks can be profiled in +parallel on a number of GPUs in the system. There are two ways to specify which +GPU(s) we want to use for profiling. Note that these flags cannot be use together. +By default, only one task is generated and profiled on GPU0. + - `--ngpus n`: GPU 0,1,.., n-1 will be used. + - `--gpu_ids ids`: `ids` are comma separated gpu ids and GPUs in `ids` will be used. +- General tuning control flags + - `--init_type INIT_TYPE` defines how input data are initialized. `INIT_TYPE` can be + - hpl: uniform distribution between -.5 and .5 + - trig_float: the distribution of elements in the flattened tensor follow + the `sin` function. + - zeros: initialize all data as 0, i.e. `torch.zeros` + - randn (default): normal distribution, i.e. `torch.randn` + - `--rotating_tensor SIZE`: provide the size of memory used for rotating tensor. + The default is 0, meaning rotating tensor is not used. + - `--icahe_flush`: If true, the script will generate a kernel to flush i-cache. + The default is False. + - `--bias_vector`: If true, a bias vector along the M dim is applied. + The default is False. +- Correctness check + - `--compare` will check the correctness of the best config for each gemm size. + - `--compare_wo_tuning` will check the correctness of the config provided in + the yaml file. If this is set, user needs to provide all the parameters in + the input yaml file. Example can be found in the benchmark mode section. +- Logistics + - `--keep` can be used to keep the files generated during the tuning process. + Be default, intermediate files are removed at the end. + - `--time_breakdown`: If set, the script will print out elapsed time during + each stage of the tuning in real-time. The default is False. + - `--verbose` will enable more logging message than `--time_breakdown`, such + as output from rocprofv2 + - `--o OUTPUT` can be used to control the output filename to store the tuning + result. The default filename is `tuning_results_branchName@gitCommit_timeStamp.yaml`. + Therefore, each time the user runs the tuning script, a different output file + will be generated. +- Hacks + - `--hack_triton_compiler`: If set, the triton source code will be modified + to provide a static backend target so that the compiler will not query + GPU information. This makes sure that during the compilation stage, no + hip runtime kernels are launched. + Note that this is a very hacky option, because + - It modifies the triton compiler directly, which is located from + `pip show triton`. + - It does string match and replace to modify the code. + - It does not restore the code when the tuning session terminates. + +Here are some example usages of running the script for tuning: + +Tune some gemm sizes with f16 input +```python +./tune_gemm.py --gemm_size_file input.yaml --ngpus 8 --jobs 32 --o output.yaml ``` +It's recommended to use as many GPUs as possible and set `--jobs` to +a value that is 4 to 6 times the number of GPUs. -2. Tune a single gemm size -```bash -python tune_gemm.py -m 16 -n 16 -k 16 +If you are only allowed to use a subset of the GPUs, you can +```python +./tune_gemm.py --gemm_size_file input.yaml --gpu_ids 0,1,3,4 --jobs 32 --o output.yaml ``` +This runs the profiling on GPU 0,1,3,4. -3. Choose the file to store tuning results -```bash -python tune_gemm.py --gemm_size_file input_gemm_sizes.yaml --o output_tuning.yaml +For bf8 input +```python +./tune_gemm.py --gemm_size_file input.yaml --ngpus 8 --jobs 32 -dtype_a bf8 -dtype_b bf8 ``` -4. Only check correctness given the tuning results -```bash -python tune_gemm.py --gemm_size_file output_tuning.yaml --compare_wo_tuning +Check correctness of the tuned configs +```python +./tune_gemm.py --gemm_size_file output.yaml --compare_wo_tuning ``` -Note that the tuning results file are provided as the `gemm_size_file` in this scenario. + + +### Benchmark mode + +In benchmark mode, the script will run a single given config multiple times to +collect performance data. The benchmark mode works as +The tuning script works as follows +```python +./tune_gemm.py --gemm_size_file input.yaml [options] --benchmark +``` +The supported `options` are as followings +- `-dtype_a dtype`, `-dtype_b dtype`, and `-dtype_c dtype`: same as tuning mode. +- `--iters n` controls the number of iterations to run the kernel. +The default value is 1000. +- `--icahe`: same as tuning mode +- `--rotating_tensor SIZE`: same as tuning mode + + +## Tuning script implementation overview + +The general idea of the tuning script can be summarized as +- Compile all the kernels in the tuning space in parallel. +- Divide the tuning space into tasks and invoke `rocprofv2` once per +task. This will save invocation overhead of the profiler. +- Profile tasks in parallel on multiple GPUs. + +For detailed implementation, please refer to the changelog of each version. + + +# Changelog + +## GEMM tuning script v1 + +Shucai (@scxiao) implemented the first version of gemm tuning script: https://github.com/ROCmSoftwarePlatform/triton/pull/309 + +## GEMM tuning script v2 + +This version is based on v1 and @alefimov-amd's thread pool https://github.com/ROCmSoftwarePlatform/triton/pull/310 + +### Main features +- `rocprof` is used to measure the time for kernels in the full tuning space +- Each kernel is executed 10 times and the execution time of the last instance is used +- All kernels are compiled in parallel +- Two modes for correctness checking + - During tuning, check correctness with the best perf_config for the current gemm size + - Without tuning, check correctness based on the tuning results, which includes best perf_config for each gemm size +- The process takes about 30 - 40 minutes for the full tuning space with ~15000 configs +- Limitations + - For now, only support fp16 as inputs. It should be trivial to extend to other types, but may require some work for mixed inputs ### Overview of implementations @@ -63,7 +195,7 @@ Workflow of the tuning process 5. Invoke `rocprof` on the generated script 6. Post process `results.csv` by extract the execution time of the last instance of each kernel. Pick the best one, write to file, and return. -# GEMM Tuning Script v3 +## GEMM Tuning Script v3 ### API changes @@ -89,41 +221,76 @@ This is necessary to keep each file "small" in terms of execution time. - Added error recovery. This helps when rocprof crashes in multi-processing mode. -### Example Usage -Let's say we have an input yaml file, named `gemm_input.yaml`, that contains the following configs -```yaml -- {'M': 4864, 'N': 4096, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'N'} -- {'M': 8192, 'N': 8192, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'N'} -``` -1. Tuning with bf8 input types with gpu 4,5,6,7, and save output to `output.yaml` -```bash -python ./tune_gemm.py --gemm_size_file gemm_input.yaml -dtype_a bf8 -dtype_b bf8 --gpu_ids 4,5,6,7 --o output.yaml -``` +## GEMM Tuning Script v3.1 -2. Check the correctness of the tuned configs -```bash -python ./tune_gemm.py --gemm_size_file output.yaml -dtype_a bf8 -dtype_b bf8 --compare_wo_tuning -``` +### API changes -3. Run benchmark of the tuned configs -```bash -python ./tune_gemm.py --gemm_size_file output.yaml -dtype_a bf8 -dtype_b bf8 --benchmark -``` +- Added `matrix_instr_nonkdim` into the tuning space. Now we can tune mfma instruction size. -A sample output from `benchmark` looks like -```bash -Benchmarking gemm with bf8 inputs (peak tflops: 1298) -trans M N K TFLOPS Efficiency -NT 4864 4096 8192 841.22 65% -NT 8192 8192 8192 745.31 57% -``` -# GEMM Tuning Script v3.1 +## GEMM Tuning Script v3.2 ### API changes -- Added `matrix_instr_nonkdim` into the tuning space. Now we can tune mfma instruction size. +- Added `--rotating_tensor ` to use rotating memory blocks in each iteration, size in MB. Default is 0MB. +- Added `--icache_flush` to flush icache in each iteration. +Note, icache flush needs the module `python-hip`, which can be installed as: +`python3 -m pip install -i https://test.pypi.org/simple hip-python~=$rocm_version` +Rotating tensor and icache flush are to make perf numbers are closer to that in real applications. +- Added `--bias_vector` to support kernel execution with bias (bias vector is of the same size as the number of rows of the output matrix, +so each element of the bias vector is added to all elements of the corresponding row of the output matrix.) + + +## GEMM Tuning Script v3.3 + +### API changes + +no API changes + +### Implementation changes + +- We use a dedicated file (named `get_filename_myKernels()`) to keep all the kernels +in the tuning space. +- Inside the for loop of tuning, each iteration tunes one gemm size + 1. Update kernel stage: Different gemm size may need different configs. We keep track + of the current tuning space. And if the current gemm size needs some configs that is + not included in the current tuning space, we expand the tuning space with the newly + added configs. + - This means if two gemm sizes share some configs, these configs will be compiled + once. This will greatly reduce batch tuning time. + 2. Compilation stage: + - We generate a single compilation driver file, named compile_driver.py (this is + obtained from `get_filename_compile_driver`) which contains the wrapper functions + of all the configs in the **pruned** tuning space for this gemm size. + - All the kernels will be compiled by 32 threads by default. Compiling all the + kernels in a single file in parallel is faster than splitting them into multiple + files. This can greatly reduce the compile time of the tuning process. + - Note that we no longer generate matmul_kernel in this file. Kernels are imported + from myKernels.py. + 3. Profile stage + - We generate one task file per job, named `profile_driver_MxNxK_{job_id}.py` + (this is obtained from `get_filename_profile_driver`). The only difference is + that we no longer generate matmul_kernel in this file. Kernels are imported + from myKernels.py. +- `configStr` does not contain gemm size anymore. This allows the same matmul_{configStr} +kernel to be reused by different gemm sizes. +- `configStr` does not contain `_bias` if bias is provided. This is because we do not +expect to compare the same kernel w/ and w/o bias. Therefore, we treat bias in the same +way as gemm sizes. +- Add support for `EVEN_K` in the matmul kernel. Now the kernel support `BLOCK_SIZE_K` +that cannot divide `K`. +- Tuning result file is open and closed inside the tuning loop, enabling timely flush +of the tuning results. +- Now we use `rocprofv2` to measure kernel time. +- We can use `--hack_triton_compile` to avoid all GPU activities during the compilation +stage. This is achieved by modifying the triton frontend compiler in the following +places: + - Return True from the `is_active()` function in the hip hackend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/amd/backend/driver.py#L433) + - Return statically constructed GPUTarget from the `get_current_target()` + function in the hip backend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/amd/backend/driver.py#L437) + - Return False from the `is_active()` function in the cuda hackend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/nvidia/backend/driver.py#L383) + - Statically set `device` and `stream` in the [jit.py](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/python/triton/runtime/jit.py#L588-L589) # One config running script @@ -131,7 +298,7 @@ NT 8192 8192 8192 745.31 57% `one_config.py` is a script that runs one given matmul config. It is an interface to `tune_gemm.py` functionality and could be used for triton debugging. -### Usage +## Usage This script supports two methods to specify configuration parameters. @@ -147,15 +314,3 @@ This is how configs are printed by `tune_gemm.py` script ```bash python one_config.py --config_str M16_N8_K128_BM64_BN64_BK64_GM1_SK2_nW2_nS0_EU0_kP2_mfma16 ``` - -# GEMM Tuning Script v3.2 - -### API changes - -- Added `--rotating_tensor ` to use rotating memory blocks in each iteration, size in MB. Default is 0MB. -- Added `--icache_flush` to flush icache in each iteration. -Note, icache flush needs the module `python-hip`, which can be installed as: -`python3 -m pip install -i https://test.pypi.org/simple hip-python~=$rocm_version` -Rotating tensor and icache flush are to make perf numbers are closer to that in real applications. -- Added `--bias_vector` to support kernel execution with bias (bias vector is of the same size as the number of rows of the output matrix, -so each element of the bias vector is added to all elements of the corresponding row of the output matrix.) diff --git a/scripts/amd/gemm/matmul_kernel.py b/scripts/amd/gemm/matmul_kernel.py index 4559c46fdcc2..d5f854f3d8a1 100644 --- a/scripts/amd/gemm/matmul_kernel.py +++ b/scripts/amd/gemm/matmul_kernel.py @@ -12,6 +12,7 @@ def matmul_kernel( stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, SPLIT_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BIAS: tl.constexpr, + EVEN_K: tl.constexpr ): pid = tl.program_id(axis=0) pid_z = tl.program_id(1) @@ -41,8 +42,12 @@ def matmul_kernel( acc_dtype = tl.float32 if a_ptr.type.element_ty != tl.int8 else tl.int32 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) accumulator += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk @@ -56,4 +61,4 @@ def matmul_kernel( if SPLIT_K == 1: tl.store(c_ptrs, c, mask=c_mask) else: - tl.atomic_add(c_ptrs, c, mask=c_mask) \ No newline at end of file + tl.atomic_add(c_ptrs, c, mask=c_mask) diff --git a/scripts/amd/gemm/tune_gemm.py b/scripts/amd/gemm/tune_gemm.py old mode 100644 new mode 100755 index e5dee0d826cc..3fdd7da082b5 --- a/scripts/amd/gemm/tune_gemm.py +++ b/scripts/amd/gemm/tune_gemm.py @@ -1,10 +1,10 @@ -# fp8 +#!/usr/bin/env python3 + import argparse import sys import yaml import os import glob -import subprocess import torch import triton @@ -16,6 +16,10 @@ import multiprocessing import pandas as pd +from utils.file_generator import * +from utils.utils import * + + def is_hip_available(): try: __import__("hip") @@ -51,7 +55,28 @@ def get_full_tuning_space(): for waves_per_eu in waves_per_eu_range: for matrix_instr_nonkdim in matrix_instr_nonkdim_range: for kpack in kpack_range: - configs.append({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k, 'num_warps': num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu, 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack}) + configs.append({ + 'BLOCK_SIZE_M': + block_m, + 'BLOCK_SIZE_N': + block_n, + 'BLOCK_SIZE_K': + block_k, + 'GROUP_SIZE_M': + group_m, + 'SPLIT_K': + split_k, + 'num_warps': + num_warps, + 'num_stages': + num_stages, + 'waves_per_eu': + waves_per_eu, + 'matrix_instr_nonkdim': + matrix_instr_nonkdim, + 'kpack': + kpack + }) return configs @@ -60,6 +85,7 @@ def get_default_config(): full_configs = get_full_tuning_space() return full_configs[0] + def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): pruned_configs = [] @@ -70,7 +96,7 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): # TODO (zhanglx): figure out the boundary between large and small gemms large_gemm = False - if M >= 2048 and N >=2048: + if M >= 2048 and N >= 2048: large_gemm = True for config in configs: @@ -109,7 +135,7 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): # skip split_k that leads to EVEN_K = false leap = SPLIT_K * BLOCK_SIZE_K modv = K % leap - if modv != 0: + if modv != 0 and SPLIT_K != 1: continue # skip large GROUP_M if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: @@ -141,263 +167,18 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): def need_split_k(SIZE_M, SIZE_N, SIZE_K): return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 -def run_bash_command_wrapper(commandstring, capture=True): - try: - run_bash_command(commandstring, capture) - except subprocess.CalledProcessError as e: - if not capture: - print(f"running {commandstring} one more time") - run_bash_command(commandstring, capture) - -def run_bash_command(commandstring, capture=True): - if capture: - proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout=subprocess.PIPE) - return proc.stdout.splitlines() - proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash') - return None - -def read_config(config): - block_m = config.get('BLOCK_SIZE_M') - block_n = config.get('BLOCK_SIZE_N') - block_k = config.get('BLOCK_SIZE_K') - group_m = config.get('GROUP_SIZE_M') - split_k = config.get('SPLIT_K') - num_warps = config.get('num_warps') - num_stages = config.get('num_stages') - waves_per_eu = config.get('waves_per_eu') - mfma_instr_size = config.get('matrix_instr_nonkdim') - kpack = config.get('kpack') - return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack - - -def gen_kernel_and_configStr_from_config(M, N, K, config, dtype_a, dtype_b, dtype_c, bias_size): - block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(config) - torch_dtype_a = 'fp16' - torch_dtype_b = 'fp16' - torch_dtype_c = 'fp16' - if dtype_a: - torch_dtype_a = tl_to_torch_types[name_to_tl_types[dtype_a]] - if dtype_b: - torch_dtype_b = tl_to_torch_types[name_to_tl_types[dtype_b]] - if dtype_c: - torch_dtype_c = tl_to_torch_types[name_to_tl_types[dtype_c]] - configStr = f"M{M}_N{N}_K{K}_BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}" - if bias_size > 0: - configStr += "_bias" - use_bias = bias_size > 0 - matmul_def_str = f""" -def matmul_{configStr}(a, b, c, bias, M, N, K, am, ak, bk, bn, cm, cn, biasn, warmup=False): - grid = triton.cdiv(M, {block_m}) * triton.cdiv(N, {block_n}), {split_k} - #print(f'config: matmul_kernel_{configStr}', flush=True) - if warmup: - matmul_kernel_{configStr}.warmup( - {torch_dtype_a}, {torch_dtype_b}, {torch_dtype_c}, {torch_dtype_c}, - M, N, K, - am, ak, bk, bn, cm, cn, biasn, - BLOCK_SIZE_M = {block_m}, - BLOCK_SIZE_N = {block_n}, - BLOCK_SIZE_K = {block_k}, - GROUP_SIZE_M = {group_m}, - SPLIT_K = {split_k}, - num_warps = {num_warps}, - num_stages = {num_stages}, - waves_per_eu = {waves_per_eu}, - matrix_instr_nonkdim = {mfmaInstrSize}, - kpack = {kpack}, - BIAS={use_bias}, - grid=(1,), - ) - return None - else: - matmul_kernel_{configStr}[grid]( - a, b, c, bias, - M, N, K, - am, ak, bk, bn, cm, cn, biasn, - BLOCK_SIZE_M = {block_m}, - BLOCK_SIZE_N = {block_n}, - BLOCK_SIZE_K = {block_k}, - GROUP_SIZE_M = {group_m}, - SPLIT_K = {split_k}, - num_warps = {num_warps}, - num_stages = {num_stages}, - waves_per_eu = {waves_per_eu}, - matrix_instr_nonkdim = {mfmaInstrSize}, - kpack = {kpack}, - BIAS = {use_bias}, - ) - return c - -def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn): - try: - matmul_{configStr}(None, None, None, None, M, N, K, am, ak, bk, bn, cm, cn, biasn, True) - return True - except Exception as e: - print(f'invalid config(compilation): {configStr}: ', e, flush=True) - return False -""" - return configStr, matmul_def_str - - -def generated_kernel_name(M, N, K, gpu_id): - path = os.path.dirname(os.path.abspath(__file__)) - return f"{path}/generated_kernel{M}-{N}-{K}-{gpu_id}.py" - - -# Open {len(gpus)} files -# generated_kernelM-N-K-{gpus[0]}.py, generated_kernelM-N-K-{gpus[1]}.py, ..., generated_kernelM-N-K-{gpus[-1]}.py -# and generate -# 1. matmul kernels of all configs -# 2. wrapper function matmul to invoke all the generated kernels -# 3. Another wraper function try_config to invoke matmul function -# 4. test_gemm to invoke -# 4.1 run try_config in parallel -# 4.2 matmul in a loop of 10 iterations -def generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, jobs, iters, run_bench, rotating_buffer_size, bias_size, icache_flush): - filenames = [] - for i in range(jobs): - filenames.append(generated_kernel_name(M, N, K, i)) - f_kernel = [open(path, 'w') for path in filenames] - - # write imports - import_str = """import torch -import triton -import triton.language as tl -import argparse -import sys -import multiprocessing -from tune_gemm import gen_rotating_tensors -""" - if icache_flush: - import_str += """ -from icache_flush import icache_flush -""" - for fi in range(jobs): - f_kernel[fi].write(import_str + "\n") - - # write definitions of matmul_kernel_xxx - # and matmul_xxx and try_config - with open(os.path.dirname(os.path.abspath(__file__))+"/matmul_kernel.py") as file: - matmul_kernel_code = file.read() - idx = 0 - for config in configs: - file_idx = idx % jobs - configStr, matmul_def_str = gen_kernel_and_configStr_from_config(M, N, K, config, dtype_a, dtype_b, dtype_c, bias_size) - # Copy the matmul_kernel with name replaced - matmul_kernel_config = matmul_kernel_code.replace("matmul_kernel", f"matmul_kernel_{configStr}") - matmul_kernel_config = matmul_kernel_config.replace("import triton.language as tl", "") - matmul_kernel_config = matmul_kernel_config.replace("import triton", "") - f_kernel[file_idx].write(matmul_kernel_config + "\n\n") - f_kernel[file_idx].write(matmul_def_str + "\n") - idx += 1 - - # write test_gemm - # pre string - test_gemm_pre_str = f"""def test_gemm(M, N, K, rotating_buffer_size, bias_size, num_threads): - thread_pool = multiprocessing.Pool(processes=num_threads) - tensors = gen_rotating_tensors(M, N, K, '{dtype_a}', {col_a}, '{dtype_b}', {col_b}, '{dtype_c}', - 1, '{init_type}', rotating_buffer_size, bias_size, device='cuda') - - a = tensors['input_a'][0] - b = tensors['input_b'][0] - c = tensors['output_c'][0] - assert bias_size == M or bias_size == 0 - - stride_bias = tensors['bias'][0].stride(0) if bias_size > 0 else 0 - task_args = (M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), stride_bias) - - if num_threads > 1: - results = [] - config_names = [] -""" - for fi in range(jobs): - f_kernel[fi].write(test_gemm_pre_str + "\n") - - # warm up call of all matmul functions in parallel - idx = 0 - for config in configs: - configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config, None, None, None, bias_size) - task_str = f" results += [thread_pool.apply_async(try_config_{configStr}, args=task_args)]\n" + \ - f" config_names += ['{configStr}']\n" - f_kernel[idx % jobs].write(task_str) - idx += 1 - - for fi in range(jobs): - threadpool_str = """ - failed_configs = [] - for i in range(len(results)): - results[i].wait() - res = results[i].get() - if not res: - failed_configs += [config_names[i]] - thread_pool.close() - thread_pool.join() - with open("{filename}.failed_configs", "w") as f: - for cfg in failed_configs: - f.write(cfg + "\\n") - else: - try: - with open("{filename}.failed_configs", "r") as f: - failed_configs = [cfg.strip() for cfg in f.readlines()] - except Exception: - failed_configs = [] - """.format(filename=filenames[fi]) - f_kernel[fi].write(threadpool_str) - # call all matmul_xxx functions - idx = 0 - runs = iters if run_bench else 200 - call_icache_flush = 'icache_flush()' if icache_flush else '' - for config in configs: - configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config, None, None, None, bias_size) - matmul_call_str = f""" - if '{configStr}' not in failed_configs: - rotating_num = tensors['rotating_num'] - for i in range({runs}): - a = tensors['input_a'][i % rotating_num] - b = tensors['input_b'][i % rotating_num] - c = tensors['output_c'][i % rotating_num] - bias = tensors['bias'][i % rotating_num] if bias_size > 0 else None""" - if icache_flush: - matmul_call_str += f""" - icache_flush()""" - matmul_call_str += f""" - d = matmul_{configStr}(a, b, c, bias, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), bias.stride(0))""" - f_kernel[idx % jobs].write(matmul_call_str + "\n") - idx += 1 - # post string - for fi in range(jobs): - f_kernel[fi].write(" return d\n") - - # def main and call test_gemm - def_main_str = f""" -def main(): - parser = argparse.ArgumentParser( - prog="tune a specific gemm size", - allow_abbrev=False,) - parser.add_argument("-n", type=int, default=1, help='number of threads') - parser.add_argument("-rotating_tensor", type=int, default={rotating_buffer_size}, help='size of rotating buffer (MB), default: 256') - args = parser.parse_args() - numThreads = args.n - rotating_buffer_size = args.rotating_tensor - """ - test_gemm_call_str = f'test_gemm({M}, {N}, {K}, rotating_buffer_size, {M}, numThreads)' - for fi in range(jobs): - f_kernel[fi].write(def_main_str) - f_kernel[fi].write(test_gemm_call_str + "\n\n") - f_kernel[fi].write("""if __name__ == '__main__': - sys.exit(main())""") - f_kernel[fi].close() - def extract_kernel_time(M, N, K, config, df, bias_size): # Correct the header by removing 'sig' and 'obj' to reduce number from 21 to 19 - # once the bug(https://github.com/ROCm/rocprofiler/issues/144) fixed, we should + # once the bug(https://github.com/ROCm/rocprofiler/issues/144) fixed, we should # not need below two lines - cols = ['Index','KernelName','gpu-id','queue-id','queue-index','pid','tid','grd','wgr','lds','scr','arch_vgpr','accum_vgpr','sgpr','wave_size','DispatchNs','BeginNs','EndNs','CompleteNs'] + cols = [ + 'Index', 'KernelName', 'gpu-id', 'queue-id', 'queue-index', 'pid', + 'tid', 'grd', 'wgr', 'lds', 'scr', 'arch_vgpr', 'accum_vgpr', 'sgpr', + 'wave_size', 'DispatchNs', 'BeginNs', 'EndNs', 'CompleteNs' + ] df.columns = cols - configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config, None, None, None, bias_size) + configStr = gen_configStr(config) filtered_df = df[df['KernelName'].str.contains(configStr, na=False)].copy() filtered_df['DurationNs'] = filtered_df['EndNs'] - filtered_df['BeginNs'] meanTime = filtered_df['DurationNs'].tail(100).mean() @@ -412,35 +193,62 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose): os.environ['ROCR_VISIBLE_DEVICES'] = str(gpuid) jobId = gpuIdx while jobId < jobs: - kernel_name = generated_kernel_name(M, N, K, jobId) + kernel_name = get_filename_profile_driver(M, N, K, jobId) if verbose: print(f"profiling {kernel_name} on GPU {gpuid}") - run_bash_command_wrapper(f"rocprofv2 --plugin file --plugin-version 1 --kernel-trace -o {jobId} python {generated_kernel_name(M, N, K, jobId)}", capture=(verbose < 2)) + run_bash_command_wrapper( + f"rocprofv2 --plugin file --plugin-version 1 --kernel-trace -o {jobId} python {get_filename_profile_driver(M, N, K, jobId)}", + capture=(verbose < 2)) jobId += ngpus -def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, - run_bench, jobs, iters, skipWarmup, verbose=0, num_threads=16, - gpus=[0], rotating_buffer_size=256, bias_size = 0, icache_flush = False): - # Generate kernel out of all configs - generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, jobs, iters, run_bench, rotating_buffer_size, bias_size, icache_flush) - - # remove any compiled kernel in the cache - run_bash_command("rm -rf ~/.triton/cache") +def tune_gemm_config(M, + N, + K, + col_a, + col_b, + dtype_a, + dtype_b, + dtype_c, + init_type, + configs, + run_bench, + jobs, + iters, + skipWarmup, + verbose=0, + num_threads=32, + gpus=[0], + rotating_buffer_size=256, + bias_size=0, + icache_flush=False): # precompile the kernels in parallel start_time = datetime.now() if not skipWarmup: - for i in range(jobs): - kernel_name = generated_kernel_name(M, N, K, i) - run_bash_command(f"python {kernel_name} -n {num_threads}", capture=(verbose < 2)) + # Generate kernel out of all configs + fname = generate_compile_driver(M, N, K, col_a, col_b, dtype_a, + dtype_b, dtype_c, init_type, configs, + rotating_buffer_size, bias_size) + + run_bash_command(f"python {fname} -n {num_threads}", + capture=(verbose < 2)) compile_end = datetime.now() compile_time = compile_end - start_time if verbose: print(f"compile time: {compile_time}", flush=True) + # Generate kernels out of all configs + generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, + init_type, configs, jobs, iters, run_bench, + rotating_buffer_size, bias_size, icache_flush) + # profile generated kernels - running = [multiprocessing.Process(target=profile_batch_kernels, args=(M, N, K, gpu_id, gpus, jobs, verbose)) for gpu_id in gpus] + running = [ + multiprocessing.Process(target=profile_batch_kernels, + args=(M, N, K, gpu_id, gpus, jobs, verbose)) + for gpu_id in gpus + ] for p in running: p.start() for p in running: @@ -457,10 +265,21 @@ def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type thread_pool = multiprocessing.Pool(processes=num_threads) tasks = [] idx = 0 - df_prof = [pd.read_csv(f"results_{i}.csv", skiprows=1, header=None, delimiter=',', quotechar='"', escapechar='\\') for i in range(jobs)] + df_prof = [ + pd.read_csv(f"results_{i}.csv", + skiprows=1, + header=None, + delimiter=',', + quotechar='"', + escapechar='\\') for i in range(jobs) + ] for config in configs: file_idx = idx % jobs - tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, df_prof[file_idx], bias_size))] + tasks += [ + thread_pool.apply_async(extract_kernel_time, + args=(M, N, K, config, df_prof[file_idx], + bias_size)) + ] idx += 1 thread_pool.close() thread_pool.join() @@ -474,20 +293,24 @@ def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type bestConfig = config else: min_us = -1 - print(f"invalid config(post processing): SIZE {M} {N} {K}: {config}", flush=True) + print( + f"invalid config(post processing): SIZE {M} {N} {K}: {config}", + flush=True) post_end = datetime.now() post_time = post_end - profile_end if verbose: print(f"post procesing time: {post_time}", flush=True) return minTime, bestConfig, compile_time, profile_time, post_time + def gen_input(M, N, ty_name, needTrans, seed, init_type, device='cuda'): d_type = name_to_tl_types[ty_name] torch.manual_seed(seed) torch.cuda.manual_seed(seed) @triton.jit - def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + def copy_kernel(input_ptr, output_ptr, n_elements, + BLOCK_SIZE: tl.constexpr): offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements input = tl.load(input_ptr + offsets, mask=mask) @@ -496,11 +319,13 @@ def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): def init_by_size_and_type(size, dtype, init_type): if init_type == 'hpl': - return torch.empty(size, device='cuda', dtype=dtype).uniform_(-0.5, 0.5) + return torch.empty(size, device='cuda', + dtype=dtype).uniform_(-0.5, 0.5) # This init type has element[i] in row[j] equal to sin(i+j*N) elif init_type == 'trig_float': M, N = size - return torch.reshape(torch.arange(0, M*N), (M, N)).sin().to(dtype=dtype, device='cuda') + return torch.reshape(torch.arange(0, M * N), + (M, N)).sin().to(dtype=dtype, device='cuda') elif init_type == 'zeros': return torch.zeros(size, dtype=dtype, device='cuda') elif init_type == "randn": @@ -509,7 +334,8 @@ def init_by_size_and_type(size, dtype, init_type): else: raise ValueError("Bad matrix initialization type.") - raw_data = init_by_size_and_type((N,M) if needTrans else (M,N), torch.float32, init_type) + raw_data = init_by_size_and_type((N, M) if needTrans else (M, N), + torch.float32, init_type) if needTrans: raw_data = raw_data.T if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \ @@ -522,7 +348,7 @@ def init_by_size_and_type(size, dtype, init_type): f8_tensor = f8_tensor & 0b00111111 input = triton.reinterpret(f8_tensor, d_type) input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16) - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) n_elements = raw_data.numel() copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024) @@ -530,12 +356,19 @@ def init_by_size_and_type(size, dtype, init_type): # generate inputs/outputs according to rotating tensor size -def gen_rotating_tensors(M, N, K, - dtype_a, need_Trans_a, - dtype_b, need_Trans_b, - dtype_c, seed, init_type, - rotating_buffer_size, - bias_size, device='cuda'): +def gen_rotating_tensors(M, + N, + K, + dtype_a, + need_Trans_a, + dtype_b, + need_Trans_b, + dtype_c, + seed, + init_type, + rotating_buffer_size, + bias_size, + device='cuda'): a_size = M * K * type_name_to_bytes(dtype_a) b_size = K * N * type_name_to_bytes(dtype_b) c_size = M * N * type_name_to_bytes(dtype_c) @@ -551,26 +384,50 @@ def gen_rotating_tensors(M, N, K, c = [] bias = [] for i in range(block_count): - in_a, in_a_fp16 = gen_input(M, K, dtype_a, need_Trans_a, 1, init_type, device='cuda') + in_a, in_a_fp16 = gen_input(M, + K, + dtype_a, + need_Trans_a, + 1, + init_type, + device='cuda') a.append(in_a) - in_b, in_b_fp16 = gen_input(K, N, dtype_b, need_Trans_b, 2, init_type, device='cuda') + in_b, in_b_fp16 = gen_input(K, + N, + dtype_b, + need_Trans_b, + 2, + init_type, + device='cuda') b.append(in_b) - out_c = torch.zeros((M, N), dtype=tl_to_torch_types[name_to_tl_types[dtype_c]], device='cuda') + out_c = torch.zeros((M, N), + dtype=tl_to_torch_types[name_to_tl_types[dtype_c]], + device='cuda') c.append(out_c) if bias_size > 0: - bs, bs_fp16 = gen_input(M, 1, dtype_b, need_Trans_b, 2, init_type, device='cuda') + bs, bs_fp16 = gen_input(M, + 1, + dtype_b, + need_Trans_b, + 2, + init_type, + device='cuda') bias.append(bs.squeeze()) - in_outs = {"rotating_num": block_count, - "input_a": a, - "input_b": b, - "output_c": c, - "bias": bias} - + in_outs = { + "rotating_num": block_count, + "input_a": a, + "input_b": b, + "output_c": c, + "bias": bias + } + return in_outs -def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, use_bias): +def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, + num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, + use_bias): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" #assert a.is_contiguous(), "Matrix A must be contiguous" @@ -581,30 +438,40 @@ def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps grid = triton.cdiv(M, block_m) * triton.cdiv(N, block_n), split_k stride_bias = bias.stride(0) if use_bias else 0 - matmul_kernel[grid]( - a, b, c, bias, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - stride_bias=stride_bias, - BLOCK_SIZE_M=block_m, - BLOCK_SIZE_N=block_n, - BLOCK_SIZE_K=block_k, - GROUP_SIZE_M=group_m, - SPLIT_K=split_k, - num_warps=num_warps, - num_stages=num_stages, - waves_per_eu=waves_per_eu, - matrix_instr_nonkdim=mfmaInstrSize, - kpack=kpack, - BIAS=use_bias, - ) + EVEN_K = K % block_k == 0 + matmul_kernel[grid](a, + b, + c, + bias, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + stride_bias=stride_bias, + BLOCK_SIZE_M=block_m, + BLOCK_SIZE_N=block_n, + BLOCK_SIZE_K=block_k, + GROUP_SIZE_M=group_m, + SPLIT_K=split_k, + num_warps=num_warps, + num_stages=num_stages, + waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfmaInstrSize, + kpack=kpack, + BIAS=use_bias, + EVEN_K=EVEN_K) return c -def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, bias_vector, verbose): - block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(config) +def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, + init_type, config, bias_vector, verbose): + block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config( + config) use_bias = bias_vector torch.manual_seed(0) #a = torch.randn((M, K), device='cuda', dtype=datatype) @@ -613,12 +480,22 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type b, b_fp16 = gen_input(K, N, dtype_b, col_b, 2, init_type, device='cuda') bias = None if use_bias: - bias, bias_fp16 = gen_input(M, 1, dtype_b, col_b, 2, init_type, device='cuda') + bias, bias_fp16 = gen_input(M, + 1, + dtype_b, + col_b, + 2, + init_type, + device='cuda') bias = bias.squeeze() bias_fp16 = bias.squeeze() # Allocates output. - c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]]) - triton_output = matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, use_bias) + c = torch.zeros((M, N), + device=a.device, + dtype=tl_to_torch_types[name_to_tl_types[dtype_c]]) + triton_output = matmul(a, b, c, bias, block_m, block_n, block_k, group_m, + split_k, num_warps, num_stages, waves_per_eu, + mfmaInstrSize, kpack, use_bias) torch_output = torch.matmul(a_fp16, b_fp16) if use_bias: torch_output += bias_fp16[:, None] @@ -629,7 +506,10 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type size_str = '' if verbose: size_str = f'SIZE M: {M}, N: {N}, K: {K}, trans: {row_a_str}{row_b_str}' - if torch.allclose(triton_output.to(torch.float16), torch_output, atol=atol, rtol=rtol): + if torch.allclose(triton_output.to(torch.float16), + torch_output, + atol=atol, + rtol=rtol): print(f'{size_str} Correct✅') else: print(f"triton_output={triton_output}") @@ -637,19 +517,6 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type print(f'{size_str} Incorrect❌') -def get_default_tuning_result_filename(): - git_branch_name = run_bash_command("git rev-parse --abbrev-ref HEAD") - git_branch_name = git_branch_name[0].decode() - # handle branch name of "xxx/xxx" format - git_branch_name = git_branch_name.replace('/', '_') - git_commit_hash = run_bash_command("git rev-parse --short HEAD") - git_commit_hash = git_commit_hash[0].decode() - - dt_string = datetime.now().strftime("%m-%d-%Y-%H:%M:%S") - defaultName = f"tuning_results_{git_branch_name}@{git_commit_hash}_{dt_string}.yaml" - return defaultName - - def parse_args(): parser = argparse.ArgumentParser( prog="tune a specific gemm size", @@ -659,29 +526,112 @@ def parse_args(): parser.add_argument("-m", type=int, default=0) parser.add_argument("-n", type=int, default=0) parser.add_argument("-k", type=int, default=0) - parser.add_argument("-col_a", action='store_true', default=False, help='whether matrix a is column major') - parser.add_argument("-col_b", action='store_true', default=False, help='whether matrix b is column major') - parser.add_argument("-dtype_a", type=str, default='fp16', help="matrix a element data type") - parser.add_argument("-dtype_b", type=str, default='fp16', help="matrix b element data type") - parser.add_argument("-dtype_c", type=str, default='fp16', help="output element data type") - parser.add_argument("--ngpus", type=int, default=0, help='number of GPUs used in the profiling step') - parser.add_argument("--gpu_ids", type=lambda s: [int(id) for id in s.split(',')], default=[], help='list of gpu ids to use for tuning') - parser.add_argument("--gemm_size_file", type=str, default="", help='yaml file to indicate matrix size') - parser.add_argument("--o", type=str, default='', help='yaml file to store tuning results') - parser.add_argument("--keep", action='store_true', default=False, help='keep generated files') - parser.add_argument("--compare", action='store_true', default=False, help="Whether check result correctness") - parser.add_argument("--compare_wo_tuning", action='store_true', default=False, help="Whether check result correctness") - parser.add_argument("--benchmark", action='store_true', default=False, help="Benchmark the given config") - parser.add_argument("--time_breakdown", action='store_true', default=False, help="Show detailed time breakdown of each step during the tuning") - parser.add_argument("--verbose", action='store_true', default=False, help="enables time_breakdown and additional logging messages") - parser.add_argument("--num_threads", type=int, default=16, help="number of threads to use for kernel compilation and post processing") - parser.add_argument("--jobs", type=int, default=1, help="number of generated files") - parser.add_argument("--iters", type=int, default=1000, help="number of generated files") - parser.add_argument("--init_type", type=str, default='randn', help="Initialization type for input matrices (default uniform rand [0, 1.0)])") - parser.add_argument("--rotating_tensor", type=int, default=0, help="total size (MB) of all tensors (default 0 MB (no rotating tensor), need to be larger than the L1, L2, MALL size)") - parser.add_argument("--bias_vector", action='store_true', default=False, help="apply bias vector") - parser.add_argument("--icache_flush", action='store_true', default=False, help="apply icache flush in tuning performance") - parser.add_argument("--no_warmup", action='store_true', default=False, help="Do not call the warmup kernel") + parser.add_argument("-col_a", + action='store_true', + default=False, + help='whether matrix a is column major') + parser.add_argument("-col_b", + action='store_true', + default=False, + help='whether matrix b is column major') + parser.add_argument("-dtype_a", + type=str, + default='fp16', + help="matrix a element data type") + parser.add_argument("-dtype_b", + type=str, + default='fp16', + help="matrix b element data type") + parser.add_argument("-dtype_c", + type=str, + default='fp16', + help="output element data type") + parser.add_argument("--ngpus", + type=int, + default=0, + help='number of GPUs used in the profiling step') + parser.add_argument("--gpu_ids", + type=lambda s: [int(id) for id in s.split(',')], + default=[], + help='list of gpu ids to use for tuning') + parser.add_argument("--gemm_size_file", + type=str, + default="", + help='yaml file to indicate matrix size') + parser.add_argument("--o", + type=str, + default='', + help='yaml file to store tuning results') + parser.add_argument("--keep", + action='store_true', + default=False, + help='keep generated files') + parser.add_argument("--compare", + action='store_true', + default=False, + help="Whether check result correctness") + parser.add_argument( + "--compare_wo_tuning", + action='store_true', + default=False, + help="Whether check result correctness without tuning.") + parser.add_argument("--benchmark", + action='store_true', + default=False, + help="Benchmark the given config") + parser.add_argument( + "--time_breakdown", + action='store_true', + default=False, + help="Show detailed time breakdown of each step during the tuning") + parser.add_argument( + "--verbose", + action='store_true', + default=False, + help="enables time_breakdown and additional logging messages") + parser.add_argument( + "--num_threads", + type=int, + default=32, + help= + "number of threads to use for kernel compilation and post processing") + parser.add_argument("--jobs", + type=int, + default=1, + help="number of tasks during the profiling process") + parser.add_argument("--iters", + type=int, + default=1000, + help="number of iterations used in --benchmark mode") + parser.add_argument( + "--init_type", + type=str, + default='randn', + choices=['randn', 'hpl', 'trig_float', 'zeros'], + help="Input tensor initialization (default normal distribution)") + parser.add_argument( + "--rotating_tensor", + type=int, + default=0, + help="total size (MB) of all tensors (a, b, c, bias)." + " The default value is 0 (no rotating tensor)." + " When set, it needs to be larger than the L1, L2, MALL size)") + parser.add_argument("--bias_vector", + action='store_true', + default=False, + help="apply bias vector") + parser.add_argument("--icache_flush", + action='store_true', + default=False, + help="apply icache flush in tuning performance") + parser.add_argument("--no_warmup", + action='store_true', + default=False, + help="Whether we want to skip the compilation stage") + parser.add_argument("--hack_triton_compiler", + action='store_true', + default=False, + help="Modify the triton source to avoid backend query") args = parser.parse_args() if not args.o: if args.benchmark: @@ -692,30 +642,6 @@ def parse_args(): return args -TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz') -TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz') -tl_to_torch_types = { - tl.float16: torch.float16, - tl.bfloat16: torch.bfloat16, - tl.float32: torch.float32, - tl.int8: torch.int8, - tl.int32: torch.int32, -} -if TORCH_HAS_FP8E5B16: - tl_to_torch_types[tl.float8e5b16] = torch.float8_e5m2fnuz -if TORCH_HAS_FP8E4B8: - tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz - -name_to_tl_types = { - 'int8': tl.int8, - 'int32': tl.int32, - 'fp16': tl.float16, - 'fp32': tl.float32, - 'bf16': tl.bfloat16, - 'fp8': tl.float8e4b8, - 'bf8': tl.float8e5b16, -} - def process_item(item): M = item['M'] N = item['N'] @@ -770,6 +696,7 @@ def main(): jobs = args.jobs iters = args.iters skipWarmup = args.no_warmup + hack_triton = args.hack_triton_compiler # Get GPU ids ngpus = args.ngpus @@ -793,7 +720,9 @@ def main(): dtype_b = args.dtype_b dtype_c = args.dtype_c if not dtype_a in name_to_tl_types or not dtype_b in name_to_tl_types or not dtype_c in name_to_tl_types: - print(f"Unsupported dtype_a {args.dtype_a} or dtype_b {args.dtype_b} or dtype_c {args.dtype_c}") + print( + f"Unsupported dtype_a {args.dtype_a} or dtype_b {args.dtype_b} or dtype_c {args.dtype_c}" + ) print("Supported types: ", list(name_to_tl_types.keys())) sys.exit(1) rotating_buffer_size = args.rotating_tensor @@ -830,8 +759,10 @@ def main(): if args.compare_wo_tuning: for (M, N, K, col_a, col_b, myConfig) in mnks: if myConfig is None: - raise Exception("kernel config is None, need to provide a tuning config") - test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, myConfig, bias_vector, True) + raise Exception( + "kernel config is None, need to provide a tuning config") + test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, + init_type, myConfig, bias_vector, True) return configs_full = get_full_tuning_space() @@ -844,21 +775,54 @@ def main(): print("trans M N K TFLOPS us") f_results.write("trans,M,N,K,TFLOPS,us\n") else: - print(f"Tuning {len(mnks)} gemm sizes starts at: {start_time}", flush=True) + print(f"Tuning {len(mnks)} gemm sizes starts at: {start_time}", + flush=True) + + f_results.close() + + ## Before tuning starts, clear cache and previously generated kernel files + run_bash_command("rm -rf ~/.triton/cache") + run_bash_command(f"rm -rf {get_filename_myKernels()}") + + ## Modify triton compiler + ## Hacky !!! + if hack_triton: + patch_triton_compiler() + + configs = [] + ## Big for loop of tuning + ## Each iteration performs tuning for one gemm size for (M, N, K, col_a, col_b, myConfig) in mnks: + + f_results = open(output_file, 'a') + start_local_time = datetime.now() # Obtain a pruned tuning space according to gemm size # If running benchmark, use the provided config - pruned_configs = [myConfig] if run_bench else prune_configs(M, N, K, configs_full, type_name_to_bytes(dtype_a), type_name_to_bytes(dtype_b)) + pruned_configs = [myConfig] if run_bench else prune_configs( + M, N, K, configs_full, type_name_to_bytes(dtype_a), + type_name_to_bytes(dtype_b)) + + ## Only append new configs from the current gemm size + delta_configs = [ + config for config in pruned_configs if config not in configs + ] + configs += delta_configs + + ## Append new configs into the tuning space + generate_matmul_kernels(delta_configs) row_a_str = 'N' if col_a else 'T' row_b_str = 'N' if col_b else 'T' size_str = f'SIZE: {M} {N} {K} {row_a_str}{row_b_str}' if not run_bench: - print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True) + print(f"{size_str} nConfigs: {len(pruned_configs)}", + end=" ", + flush=True) else: - print(f"{row_a_str}{row_b_str} {M:5d} {N:5d} {K:5d} ", end="") + print(f"{row_a_str}{row_b_str} {M:5d} {N:5d} {K:5d} ", + end="") f_results.write(f"{row_a_str}{row_b_str},{M},{N},{K},") # The main tuning funtion for one gemm size @@ -870,10 +834,26 @@ def main(): # we consider bias size as M for now. bias_size = M if bias_vector else 0 minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config( - M, N, K, col_a, col_b, dtype_a, - dtype_b, dtype_c, init_type, pruned_configs, - run_bench, jobs, iters, skipWarmup, num_threads=args.num_threads, gpus=gpus, - verbose=verbose_level, rotating_buffer_size=rotating_buffer_size, bias_size=bias_size, icache_flush=icache_flush) + M, + N, + K, + col_a, + col_b, + dtype_a, + dtype_b, + dtype_c, + init_type, + pruned_configs, + run_bench, + jobs, + iters, + skipWarmup, + num_threads=args.num_threads, + gpus=gpus, + verbose=verbose_level, + rotating_buffer_size=rotating_buffer_size, + bias_size=bias_size, + icache_flush=icache_flush) # post processing the numbers perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6) @@ -881,45 +861,65 @@ def main(): formatted_tflops = format_output(tri_tflops) minTime = format_output(minTime) if not run_bench: - print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ", flush=True) + print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', + end=" ", + flush=True) - bestConfig_compact_str, _ = gen_kernel_and_configStr_from_config(M, N, K, bestConfig, None, None, None, bias_size) + bestConfig_compact_str = gen_configStr(bestConfig) if not run_bench: - print(f'best_config: {bestConfig_compact_str}', end=" ", flush=True) + print(f'best_config: {bestConfig_compact_str}', + end=" ", + flush=True) # write best config to tuning_results.yaml if run_bench: print(f"{formatted_tflops} {minTime}") f_results.write(f"{formatted_tflops},{minTime}\n") - sizeDict = {'M': M, 'N': N, 'K': K, 'rowMajorA': row_a_str, 'rowMajorB': row_b_str} + sizeDict = { + 'M': M, + 'N': N, + 'K': K, + 'rowMajorA': row_a_str, + 'rowMajorB': row_b_str + } sizeDict.update(bestConfig) if not run_bench: f_results.write("- " + str(sizeDict) + " ") - f_results.write(f'# TFLOPS: {formatted_tflops} time(us): {minTime}\n') + f_results.write( + f'# TFLOPS: {formatted_tflops} time(us): {minTime}\n') # remove generated files if asked to if not keepTmp: + if not skipWarmup: + os.remove(get_filename_compile_driver()) + try: + os.remove(get_filename_compile_driver() + + ".failed_configs") + except OSError: + pass for i in range(jobs): - generated_script = generated_kernel_name(M, N, K, i) + generated_script = get_filename_profile_driver(M, N, K, i) os.remove(generated_script) - if not skipWarmup: - os.remove(generated_script + ".failed_configs") for f in glob.glob(f"results_{i}.*"): os.remove(f) # Check correctness if asked to if args.compare: print("correctness: ", end=" ", flush=True) - test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, bestConfig, False) + test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, + init_type, bestConfig, bias_vector, False) elif not run_bench: print("", flush=True) end_local_time = datetime.now() if not run_bench: - print(f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)", flush=True) + print( + f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)", + flush=True) - f_results.close() + f_results.close() + ## End big loop for tuning end_time = datetime.now() tuning_time = end_time - start_time @@ -927,6 +927,11 @@ def main(): print(f"Tuning ends at: {end_time}") print(f"Total tuning time (h:m:s): {tuning_time}") + if hack_triton: + print( + "Triton compiler is hacked, don't forget to git restore the changes :)" + ) + if __name__ == '__main__': sys.exit(main()) diff --git a/scripts/amd/gemm/utils/file_generator.py b/scripts/amd/gemm/utils/file_generator.py new file mode 100644 index 000000000000..eea92cf6bf48 --- /dev/null +++ b/scripts/amd/gemm/utils/file_generator.py @@ -0,0 +1,355 @@ +import os +from .utils import * + + +def read_config(config): + block_m = config.get('BLOCK_SIZE_M') + block_n = config.get('BLOCK_SIZE_N') + block_k = config.get('BLOCK_SIZE_K') + group_m = config.get('GROUP_SIZE_M') + split_k = config.get('SPLIT_K') + num_warps = config.get('num_warps') + num_stages = config.get('num_stages') + waves_per_eu = config.get('waves_per_eu') + mfma_instr_size = config.get('matrix_instr_nonkdim') + kpack = config.get('kpack') + return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack + + +def gen_configStr(config): + block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config( + config) + + ## {M}_{N}_{K} is removed since the same kernel can be used for differen gemm sizes + configStr = f"BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}" + + return configStr + + +def generate_matmul_kernels(configs): + """ + Generate kernels based on configs and append them to get_filename_myKernels() + + Use the matmul_kernel template (../matmul_kernel.py) and append config to the + kernel name. E.g. matmul_kernel_BM256_BN256_BK64_GM1_SK1_nW1_nS0_EU0_kP2_mfma16() + """ + + if len(configs) == 0: + return + + f_kernel = open(get_filename_myKernels(), 'a') + + # write imports + import_str = """import triton +import triton.language as tl""" + f_kernel.write(import_str) + + with open( + os.path.dirname(os.path.abspath(__file__)) + + "/../matmul_kernel.py") as file: + matmul_kernel_code = file.read() + + for config in configs: + configStr = gen_configStr(config) + # Copy the matmul_kernel with name replaced + matmul_kernel_config = matmul_kernel_code.replace( + "matmul_kernel", f"matmul_kernel_{configStr}") + matmul_kernel_config = matmul_kernel_config.replace( + "import triton.language as tl", "") + matmul_kernel_config = matmul_kernel_config.replace( + "import triton", "") + f_kernel.write(matmul_kernel_config) + + f_kernel.close() + + +## construct the configStr and generate the wrapper function matmul_{configStr}() +## If `warmup` is set, the generated kernel will be **compiled** +def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, + dtype_c, bias_size, warmup): + block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config( + config) + + configStr = gen_configStr(config) + + use_bias = bias_size > 0 + + if warmup: + torch_dtype_a = 'fp16' + torch_dtype_b = 'fp16' + torch_dtype_c = 'fp16' + if dtype_a: + torch_dtype_a = tl_to_torch_types[name_to_tl_types[dtype_a]] + if dtype_b: + torch_dtype_b = tl_to_torch_types[name_to_tl_types[dtype_b]] + if dtype_c: + torch_dtype_c = tl_to_torch_types[name_to_tl_types[dtype_c]] + + matmul_def_str = f""" +def matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn): + matmul_kernel_{configStr}.warmup( + {torch_dtype_a}, {torch_dtype_b}, {torch_dtype_c}, {torch_dtype_c}, + M, N, K, + am, ak, bk, bn, cm, cn, biasn, + BLOCK_SIZE_M = {block_m}, + BLOCK_SIZE_N = {block_n}, + BLOCK_SIZE_K = {block_k}, + GROUP_SIZE_M = {group_m}, + SPLIT_K = {split_k}, + num_warps = {num_warps}, + num_stages = {num_stages}, + waves_per_eu = {waves_per_eu}, + matrix_instr_nonkdim = {mfmaInstrSize}, + kpack = {kpack}, + BIAS={use_bias}, + EVEN_K={EVEN_K}, + grid=(1,), + ) + return None + +def try_compile_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn): + try: + matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn) + return True + except Exception as e: + print(f'invalid config(compilation): {configStr}: ', e, flush=True) + return False +""" + else: + matmul_def_str = f""" +def matmul_{configStr}(a, b, c, bias, M, N, K, am, ak, bk, bn, cm, cn, biasn): + grid = triton.cdiv(M, {block_m}) * triton.cdiv(N, {block_n}), {split_k} + matmul_kernel_{configStr}[grid]( + a, b, c, bias, + M, N, K, + am, ak, bk, bn, cm, cn, biasn, + BLOCK_SIZE_M = {block_m}, + BLOCK_SIZE_N = {block_n}, + BLOCK_SIZE_K = {block_k}, + GROUP_SIZE_M = {group_m}, + SPLIT_K = {split_k}, + num_warps = {num_warps}, + num_stages = {num_stages}, + waves_per_eu = {waves_per_eu}, + matrix_instr_nonkdim = {mfmaInstrSize}, + kpack = {kpack}, + BIAS = {use_bias}, + EVEN_K = {EVEN_K} + ) + return c +""" + return configStr, matmul_def_str + + +def generate_compile_driver(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, + init_type, configs, rotating_buffer_size, + bias_size): + """ + Generate a single file that contains all kernels in the tuning space. + This file is used to **compile** the kernels in parallel + """ + + filename = get_filename_compile_driver() + f_kernel = open(filename, 'w') + + # write imports + import_str = f"""import torch +import triton +import triton.language as tl +import argparse +import sys +import multiprocessing +from tune_gemm import gen_rotating_tensors +from {get_filename_without_extension(get_filename_myKernels())} import * +""" + + f_kernel.write(import_str + "\n") + + for config in configs: + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + configStr, matmul_def_str = gen_kernel_and_configStr_from_config( + config, EVEN_K, dtype_a, dtype_b, dtype_c, bias_size, True) + # Copy the matmul_kernel with name replaced + f_kernel.write(matmul_def_str + "\n") + + # write compile_kernels + # pre string + stride_a_str = "1, M" if col_a else "M, 1" + stride_b_str = "1, N" if col_b else "N, 1" + stride_c_str = "N, 1" + compile_kernels_pre_str = f"""def compile_kernels(M, N, K, rotating_buffer_size, bias_size, num_threads): + thread_pool = multiprocessing.Pool(processes=num_threads) + + assert bias_size == M or bias_size == 0 + + stride_bias = 1 if bias_size > 0 else 0 + stride_am, stride_ak = {stride_a_str} + stride_bk, stride_bn = {stride_b_str} + stride_cm, stride_cn = {stride_c_str} + task_args = (M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, stride_bias) + + results = [] + config_names = [] +""" + f_kernel.write(compile_kernels_pre_str + "\n") + + # warm up call of all matmul functions in parallel + for config in configs: + configStr = gen_configStr(config) + task_str = f" results += [thread_pool.apply_async(try_compile_config_{configStr}, args=task_args)]\n" + \ + f" config_names += ['{configStr}']\n" + f_kernel.write(task_str) + + threadpool_str = """ + failed_configs = [] + for i in range(len(results)): + results[i].wait() + res = results[i].get() + if not res: + failed_configs += [config_names[i]] + thread_pool.close() + thread_pool.join() + if failed_configs: + with open("{filename}.failed_configs", "w") as f: + for cfg in failed_configs: + f.write(cfg + "\\n") +""".format(filename=filename) + f_kernel.write(threadpool_str) + + # def main and call compile_kernels + def_main_str = f""" +def main(): + parser = argparse.ArgumentParser( + prog="tune a specific gemm size", + allow_abbrev=False,) + parser.add_argument("-n", type=int, default=32, help='number of threads') + parser.add_argument("-rotating_tensor", type=int, default={rotating_buffer_size}, help='size of rotating buffer (MB), default: {rotating_buffer_size}') + args = parser.parse_args() + numThreads = args.n + rotating_buffer_size = args.rotating_tensor + """ + compile_kernels_call_str = f'compile_kernels({M}, {N}, {K}, rotating_buffer_size, {bias_size}, numThreads)' + + f_kernel.write(def_main_str) + f_kernel.write(compile_kernels_call_str + "\n\n") + f_kernel.write("""if __name__ == '__main__': + sys.exit(main())""") + f_kernel.close() + + return filename + + +def generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, + init_type, configs, jobs, iters, run_bench, + rotating_buffer_size, bias_size, icache_flush): + """ + Open {len(jobs)} files + generated_kernelM-N-K-0.py, generated_kernelM-N-K-1.py, ..., generated_kernelM-N-K-{njobs-1}.py + and generate + 1. matmul kernels of all configs + 2. wrapper function matmul to invoke all the generated kernels + 3. test_gemm to invoke matmul in a loop of {iters} iterations + """ + + filenames = [] + for i in range(jobs): + filenames.append(get_filename_profile_driver(M, N, K, i)) + f_kernel = [open(path, 'w') for path in filenames] + + # write imports + import_str = f"""import torch +import triton +import triton.language as tl +import argparse +import sys +import multiprocessing +from tune_gemm import gen_rotating_tensors +from {get_filename_without_extension(get_filename_myKernels())} import * +""" + if icache_flush: + import_str += """ +from icache_flush import icache_flush +""" + for fi in range(jobs): + f_kernel[fi].write(import_str + "\n") + + idx = 0 + for config in configs: + file_idx = idx % jobs + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + configStr, matmul_def_str = gen_kernel_and_configStr_from_config( + config, EVEN_K, dtype_a, dtype_b, dtype_c, bias_size, False) + # Copy the matmul_kernel with name replaced + f_kernel[file_idx].write(matmul_def_str + "\n") + idx += 1 + + # write test_gemm + # pre string + test_gemm_pre_str = f"""def test_gemm(M, N, K, rotating_buffer_size, bias_size): + tensors = gen_rotating_tensors(M, N, K, '{dtype_a}', {col_a}, '{dtype_b}', {col_b}, '{dtype_c}', + 1, '{init_type}', rotating_buffer_size, bias_size, device='cuda') + + a = tensors['input_a'][0] + b = tensors['input_b'][0] + c = tensors['output_c'][0] + assert bias_size == M or bias_size == 0 + + stride_bias = tensors['bias'][0].stride(0) if bias_size > 0 else 0 + + try: + with open("{get_filename_compile_driver()}.failed_configs", "r") as f: + failed_configs = [cfg.strip() for cfg in f.readlines()] + except Exception: + failed_configs = [] +""" + for fi in range(jobs): + f_kernel[fi].write(test_gemm_pre_str + "\n") + + # call all matmul_xxx functions + idx = 0 + runs = iters if run_bench else 200 + call_icache_flush = 'icache_flush()' if icache_flush else '' + for config in configs: + configStr = gen_configStr(config) + matmul_call_str = f""" + if '{configStr}' not in failed_configs: + rotating_num = tensors['rotating_num'] + for i in range({runs}): + a = tensors['input_a'][i % rotating_num] + b = tensors['input_b'][i % rotating_num] + c = tensors['output_c'][i % rotating_num] + bias = tensors['bias'][i % rotating_num] if bias_size > 0 else None + bias_stride = bias.stride(0) if bias_size > 0 else 0""" + if icache_flush: + matmul_call_str += f""" + icache_flush()""" + matmul_call_str += f""" + d = matmul_{configStr}(a, b, c, bias, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), bias_stride)""" + f_kernel[idx % jobs].write(matmul_call_str + "\n") + idx += 1 + # post string + for fi in range(jobs): + f_kernel[fi].write(" return d\n") + + # def main and call test_gemm + def_main_str = f""" +def main(): + parser = argparse.ArgumentParser( + prog="tune a specific gemm size", + allow_abbrev=False,) + parser.add_argument("-n", type=int, default=1, help='number of threads') + parser.add_argument("-rotating_tensor", type=int, default={rotating_buffer_size}, help='size of rotating buffer (MB), default: {rotating_buffer_size}') + args = parser.parse_args() + numThreads = args.n + rotating_buffer_size = args.rotating_tensor + """ + test_gemm_call_str = f'test_gemm({M}, {N}, {K}, rotating_buffer_size, {bias_size})' + for fi in range(jobs): + f_kernel[fi].write(def_main_str) + f_kernel[fi].write(test_gemm_call_str + "\n\n") + f_kernel[fi].write("""if __name__ == '__main__': + sys.exit(main())""") + f_kernel[fi].close() diff --git a/scripts/amd/gemm/utils/utils.py b/scripts/amd/gemm/utils/utils.py new file mode 100644 index 000000000000..9b6b50ea626b --- /dev/null +++ b/scripts/amd/gemm/utils/utils.py @@ -0,0 +1,115 @@ +import torch +import triton +import triton.language as tl + +import os +import subprocess +from datetime import datetime + +TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz') +TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz') +tl_to_torch_types = { + tl.float16: torch.float16, + tl.bfloat16: torch.bfloat16, + tl.float32: torch.float32, + tl.int8: torch.int8, + tl.int32: torch.int32, +} +if TORCH_HAS_FP8E5B16: + tl_to_torch_types[tl.float8e5b16] = torch.float8_e5m2fnuz +if TORCH_HAS_FP8E4B8: + tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz + +name_to_tl_types = { + 'int8': tl.int8, + 'int32': tl.int32, + 'fp16': tl.float16, + 'fp32': tl.float32, + 'bf16': tl.bfloat16, + 'fp8': tl.float8e4b8, + 'bf8': tl.float8e5b16, +} + + +def run_bash_command_wrapper(commandstring, capture=True): + try: + run_bash_command(commandstring, capture) + except subprocess.CalledProcessError as e: + if not capture: + print(f"running {commandstring} one more time") + run_bash_command(commandstring, capture) + + +def run_bash_command(commandstring, capture=True): + if capture: + proc = subprocess.run(commandstring, + shell=True, + check=True, + executable='/bin/bash', + stdout=subprocess.PIPE) + return proc.stdout.splitlines() + proc = subprocess.run(commandstring, + shell=True, + check=True, + executable='/bin/bash') + return None + + +def get_filename_myKernels(): + path = os.path.dirname(os.path.abspath(__file__)) + return f"{path}/../myKernels.py" + + +def get_filename_without_extension(file_path): + base_name = os.path.basename(file_path) + file_name, _ = os.path.splitext(base_name) + return file_name + + +def get_filename_compile_driver(): + path = os.path.dirname(os.path.abspath(__file__)) + return f"{path}/../compile_driver.py" + + +def get_filename_profile_driver(M, N, K, job_id): + path = os.path.dirname(os.path.abspath(__file__)) + return f"{path}/../profile_driver_{M}x{N}x{K}_{job_id}.py" + + +def get_default_tuning_result_filename(): + git_branch_name = run_bash_command("git rev-parse --abbrev-ref HEAD") + git_branch_name = git_branch_name[0].decode() + # handle branch name of "xxx/xxx" format + git_branch_name = git_branch_name.replace('/', '_') + git_commit_hash = run_bash_command("git rev-parse --short HEAD") + git_commit_hash = git_commit_hash[0].decode() + + dt_string = datetime.now().strftime("%m-%d-%Y-%H:%M:%S") + + path = os.path.dirname(os.path.abspath(__file__)) + defaultName = f"{path}/../tuning_results_{git_branch_name}@{git_commit_hash}_{dt_string}.yaml" + return defaultName + + +def patch_triton_compiler(): + device = triton.runtime.driver.active.get_current_device() + stream = triton.runtime.driver.active.get_current_stream(device) + target = triton.runtime.driver.active.get_current_target() + + triton_location_str = run_bash_command("pip show triton | grep Editable") + if not triton_location_str: + print("triton source not found from pip show triton") + + triton_dir = triton_location_str[0].split()[-1].decode('utf-8') + + jit_filename = os.path.join(triton_dir, "triton/runtime", "jit.py") + + run_bash_command(f"sed -i 's/driver.active.get_current_device()/{device}/g' {jit_filename}") + run_bash_command(f"sed -i 's/driver.active.get_current_stream(device)/{stream}/g' {jit_filename}") + + hip_driver_filename = os.path.join(triton_dir, "../third_party/amd/backend/", "driver.py") + cuda_driver_filename = os.path.join(triton_dir, "../third_party/nvidia/backend/", "driver.py") + + run_bash_command(f"sed -i 's/import torch/return True/g' {hip_driver_filename}") + run_bash_command(f"sed -i 's/device = self.get_current_device()/return GPUTarget(\"hip\", \"{target.arch}\", 64)/g' {hip_driver_filename}") + run_bash_command(f"sed -i 's/import torch/return False/g' {cuda_driver_filename}") diff --git a/scripts/amd/occ.sh b/scripts/amd/occ.sh index f34246a173eb..51c8f9095907 100755 --- a/scripts/amd/occ.sh +++ b/scripts/amd/occ.sh @@ -67,3 +67,5 @@ echo "$perf" sed -i '/local_/! {/\.loc/d}' output.mlir sed -i '/\.Ltmp.*:/d' output.mlir sed -i '/AMD clang version/d' output.mlir + +sed -n '/AMDGCN/, $p' output.mlir > output.amdgcn