AMOS: Enabling Automatic Mapping for Tensor Computations On Spatial Accelerators with Hardware Abstraction
Install | Tutorials | Dive into the code | Benchmark | Documentations
AMOS is a mapper that can automatically map tensor computations to spatial accelerators via intrinsic. When we discuss the problem of mapping, we tend to divide the problem into two exclusive parts: hardware-aware mapping and ISA-aware mapping. Hardware-aware mapping is to map software directly to hardware units at proper spatial-temporal steps. ISA-aware mapping is to map software to hardware through intermediate instructions called intrinsic. For example, when we map tensor computations to Tensor Core, we will need to emit Tensor Core instructions (e.g., CUDA WMMA or PTX MMA). ISA-aware mapping is hard because:
- The users have to configure and transform the compute according to the constraints of intrinsic or hardware.
- It depends on the user to decide how to map software loops to intrinsics. Usually, there are more than one mapping possibility. But for an inexperienced user, the best choice may be ignored.
cd ~
git clone https://github.com/KnowingNothing/AMOS.git
Then get the submodules
cd AMOS
git submodule update --init --recursive
mkdir build
cd build
cp ../cmake/config.cmake .
If you are not familiar with TVM, please stick to the following steps to configure config.cmake, otherwise, just jump to the cmake step. We recommend you to refer to the documents of TVM (https://tvm.apache.org/docs/install/from_source.html) for details.
Download LLVM source code from https://github.com/llvm/llvm-project to ~/LLVM. You can install LLVM to anywhere you want. Here we choose ~/LLVM/llvm-10
mkdir -p ~/LLVM
cd ~/LLVM
git clone [email protected]:llvm/llvm-project.git
cd llvm-project
git checkout llvmorg-10.0.0
mkdir build
cd build
cmake -DCMAKE_INSTALL_PREFIX=/home/<your-home-dir>/LLVM/llvm-10 -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_PROJECTS="clang;lld;lldb" ../llvm
make -j 20
make install
Then, go back to AMOS directory and modify the config.cmake file.
cd ~/AMOS/build
vim config.cmake
Change the USE_LLVM
variable to the path to llvm-config
, i.e., /home/<your-home-dir>/LLVM/llvm-10/bin/llvm-config
in our example.
Usually, CUDA toolkit should be installed by the administrater. If you can install CUDA on your own, you can follow the steps of https://developer.nvidia.com/cuda-downloads.
Assume we have CUDA-11.4 installed in /usr/local/cuda-11.4
.
You can add /usr/local/cuda-11.4/bin
to your PATH
so that you have access to nvcc
.
Then you can further modify the config.cmake file to change USE_CUDA
variable to value ON
.
To use OpenCL, we can use the OpenCL implementation of Nvidia, which is shipped with CUDA toolkit.
You can simply add /usr/local/cuda-11.4/lib64
and /usr/local/cuda-11.4/include
to your PATH
so that OpenCL libraries can be found.
And modify config.cmake file by changing the value USE_OPENCL
to ON
.
cmake ..
make -j 20
First, we recommend you to use virualenv
to manage your Python libraries.
If you don't have virtualenv, you can install it locally. If you don't have a pip installed, there are many workarounds, e.g., you can install a Python from source (https://www.python.org/downloads/source/). The details about building Python locally can be found here (https://realpython.com/installing-python/).
python3 -m pip install --user virtualenv
Then use virtualenv to establish your first environment.
cd ~
mkdir venv
cd venv
python3 -m virtualenv <vir-name> -p python3
You can activate your virtual environment by
source ~/venv/<vir-name>/bin/activate
If you find it inconvenient to activate the environment, you can use symbolic link
mkdir -p .local/bin
cd .local/bin
ln -s /home/<your-home-dir>/venv/<vir-name>/bin/activate <vir-name>
Add /home/<your-home-dir>/.local/bin
to your PATH
so that you can use a simple source <vir-name>
to activate your Python environment.
source <vir-name>
Then, install python dependencies of TVM after activating your virtual environment.
(<vir-name>) pip install numpy decorator attrs tornado psutil xgboost==1.5.0 cloudpickle synr pebble sklearn
At last, setup the environments.
(<vir-name>) export AMOS_HOME=~/AMOS
(<vir-name>) export PYTHONPATH=$PYTHONPATH:$AMOS_HOME/python
This tutorial requires you to use a GPU that supports Tensor Core. GPUs that support Tensor Core should be with Volta/Turing/Ampere architecture. First of all, import AMOS and tvm. AMOS is implemented as part of tvm and serves as a function unit of tvm. So we can import AMOS (renamed as auto_tensorize in tvm) from tvm.
import tvm
from tvm import auto_tensorize as at
import numpy as np
After that, let's define a Conv2d compute.
def conv2d(N, C, H, W, K, R, S, stride, padding, dilation):
kH = (R - 1) * dilation + 1
kW = (S - 1) * dilation + 1
pH = H + 2 * padding
pW = W + 2 * padding
A = tvm.te.placeholder([N, C, H, W], dtype="float16", name="A")
B = tvm.te.placeholder([K, C, R, S], dtype="float16", name="B")
Pad = tvm.te.compute(
[N, C, pH, pW],
lambda n, c, h, w: tvm.tir.if_then_else(
tvm.tir.all(h >= padding, h - padding < H, w >= padding, w - padding < W),
A[n, c, h - padding, w - padding],
tvm.tir.const(0.0, A.dtype),
),
name="Pad",
)
rc = tvm.te.reduce_axis([0, C], name="rc")
rr = tvm.te.reduce_axis([0, kH], name="rr")
rs = tvm.te.reduce_axis([0, kW], name="rs")
P = (pH - kH) // stride + 1
Q = (pW - kW) // stride + 1
Conv = tvm.te.compute(
[N, K, P, Q],
lambda n, k, p, q: tvm.te.sum(
(
Pad[n, rc, p * stride + rr * dilation, q * stride + rs * dilation]
* B[k, rc, rr, rs]
).astype("float32"),
axis=[rc, rr, rs],
),
name="Conv",
)
return [A, B, Conv]
N, H, W, K, C, R, S, stride, padding, dilation = batch, 28, 28, 128, 128, 3, 3, 1, 1, 1
A, B, Conv = conv2d(N, C, H, W, K, R, S, stride, padding, dilation)
The code has no difference from a normal scalar-based conv2d program. Later, AMOS will automatically map this program to Tensor Core. A
, B
, Conv
are the input and output tensors.
We need to trace the program syntax and construct the compute DAG.
target_dag = at.compute_dag_from_tensors([Conv])
AMOS needs to perform hardware profiling during mapping exploration and optimization. So we need to set proper measure options.
measure_opt = at.MeasureOptions(target=target, timeout=10, min_repeat_ms=500)
The timeout
is in the unit of second. It is used to control the compilation and execution time. If the compilation or execution overhead exceeds the timeout limit, an error will be reported.
min_repeat_ms
is used to get accurate performance. Hardware profiling can be inaccurate if we only run the target program a few times. min_repeat_ms
will force the program to be executed for at least these milliseconds.
Use AMOS to perform mapping exploration:
result = at.auto_tensorize_v4(
target_dag,
"cuda", # code generation target
"conv2d_tutorial", # the log file
measure_opt,
schedule_log_dir="conv2d_tutorial",
trials=1200,
search_group_size=5,
transform_dump=False,
)
AMOS has multiple interfaces for mapping exploration.
Here we use the latest interface auto_tensorize_v4
.
The trials we use is 1200 for fast exploration. AMOS will automatically increase the trials to satisfy exploration requirements if the given trials is not enough.
You can also increase the trials to obtain better performance.
When exploring mappings, AMOS will also print a lot of useful message:
Possible matchings:
0 : MatchResult(hw_abs_dag:wmma_fp16_fp32, compute:nnn, shape:16x16x16)
1 : MatchResult(hw_abs_dag:wmma_fp16_fp32, compute:nnn, shape:32x8x16)
2 : MatchResult(hw_abs_dag:wmma_fp16_fp32, compute:nnn, shape:8x32x16)
3 : MatchResult(hw_abs_dag:wmma_fp16_fp32, compute:ntn, shape:16x16x16)
4 : MatchResult(hw_abs_dag:wmma_fp16_fp32, compute:ntn, shape:32x8x16)
5 : MatchResult(hw_abs_dag:wmma_fp16_fp32, compute:ntn, shape:8x32x16)
6 : MatchResult(hw_abs_dag:wmma_fp16_fp32, compute:tnn, shape:16x16x16)
7 : MatchResult(hw_abs_dag:wmma_fp16_fp32, compute:tnn, shape:32x8x16)
8 : MatchResult(hw_abs_dag:wmma_fp16_fp32, compute:tnn, shape:8x32x16)
9 : MatchResult(hw_abs_dag:wmma_fp16_fp32, compute:ttn, shape:16x16x16)
10 : MatchResult(hw_abs_dag:wmma_fp16_fp32, compute:ttn, shape:32x8x16)
11 : MatchResult(hw_abs_dag:wmma_fp16_fp32, compute:ttn, shape:8x32x16)
Logging to devnull...
Totally 35 different mappings for this matching
Logging to devnull...
Totally 35 different mappings for this matching
Catch an infeasible mapping:
{"vmap": [[0, 0, 0, 0, 0, 0, 1], -1]}
Catch an infeasible mapping:
{"vmap": [[0, 0, 0, 0, 1, 0, 0], -1]}
Catch an infeasible mapping:
{"vmap": [[0, 0, 0, 0, 1, 0, 1], -1]}
Catch an infeasible mapping:
{"vmap": [[0, 0, 1, 0, 0, 0, 0], -1]}
Catch an infeasible mapping:
{"vmap": [[0, 0, 1, 0, 0, 0, 1], -1]}
Catch an infeasible mapping:
{"vmap": [[0, 0, 1, 0, 1, 0, 0], -1]}
Catch an infeasible mapping:
{"vmap": [[0, 0, 1, 0, 1, 0, 1], -1]}
Catch an infeasible mapping:
{"vmap": [[0, 1, 0, 0, 0, 0, 0], -1]}
Catch an infeasible mapping:
{"vmap": [[0, 1, 0, 0, 0, 0, 1], -1]}
Catch an infeasible mapping:
{"vmap": [[0, 1, 0, 0, 1, 0, 0], -1]}
Catch an infeasible mapping:
{"vmap": [[0, 1, 1, 0, 0, 0, 0], -1]}
Catch an infeasible mapping:
{"vmap": [[1, 0, 0, 0, 0, 0, 0], -1]}
Catch an infeasible mapping:
{"vmap": [[1, 0, 1, 0, 0, 0, 0], -1]}
Catch an infeasible mapping:
{"vmap": [[1, 1, 0, 0, 0, 0, 0], -1]}
Catch an infeasible mapping:
{"vmap": [[1, 1, 1, 0, 0, 0, 0], -1]}
Total trials: 1200
Num rounds: 10
Num matching: 1
Num mapping: 20
Initial trials per matching: 120
Original weights [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05]
Original trials for each mapping [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
Current explored matching: MatchResult(hw_abs_dag:wmma_fp16_fp32, compute:nnn, shape:8x32x16)
Its axis mapping:
i: int32 : [n, n, n, p, p, q, q]
j: int32 : [k, k, k, k, k, k, k]
rk: int32 : [rc, rr, rs, rc, rs, rc, rr]
Current explored mapping: {"vmap": [[0, 0, 0, 0, 0, 1, 0], -1]}
Logging to conv2d-fp16-layer-0-batch-1/mapping_(0,0,0,0,0,1,0)_conv2d-fp16-layer-0-batch-1.log...
Loading from file conv2d-fp16-layer-0-batch-1/mapping_(0,0,0,0,0,1,0)_conv2d-fp16-layer-0-batch-1.log...
Load 0 entries! The best known is 10000000000000.000000 ms
Using arch: sm_86
.Y.Y.Y.Y.E
*Y*E*E*E
iteration=1: 6827.986947025252/6827.986947025252
.Y
*E
iteration=2: 1e-10/6827.986947025252
Best record value:6827.986947025252 (larger is better)
Round 1, Match 1, Mapping 1: 6827.986947025252/6827.986947025252(0.1464560503349629 ms), {"vmap": [[0, 0, 0, 0, 0, 1, 0], -1]}, {"inline": 0, "vectorize": 4, "spatial_factors": [[1, 1, 4, 1], [4, 1, 1, 1], [1, 1, 1, 1], [2, 14, 1, 1]], "reduce_factors": [[8, 1, 1], [1, 3, 1], [1, 1, 3]], "last_factors": [[98, 1, 32]], "output_unroll_step": 512, "last_unroll_step": 64}
Let's check the message line by line.
First, AMOS tells that there are totally 12 different matches for Tensor Core.
A match
refers to one applicable intrinsic.
For example, for our float16 Tensor Core, there are 3 different shpaes (16x16x16m 32x8x16m 8x32x16) and 4 layouts (nnn, ntn, ttn, tnn, where n
means matrix is not transposed and row-major layout is used, and t
means that matrix is transposed and col-major layout is used).
AMOS chooses one match from the 12 matches by minimizing the number of paddings and redundant computations.
For this tutorial, AMOS chooses MatchResult(hw_abs_dag:wmma_fp16_fp32, compute:nnn, shape:8x32x16)
.
For this match, there are 35 different mappings.
It may be surprising that there are 35 different methods to map a single Conv2d to Tensor Core.
AMOS can find these mapping (most not discovered before by developers or other compilers) by systematic generation and verification process (details illustrated in our paper).
From the 35 mappings, AMOS further rejects 15 infeasible mappings according to the concrete problem size.
For example, the batch size is 1 and it is infeasible to only map batch dimension to Tensor Core, which requires at least 8 elements along matrix row dimension.
After this, AMOS starts evaluating each mapping sequentially.
The profiling results .Y
means compilation success. *Y
means execution success.
The performance of the program after mapping is also shown during exploration.
After exploration, we can retrieve the results
schedule_gen = result.sch_gen
schedule_app = result.sch_app
# we store 1/time_cost in file
params, value = result.params, result.perf
cost = at.evaluate_params(schedule_app, params, measure_opt, dump=False)
print("Cost is %f ms" % cost)
And check the correctness of results
# retrieve schedule from the record
target_dag = schedule_app.target_dag
inputs = target_dag.get_inputs()
args = inputs + list(target_dag.tensors)
sch = tvm.te.create_schedule([x.op for x in target_dag.tensors])
sch = schedule_app.apply(sch, params)
print(tvm.lower(sch, args, simple_mode=True))
func = tvm.build(sch, args, target)
# test correctness
A, B = inputs
(Conv,) = target_dag.tensors
A_np = np.random.uniform(-10, 10, [int(x) for x in A.shape]).astype(A.dtype)
B_np = np.random.uniform(-10, 10, [int(x) for x in B.shape]).astype(B.dtype)
Conv_np = np.random.uniform(-10, 10, [int(x) for x in Conv.shape]).astype(Conv.dtype)
# use scipy convolve2d api
from tvm.topi.testing import conv2d_nchw_python
Conv_golden = conv2d_nchw_python(
A_np.astype("float32"), B_np.astype("float32"), stride, padding
)
ctx = tvm.context(target, 0)
A_tvm = tvm.nd.array(A_np, ctx)
B_tvm = tvm.nd.array(B_np, ctx)
Conv_tvm = tvm.nd.array(Conv_np, ctx)
func(A_tvm, B_tvm, Conv_tvm)
from tvm import testing
testing.assert_allclose(Conv_golden, Conv_tvm.asnumpy(), atol=1e-2, rtol=1e-2)
print("Correctness check passed!")
We can check the generated code to see if Tensor Core is really used
print(func.imported_modules[0].get_source())
The result code is
extern "C" __global__ void default_function_kernel0(half* __restrict__ Pad_vmap_input_cmap_input, half* __restrict__ A) {
for (int i_o_input_rk_o_input_fused_n_main_input_fused_rs_main_input_fused_i_input_fused_rk_input_fused_outer_outer_inner = 0; i_o_input_rk_o_input_fused_n_main_input_fused_rs_main_input_fused_i_input_fused_rk_input_fused_outer_outer_inner < 4; ++i_o_input_rk_o_input_fused_n_main_input_fused_rs_main_input_fused_i_input_fused_rk_input_fused_outer_outer_inner) {
Pad_vmap_input_cmap_input[((((((((((((int)blockIdx.x) * 32) + (i_o_input_rk_o_input_fused_n_main_input_fused_rs_main_input_fused_i_input_fused_rk_input_fused_outer_outer_inner * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4)) / 24) * 384) + ((((((int)blockIdx.x) * 4) + i_o_input_rk_o_input_fused_n_main_input_fused_rs_main_input_fused_i_input_fused_rk_input_fused_outer_outer_inner) % 3) * 128)) + (((int)threadIdx.y) * 32)) + ((int)threadIdx.x)))] = (((((1 <= ((((((((((((int)blockIdx.x) * 32) + (i_o_input_rk_o_input_fused_n_main_input_fused_rs_main_input_fused_i_input_fused_rk_input_fused_outer_outer_inner * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4)) / 576) * 8) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4)) / 28) + (((((((((((int)blockIdx.x) * 32) + (i_o_input_rk_o_input_fused_n_main_input_fused_rs_main_input_fused_i_input_fused_rk_input_fused_outer_outer_inner * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4)) % 576) / 24) * 16) + (((int)threadIdx.x) & 15)) % 3))) && (((((((((((((int)blockIdx.x) * 32) + (i_o_input_rk_o_input_fused_n_main_input_fused_rs_main_input_fused_i_input_fused_rk_input_fused_outer_outer_inner * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4)) / 576) * 8) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4)) / 28) + (((((((((((int)blockIdx.x) * 32) + (i_o_input_rk_o_input_fused_n_main_input_fused_rs_main_input_fused_i_input_fused_rk_input_fused_outer_outer_inner * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4)) % 576) / 24) * 16) + (((int)threadIdx.x) & 15)) % 3)) < 29)) && (1 <= ((((((((((((int)blockIdx.x) * 32) + (i_o_input_rk_o_input_fused_n_main_input_fused_rs_main_input_fused_i_input_fused_rk_input_fused_outer_outer_inner * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4)) / 576) * 8) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4)) % 28) + (((((int)blockIdx.x) * 4) + i_o_input_rk_o_input_fused_n_main_input_fused_rs_main_input_fused_i_input_fused_rk_input_fused_outer_outer_inner) % 3)))) && (((((((((((((int)blockIdx.x) * 32) + (i_o_input_rk_o_input_fused_n_main_input_fused_rs_main_input_fused_i_input_fused_rk_input_fused_outer_outer_inner * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4)) / 576) * 8) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4)) % 28) + (((((int)blockIdx.x) * 4) + i_o_input_rk_o_input_fused_n_main_input_fused_rs_main_input_fused_i_input_fused_rk_input_fused_outer_outer_inner) % 3)) < 29)) ? A[(((((((((((((((((((int)blockIdx.x) * 32) + (i_o_input_rk_o_input_fused_n_main_input_fused_rs_main_input_fused_i_input_fused_rk_input_fused_outer_outer_inner * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4)) % 576) / 24) * 16) + (((int)threadIdx.x) & 15)) / 3) * 784) + ((((((((((((int)blockIdx.x) * 32) + (i_o_input_rk_o_input_fused_n_main_input_fused_rs_main_input_fused_i_input_fused_rk_input_fused_outer_outer_inner * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4)) % 576) / 24) * 16) + (((int)threadIdx.x) & 15)) % 3) * 28)) + ((((((((int)blockIdx.x) * 32) + (i_o_input_rk_o_input_fused_n_main_input_fused_rs_main_input_fused_i_input_fused_rk_input_fused_outer_outer_inner * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4)) / 576) * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4)) + (((((int)blockIdx.x) * 4) + i_o_input_rk_o_input_fused_n_main_input_fused_rs_main_input_fused_i_input_fused_rk_input_fused_outer_outer_inner) % 3)) - 29))] : __float2half_rn(0.000000e+00f));
}
}
extern "C" __global__ void default_function_kernel3(float* __restrict__ Conv_vmap_output, float* __restrict__ memcpy_dst) {
for (int n_output_k_output_fused_p_output_fused_q_output_fused_outer_outer_inner = 0; n_output_k_output_fused_p_output_fused_q_output_fused_outer_outer_inner < 4; ++n_output_k_output_fused_p_output_fused_q_output_fused_outer_outer_inner) {
Conv_vmap_output[(((((((int)blockIdx.x) * 512) + (n_output_k_output_fused_p_output_fused_q_output_fused_outer_outer_inner * 128)) + (((int)threadIdx.y) * 32)) + ((int)threadIdx.x)))] = memcpy_dst[(((((((((((((int)blockIdx.x) * 512) + (n_output_k_output_fused_p_output_fused_q_output_fused_outer_outer_inner * 128)) + (((int)threadIdx.y) * 32)) + ((int)threadIdx.x)) % 784) >> 3) * 1024) + ((((((((int)blockIdx.x) * 512) + (n_output_k_output_fused_p_output_fused_q_output_fused_outer_outer_inner * 128)) + (((int)threadIdx.y) * 32)) + ((int)threadIdx.x)) / 25088) * 256)) + ((((int)threadIdx.x) & 7) * 32)) + ((((((((int)blockIdx.x) * 512) + (n_output_k_output_fused_p_output_fused_q_output_fused_outer_outer_inner * 128)) + (((int)threadIdx.y) * 32)) + ((int)threadIdx.x)) % 25088) / 784)))];
}
}
extern "C" __global__ void default_function_kernel1(half* __restrict__ B_vmap_input_cmap_input, half* __restrict__ B) {
for (int j_o_input_rk_o_input_fused_rs_main_input_fused_rk_input_fused_j_input_fused_outer_outer_inner = 0; j_o_input_rk_o_input_fused_rs_main_input_fused_rk_input_fused_j_input_fused_outer_outer_inner < 4; ++j_o_input_rk_o_input_fused_rs_main_input_fused_rk_input_fused_j_input_fused_outer_outer_inner) {
B_vmap_input_cmap_input[((((((((((((int)blockIdx.x) * 16) + (j_o_input_rk_o_input_fused_rs_main_input_fused_rk_input_fused_j_input_fused_outer_outer_inner * 4)) + ((int)threadIdx.y)) / 48) * 1536) + ((((int)blockIdx.x) % 3) * 512)) + (j_o_input_rk_o_input_fused_rs_main_input_fused_rk_input_fused_j_input_fused_outer_outer_inner * 128)) + (((int)threadIdx.y) * 32)) + ((int)threadIdx.x)))] = B[(((((((((((((int)blockIdx.x) * 16) + (j_o_input_rk_o_input_fused_rs_main_input_fused_rk_input_fused_j_input_fused_outer_outer_inner * 4)) + ((int)threadIdx.y)) / 1152) * 36864) + (((int)threadIdx.x) * 1152)) + ((((((((int)blockIdx.x) * 16) + (j_o_input_rk_o_input_fused_rs_main_input_fused_rk_input_fused_j_input_fused_outer_outer_inner * 4)) + ((int)threadIdx.y)) % 1152) / 48) * 48)) + (j_o_input_rk_o_input_fused_rs_main_input_fused_rk_input_fused_j_input_fused_outer_outer_inner * 12)) + (((int)threadIdx.y) * 3)) + (((int)blockIdx.x) % 3)))];
}
}
extern "C" __global__ void default_function_kernel2(half* __restrict__ Pad_vmap_input_cmap_input, half* __restrict__ B_vmap_input_cmap_input, float* __restrict__ memcpy_dst) {
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 8, 32, 16, float> Conv_vmap_main_cmap_main[1];
__shared__ half Pad_vmap_input_cmap_input_shared[1536];
__shared__ half B_vmap_input_cmap_input_shared[6144];
nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 8, 32, 16, half, nvcuda::wmma::row_major> memcpy_dst1[2];
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 8, 32, 16, half, nvcuda::wmma::row_major> memcpy_dst2[2];
(void)nvcuda::wmma::fill_fragment(Conv_vmap_main_cmap_main[0], 0.000000e+00f);
for (int rk_o_main_outer_outer = 0; rk_o_main_outer_outer < 4; ++rk_o_main_outer_outer) {
for (int rs_main_main_outer_outer = 0; rs_main_main_outer_outer < 3; ++rs_main_main_outer_outer) {
__syncthreads();
((uint1*)(Pad_vmap_input_cmap_input_shared + (((((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) >> 7) * 128) + ((((((int)threadIdx.y) * 4) + (((int)threadIdx.x) >> 3)) & 7) * 16)) + ((((int)threadIdx.x) & 7) * 2)))))[0] = ((uint1*)(Pad_vmap_input_cmap_input + ((((((((((int)blockIdx.x) >> 1) * 18432) + (rk_o_main_outer_outer * 2304)) + ((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) >> 7) * 384)) + (rs_main_main_outer_outer * 128)) + ((((((int)threadIdx.y) * 4) + (((int)threadIdx.x) >> 3)) & 7) * 16)) + ((((int)threadIdx.x) & 7) * 2)))))[0];
((uint1*)(Pad_vmap_input_cmap_input_shared + ((((((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) >> 7) * 128) + ((((((int)threadIdx.y) * 4) + (((int)threadIdx.x) >> 3)) & 7) * 16)) + ((((int)threadIdx.x) & 7) * 2)) + 256))))[0] = ((uint1*)(Pad_vmap_input_cmap_input + (((((((((((int)blockIdx.x) >> 1) * 18432) + (rk_o_main_outer_outer * 2304)) + ((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) >> 7) * 384)) + (rs_main_main_outer_outer * 128)) + ((((((int)threadIdx.y) * 4) + (((int)threadIdx.x) >> 3)) & 7) * 16)) + ((((int)threadIdx.x) & 7) * 2)) + 768))))[0];
((uint1*)(Pad_vmap_input_cmap_input_shared + ((((((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) >> 7) * 128) + ((((((int)threadIdx.y) * 4) + (((int)threadIdx.x) >> 3)) & 7) * 16)) + ((((int)threadIdx.x) & 7) * 2)) + 512))))[0] = ((uint1*)(Pad_vmap_input_cmap_input + (((((((((((int)blockIdx.x) >> 1) * 18432) + (rk_o_main_outer_outer * 2304)) + ((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) >> 7) * 384)) + (rs_main_main_outer_outer * 128)) + ((((((int)threadIdx.y) * 4) + (((int)threadIdx.x) >> 3)) & 7) * 16)) + ((((int)threadIdx.x) & 7) * 2)) + 1536))))[0];
((uint1*)(Pad_vmap_input_cmap_input_shared + ((((((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) >> 7) * 128) + ((((((int)threadIdx.y) * 4) + (((int)threadIdx.x) >> 3)) & 7) * 16)) + ((((int)threadIdx.x) & 7) * 2)) + 768))))[0] = ((uint1*)(Pad_vmap_input_cmap_input + (((((((((((int)blockIdx.x) >> 1) * 18432) + (rk_o_main_outer_outer * 2304)) + ((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) >> 7) * 384)) + (rs_main_main_outer_outer * 128)) + ((((((int)threadIdx.y) * 4) + (((int)threadIdx.x) >> 3)) & 7) * 16)) + ((((int)threadIdx.x) & 7) * 2)) + 9216))))[0];
((uint1*)(Pad_vmap_input_cmap_input_shared + (((((((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 1024) / 768) * 768) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) >> 7) + 2) * 128)) + ((((((int)threadIdx.y) * 4) + (((int)threadIdx.x) >> 3)) & 7) * 16)) + ((((int)threadIdx.x) & 7) * 2)))))[0] = ((uint1*)(Pad_vmap_input_cmap_input + (((((((((((int)blockIdx.x) >> 1) * 18432) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 1024) / 768) * 9216)) + (rk_o_main_outer_outer * 2304)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) >> 7) + 2) * 384)) + (rs_main_main_outer_outer * 128)) + ((((((int)threadIdx.y) * 4) + (((int)threadIdx.x) >> 3)) & 7) * 16)) + ((((int)threadIdx.x) & 7) * 2)))))[0];
((uint1*)(Pad_vmap_input_cmap_input_shared + (((((((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 1280) / 768) * 768) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) >> 7) + 4) * 128)) + ((((((int)threadIdx.y) * 4) + (((int)threadIdx.x) >> 3)) & 7) * 16)) + ((((int)threadIdx.x) & 7) * 2)))))[0] = ((uint1*)(Pad_vmap_input_cmap_input + (((((((((((int)blockIdx.x) >> 1) * 18432) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 1280) / 768) * 9216)) + (rk_o_main_outer_outer * 2304)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) >> 7) + 4) * 384)) + (rs_main_main_outer_outer * 128)) + ((((((int)threadIdx.y) * 4) + (((int)threadIdx.x) >> 3)) & 7) * 16)) + ((((int)threadIdx.x) & 7) * 2)))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + (((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)))))[0] = ((uint1*)(B_vmap_input_cmap_input + (((((((((int)blockIdx.x) & 1) * 73728) + (rk_o_main_outer_outer * 9216)) + (rs_main_main_outer_outer * 512)) + (((int)threadIdx.y) * 64)) + (((int)threadIdx.x) * 2)))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 256))))[0] = ((uint1*)(B_vmap_input_cmap_input + ((((((((((int)blockIdx.x) & 1) * 73728) + (rk_o_main_outer_outer * 9216)) + (rs_main_main_outer_outer * 512)) + (((int)threadIdx.y) * 64)) + (((int)threadIdx.x) * 2)) + 256))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 512))))[0] = ((uint1*)(B_vmap_input_cmap_input + ((((((((((int)blockIdx.x) & 1) * 73728) + (rk_o_main_outer_outer * 9216)) + (rs_main_main_outer_outer * 512)) + (((int)threadIdx.y) * 64)) + (((int)threadIdx.x) * 2)) + 1536))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 768) >> 9) * 512) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0] = ((uint1*)(B_vmap_input_cmap_input + ((((((((((int)blockIdx.x) & 1) * 73728) + (rk_o_main_outer_outer * 9216)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 768) >> 9) * 1536)) + (rs_main_main_outer_outer * 512)) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 1024))))[0] = ((uint1*)(B_vmap_input_cmap_input + ((((((((((int)blockIdx.x) & 1) * 73728) + (rk_o_main_outer_outer * 9216)) + (rs_main_main_outer_outer * 512)) + (((int)threadIdx.y) * 64)) + (((int)threadIdx.x) * 2)) + 3072))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 1280) >> 9) * 512) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0] = ((uint1*)(B_vmap_input_cmap_input + ((((((((((int)blockIdx.x) & 1) * 73728) + (rk_o_main_outer_outer * 9216)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 1280) >> 9) * 1536)) + (rs_main_main_outer_outer * 512)) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 1536))))[0] = ((uint1*)(B_vmap_input_cmap_input + ((((((((((int)blockIdx.x) & 1) * 73728) + (rk_o_main_outer_outer * 9216)) + (rs_main_main_outer_outer * 512)) + (((int)threadIdx.y) * 64)) + (((int)threadIdx.x) * 2)) + 4608))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 1792) >> 9) * 512) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0] = ((uint1*)(B_vmap_input_cmap_input + ((((((((((int)blockIdx.x) & 1) * 73728) + (rk_o_main_outer_outer * 9216)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 1792) >> 9) * 1536)) + (rs_main_main_outer_outer * 512)) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 2048))))[0] = ((uint1*)(B_vmap_input_cmap_input + ((((((((((int)blockIdx.x) & 1) * 73728) + (rk_o_main_outer_outer * 9216)) + (rs_main_main_outer_outer * 512)) + (((int)threadIdx.y) * 64)) + (((int)threadIdx.x) * 2)) + 6144))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 2304) >> 9) * 512) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0] = ((uint1*)(B_vmap_input_cmap_input + ((((((((((int)blockIdx.x) & 1) * 73728) + (rk_o_main_outer_outer * 9216)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 2304) >> 9) * 1536)) + (rs_main_main_outer_outer * 512)) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 2560))))[0] = ((uint1*)(B_vmap_input_cmap_input + ((((((((((int)blockIdx.x) & 1) * 73728) + (rk_o_main_outer_outer * 9216)) + (rs_main_main_outer_outer * 512)) + (((int)threadIdx.y) * 64)) + (((int)threadIdx.x) * 2)) + 7680))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 2816) >> 9) * 512) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0] = ((uint1*)(B_vmap_input_cmap_input + ((((((((((int)blockIdx.x) & 1) * 73728) + (rk_o_main_outer_outer * 9216)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 2816) >> 9) * 1536)) + (rs_main_main_outer_outer * 512)) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 3072))))[0] = ((uint1*)(B_vmap_input_cmap_input + ((((((((((int)blockIdx.x) & 1) * 73728) + (rk_o_main_outer_outer * 9216)) + (rs_main_main_outer_outer * 512)) + (((int)threadIdx.y) * 64)) + (((int)threadIdx.x) * 2)) + 36864))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 3328) >> 9) * 512) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0] = ((uint1*)(B_vmap_input_cmap_input + (((((((((((int)blockIdx.x) & 1) * 73728) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 3328) / 3072) * 36864)) + (rk_o_main_outer_outer * 9216)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 256) >> 9) * 1536)) + (rs_main_main_outer_outer * 512)) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 3584))))[0] = ((uint1*)(B_vmap_input_cmap_input + (((((((((((int)blockIdx.x) & 1) * 73728) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 3584) / 3072) * 36864)) + (rk_o_main_outer_outer * 9216)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 512) >> 9) * 1536)) + (rs_main_main_outer_outer * 512)) + ((((int)threadIdx.y) * 64) + ((((int)threadIdx.x) >> 4) * 32))) + ((((int)threadIdx.x) & 15) * 2)))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 3840) >> 9) * 512) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0] = ((uint1*)(B_vmap_input_cmap_input + (((((((((((int)blockIdx.x) & 1) * 73728) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 3840) / 3072) * 36864)) + (rk_o_main_outer_outer * 9216)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 768) >> 9) * 1536)) + (rs_main_main_outer_outer * 512)) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 4096))))[0] = ((uint1*)(B_vmap_input_cmap_input + (((((((((((int)blockIdx.x) & 1) * 73728) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 4096) / 3072) * 36864)) + (rk_o_main_outer_outer * 9216)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 1024) >> 9) * 1536)) + (rs_main_main_outer_outer * 512)) + ((((int)threadIdx.y) * 64) + ((((int)threadIdx.x) >> 4) * 32))) + ((((int)threadIdx.x) & 15) * 2)))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 4352) >> 9) * 512) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0] = ((uint1*)(B_vmap_input_cmap_input + (((((((((((int)blockIdx.x) & 1) * 73728) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 4352) / 3072) * 36864)) + (rk_o_main_outer_outer * 9216)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 1280) >> 9) * 1536)) + (rs_main_main_outer_outer * 512)) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 4608))))[0] = ((uint1*)(B_vmap_input_cmap_input + (((((((((((int)blockIdx.x) & 1) * 73728) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 4608) / 3072) * 36864)) + (rk_o_main_outer_outer * 9216)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 1536) >> 9) * 1536)) + (rs_main_main_outer_outer * 512)) + ((((int)threadIdx.y) * 64) + ((((int)threadIdx.x) >> 4) * 32))) + ((((int)threadIdx.x) & 15) * 2)))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 4864) >> 9) * 512) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0] = ((uint1*)(B_vmap_input_cmap_input + (((((((((((int)blockIdx.x) & 1) * 73728) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 4864) / 3072) * 36864)) + (rk_o_main_outer_outer * 9216)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 1792) >> 9) * 1536)) + (rs_main_main_outer_outer * 512)) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 5120))))[0] = ((uint1*)(B_vmap_input_cmap_input + (((((((((((int)blockIdx.x) & 1) * 73728) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 5120) / 3072) * 36864)) + (rk_o_main_outer_outer * 9216)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 2048) >> 9) * 1536)) + (rs_main_main_outer_outer * 512)) + ((((int)threadIdx.y) * 64) + ((((int)threadIdx.x) >> 4) * 32))) + ((((int)threadIdx.x) & 15) * 2)))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 5376) >> 9) * 512) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0] = ((uint1*)(B_vmap_input_cmap_input + (((((((((((int)blockIdx.x) & 1) * 73728) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 5376) / 3072) * 36864)) + (rk_o_main_outer_outer * 9216)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 2304) >> 9) * 1536)) + (rs_main_main_outer_outer * 512)) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 5632))))[0] = ((uint1*)(B_vmap_input_cmap_input + (((((((((((int)blockIdx.x) & 1) * 73728) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 5632) / 3072) * 36864)) + (rk_o_main_outer_outer * 9216)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 2560) >> 9) * 1536)) + (rs_main_main_outer_outer * 512)) + ((((int)threadIdx.y) * 64) + ((((int)threadIdx.x) >> 4) * 32))) + ((((int)threadIdx.x) & 15) * 2)))))[0];
((uint1*)(B_vmap_input_cmap_input_shared + ((((((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 5888) >> 9) * 512) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0] = ((uint1*)(B_vmap_input_cmap_input + (((((((((((int)blockIdx.x) & 1) * 73728) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 5888) / 3072) * 36864)) + (rk_o_main_outer_outer * 9216)) + (((((((int)threadIdx.y) * 64) + (((int)threadIdx.x) * 2)) + 2816) >> 9) * 1536)) + (rs_main_main_outer_outer * 512)) + ((((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4)) + 8) * 32)) + ((((int)threadIdx.x) & 15) * 2)))))[0];
__syncthreads();
(void)nvcuda::wmma::load_matrix_sync(memcpy_dst1[0], ((half *)Pad_vmap_input_cmap_input_shared + (((((int)threadIdx.y) >> 1) * 768))), 16);
(void)nvcuda::wmma::load_matrix_sync(memcpy_dst1[1], ((half *)Pad_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) >> 1) * 768) + 128))), 16);
(void)nvcuda::wmma::load_matrix_sync(memcpy_dst2[0], ((half *)B_vmap_input_cmap_input_shared + (((((int)threadIdx.y) & 1) * 3072))), 32);
(void)nvcuda::wmma::load_matrix_sync(memcpy_dst2[1], ((half *)B_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) & 1) * 3072) + 512))), 32);
(void)nvcuda::wmma::mma_sync(Conv_vmap_main_cmap_main[0], memcpy_dst1[0], memcpy_dst2[0], Conv_vmap_main_cmap_main[0]);
(void)nvcuda::wmma::mma_sync(Conv_vmap_main_cmap_main[0], memcpy_dst1[1], memcpy_dst2[1], Conv_vmap_main_cmap_main[0]);
(void)nvcuda::wmma::load_matrix_sync(memcpy_dst1[0], ((half *)Pad_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) >> 1) * 768) + 256))), 16);
(void)nvcuda::wmma::load_matrix_sync(memcpy_dst1[1], ((half *)Pad_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) >> 1) * 768) + 384))), 16);
(void)nvcuda::wmma::load_matrix_sync(memcpy_dst2[0], ((half *)B_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) & 1) * 3072) + 1024))), 32);
(void)nvcuda::wmma::load_matrix_sync(memcpy_dst2[1], ((half *)B_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) & 1) * 3072) + 1536))), 32);
(void)nvcuda::wmma::mma_sync(Conv_vmap_main_cmap_main[0], memcpy_dst1[0], memcpy_dst2[0], Conv_vmap_main_cmap_main[0]);
(void)nvcuda::wmma::mma_sync(Conv_vmap_main_cmap_main[0], memcpy_dst1[1], memcpy_dst2[1], Conv_vmap_main_cmap_main[0]);
(void)nvcuda::wmma::load_matrix_sync(memcpy_dst1[0], ((half *)Pad_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) >> 1) * 768) + 512))), 16);
(void)nvcuda::wmma::load_matrix_sync(memcpy_dst1[1], ((half *)Pad_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) >> 1) * 768) + 640))), 16);
(void)nvcuda::wmma::load_matrix_sync(memcpy_dst2[0], ((half *)B_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) & 1) * 3072) + 2048))), 32);
(void)nvcuda::wmma::load_matrix_sync(memcpy_dst2[1], ((half *)B_vmap_input_cmap_input_shared + ((((((int)threadIdx.y) & 1) * 3072) + 2560))), 32);
(void)nvcuda::wmma::mma_sync(Conv_vmap_main_cmap_main[0], memcpy_dst1[0], memcpy_dst2[0], Conv_vmap_main_cmap_main[0]);
(void)nvcuda::wmma::mma_sync(Conv_vmap_main_cmap_main[0], memcpy_dst1[1], memcpy_dst2[1], Conv_vmap_main_cmap_main[0]);
}
}
(void)nvcuda::wmma::store_matrix_sync(((float *)memcpy_dst + ((((((((int)blockIdx.x) >> 1) * 2048) + ((((int)threadIdx.y) >> 1) * 1024)) + ((((int)blockIdx.x) & 1) * 512)) + ((((int)threadIdx.y) & 1) * 256)))), Conv_vmap_main_cmap_main[0], 32, nvcuda::wmma::mem_row_major);
}
We can see nvcuda::wmma::mma_sync
, which means that Tensor Core is used.
There are multiple kernels because some kernels are used to do data transform.
In some cases, these kernels can also be fused by AMOS.
The peroformance of this code is 0.031456
ms after the early 200 tuning trials (the following 1000 trials are not done yet). If we wait the full tuning to complete, a better performance can be obtained. And on the same device (RTX 3090), the performnace of PyTorch v1.10 + CuDNN v8.2 is 0.10336
ms.
The main body of AMOS is put in
C++ header files: include/tvm/auto_tensorize/*.h
C++ source files: src/auto_tensorize/*
python files: python/tvm/auto_tensorize/*
tutorial files: tutorials/auto_tensorize/*
We also modified some code outside auto_tensorize
to facilitate the code generation process.
For example
src/target/source/codegen_c.h
src/target/source/codegen_c.cc
are modified.
The detailed list of modified code outside auto_tensorize
directory is omitted.
We define HardwareAbstraction
as a unified user interface to abstract hardware intrinsic.
For a given hardware intrinsic, the HardwareAbstraction
describe its loop structure (implemented as TVM tensor expression) as well as generation rule (with the help of tvm.te.TensorIntrin
). When we want to schedule a compute, we can use intrinsic loop structure to match the compute's loops to generate & verify possible software-hardware mappings, and generate the final source code based on its generation rule.
There're two main sub-class of HardwareAbstraction
: ComputeAbstraction
and MemoryAbstraction
. The former one is used to describe compute intrinsic like wmma::mma_sync
of Tensor Core. The latter one is used to describe memory loading/storing intrinsic like wmma::load_matrix_sync
of Tensor Core.
We implement it mainly in python/tvm/auto_tensorize/hw_abstraction/*
, include/tvm/auto_tensorize/hw_abstraction.h
, and src/auto_tensorize/hw_abstraction.cc
. We also pre-implement abstractions for some exsiting hardwares such as NVIDIA GPU (with Tensor Core), Intel CPU (with AVX512), and Mali GPU (with arm_dot intrinsic) in the corresponding subdirectories of python/tvm/auto_tensorize/hw_abstraction/*
.
Generally speaking, compute intrinsics are cooperating with memory intrinsics because of the data-path between on-chip and off-chip memory. We define HardwareAbstractionDAG
to describe the relations among intrinsics. A HardwareAbstractionDAG
contains a direct cyclic graph of several HardwareAbstraction
s with a ComputeAbstraction
as the "main abstraction". For example, two MemoryAbstraction
s for wmma::load_matrix_sync
, one MemoryAbstraction
for wmma::store_matrix_sync
, and one ComputeAbstraction
for wmma::mma_sync
construct a HardwareAbstractionDAG
, where the abstraction for wmma::mma_sync
is the main abstraction.
We implement it mainly in python/tvm/auto_tensorize/hw_abs_dag/*
, include/tvm/auto_tensorize/hw_abs_dag.h
, and src/auto_tensorize/hw_abs_dag.cc
. We also pre-implement abstraction DAGs for some exsiting hardwares such as NVIDIA GPU (with Tensor Core), Intel CPU (with AVX512), and Mali GPU (with arm_dot intrinsic) in the corresponding subdirectories of python/tvm/auto_tensorize/hw_abs_dag/*
.
We define a function MatchIntrinsic
to generate and verify possible software-hardware mappings based on the graph & expression & loop structure of a HardwareAbstractionDAG
and the structure of a target compute to schedule.
First, it uses HwAbsDAGMatcher
to match the DAG of the target compute with the DAG of the hardware abstraction.
Second, for a matched DAG, it uses HwAbsExprMatcher
to match the expression structure of the compute with the "main HardwareAbstraction
".
Third, if the expression structures match, it uses IndexExprMatcher
to generate and verify the possible software-hardware loop mappings based on the loop structures of the target compute and the "main HardwareAbstraction
".
We implement it mainly in python/tvm/auto_tensorize/tensorization_phases/intrin_match.py
, include/tvm/auto_tensorize/matcher.h
, and src/auto_tensorize/matcher.cc
.
For every possible software-hardware mapping, we first generate corresponding schedules with random parameters, then validate them by our validation checker. We then use two steps to find the optimal mapping and its schedule parameters, coarse-grained and fine-grained, respectively.
First, each valid mapping and its schedules are estimated by our performance model to find top-K parameter options. Second, we profile these parameters according to the estimated results to find the best one. This process is repeated thousands of times to find the best mapping.
Users can turn on the estimation step by setting enable_perf_model=True
, and select the percentage of the chosen parameters by configuring perf_percentage
.
Users who want to explore mappings only through performance model estimation can set target='tenet'+'real target'
, e.g. tenet cuda
.
We implement the exploration process mainly in python/tvm/auto_tensorize/search/parameter.py
and python/tvm/auto_tensorize/search/measure.py
.
The schedule generators and appliers for different target platforms are implemented in python/tvm/auto_tensorize/tensorization_phases/schedulers/*
, and the performance model is implemented in python/tvm/auto_tensorize/backend/*
.
To run the benchmarks of AMOS, please see the directory benchmark/amos. We have prepared a series of mapping_*.py files. For example, to run the conv2d mapping file and map conv2d to FP16 Tensor Core:
python mapping_conv2d_tensorcore.py --trials 100 --in_dtype float16 --out_dtype float16 --begin 0 --num 1 --batch 1
To run the benchmarks of PyTorch and CuDNN, please see the directory benchmark/pytorch. For example, to run conv2d mapping file with CuDNN:
python conv2d.py --target cuda --batch 1 --enable_cudnn --number 5 --repeats 5 --begin 0 --num 1 --dtype FP16
Check out the AMOS Documentation site for installation instructions, tutorials, API references, and more.
@inproceedings{DBLP:conf/isca/0001CWJHLWLY022,
author = {Size Zheng and
Renze Chen and
Anjiang Wei and
Yicheng Jin and
Qin Han and
Liqiang Lu and
Bingyang Wu and
Xiuhong Li and
Shengen Yan and
Yun Liang},
editor = {Valentina Salapura and
Mohamed Zahran and
Fred Chong and
Lingjia Tang},
title = {{AMOS:} enabling automatic mapping for tensor computations on spatial
accelerators with hardware abstraction},
booktitle = {{ISCA} '22: The 49th Annual International Symposium on Computer Architecture,
New York, New York, USA, June 18 - 22, 2022},
pages = {874--887},
publisher = {{ACM}},
year = {2022},
url = {https://doi.org/10.1145/3470496.3527440},
doi = {10.1145/3470496.3527440},
timestamp = {Wed, 01 Jun 2022 14:59:23 +0200},
biburl = {https://dblp.org/rec/conf/isca/0001CWJHLWLY022.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}