diff --git a/python/tvm/topi/arm_cpu/arm_utils.py b/python/tvm/topi/arm_cpu/arm_utils.py index 7e0f566b96f4..15d84c20ed23 100644 --- a/python/tvm/topi/arm_cpu/arm_utils.py +++ b/python/tvm/topi/arm_cpu/arm_utils.py @@ -43,12 +43,21 @@ def get_arch_version(target_mattr): def is_dotprod_available(): - """ Checks whether the hardware has support for fast Int8 arithmetic operations. """ + """ Checks whether the hardware has support for udot/sdot instructions. """ target = tvm.target.Target.current(allow_none=False) arch_version = get_arch_version(target.mattr) return arch_version >= 8.4 or ((arch_version in (8.2, 8.3)) and "+dotprod" in target.mattr) +def is_mmla_available(): + """ Checks whether the hardware has support for ummla/smmla instructions. """ + target = tvm.target.Target.current(allow_none=False) + arch_version = get_arch_version(target.mattr) + return arch_version >= 8.6 or ( + (arch_version in (8.2, 8.3, 8.4, 8.5)) and "+i8mm" in target.mattr + ) + + def is_aarch64_arm(): """ Checks whether we are compiling for an AArch64 target. """ target = tvm.target.Target.current(allow_none=False) @@ -63,8 +72,10 @@ def get_tiling_B_interleaved_t(interleave_A): tile computation. Please refer to: - - https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product - - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h + - https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-performance-for-armv8-architectures # pylint: disable=line-too-long + - https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product + - https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-through-mmla-instruction + - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h In order to have more information Parameters @@ -77,7 +88,13 @@ def get_tiling_B_interleaved_t(interleave_A): tile_rows_B: the output tile rows of B' tile_cols_B: the output tile columns of B' """ - if is_dotprod_available(): + if is_mmla_available(): + # If smmla/ummla is available, A must be interleaved. + # Each load from B' will contain 8 elements + # and we are loading 12 rows of B' (i.e., 12 columns of B) + tile_rows_B = 12 + tile_cols_B = 8 + elif is_dotprod_available(): # The number of tile rows of B' vary depending on the # strategy: # * If we are interleaving A, then we select 12 columns from B'(i.e., @@ -92,7 +109,7 @@ def get_tiling_B_interleaved_t(interleave_A): # rows of the original matrix B) need to be 4. tile_cols_B = 4 else: - # If dot product is not available, A must be interleaved. In this case + # If no acceleration is available, A must be interleaved. In this case # we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements tile_rows_B = 4 tile_cols_B = 16 diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index 81326f169260..6a5cb2ae890e 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -28,8 +28,9 @@ gemm_quantized_impl, gemm_acc_4x4_int8_int8_int32, gemm_acc_nx16_int8_int8_int32, + gemm_acc_2x2_int8_int8_int32, ) -from .arm_utils import is_aarch64_arm, is_dotprod_available +from .arm_utils import is_aarch64_arm, is_dotprod_available, is_mmla_available def configure_knobs(cfg, M, K): @@ -130,11 +131,18 @@ def compute_conv2d_gemm_without_weight_transform( # the tile computation. # # Please refer to: - # - https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product - # - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h + # - https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-performance-for-armv8-architectures # pylint: disable=line-too-long + # - https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product + # - https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-through-mmla-instruction + # - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h # In order to have more information # - if is_dotprod_available() and interleave_A: + if is_mmla_available(): + # If smmla/ummla is enabled, we are loading 8 rows from A. Each row + # will contain 8 elements + tile_rows_A = 8 + tile_cols_A = 8 + elif is_dotprod_available() and interleave_A: # If dot product has been enabled, and we are interleaving A # tile size should be 8x4 tile_rows_A = 8 @@ -177,24 +185,71 @@ def compute_conv2d_gemm_without_weight_transform( lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y], name="A_interleaved", ) - # Execute GEMM - C_interleaved = te.compute( - (batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B), - lambda b, x, y, w, z: te.sum( - A_interleaved[b, x, k // tile_cols_A, w, idxm(k, tile_cols_A)].astype("int32") - * B_interleaved_t[y, k // tile_cols_B, z, idxm(k, tile_cols_B)].astype("int32"), - axis=k, - ), - name="C_interleaved", - ) - # Unpack the result - C = te.compute( - (batches, M, N), - lambda b, x, y: C_interleaved[ - b, x // tile_rows_A, y // tile_rows_B, idxm(x, tile_rows_A), idxm(y, tile_rows_B) - ].astype(out_dtype), - name="C", - ) + if is_mmla_available(): + # Execute GEMM. In the case of mmla, we need to enforce the tiling + # from the compute. This is because mmla is doing a tiled computation + # as well. So we have a big 8x12 tile, with small 2x2 sub-tiles + # generated by mmla. In theory we could make the tile 2x2 and + # fuse and split during scheduling, but this would not work + # because of possible padding + C_interleaved = te.compute( + ( + batches, + M_padded // tile_rows_A, + N_transformed, + tile_rows_A // 2, + tile_rows_B // 2, + 2, + 2, + ), + lambda b, x, y, w, z, s, t: te.sum( + A_interleaved[b, x, k // tile_cols_A, 2 * w + s, idxm(k, tile_cols_A)].astype( + "int32" + ) + * B_interleaved_t[y, k // tile_cols_B, 2 * z + t, idxm(k, tile_cols_B)].astype( + "int32" + ), + axis=k, + ), + name="C_interleaved", + ) + # Unpack the result + C = te.compute( + (batches, M, N), + lambda b, x, y: C_interleaved[ + b, + x // tile_rows_A, + y // tile_rows_B, + idxm(x, tile_rows_A) // 2, + idxm(y, tile_rows_B) // 2, + idxm(idxm(x, tile_rows_A), 2), + idxm(idxm(y, tile_rows_B), 2), + ].astype(out_dtype), + name="C", + ) + else: + # Execute GEMM + C_interleaved = te.compute( + (batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B), + lambda b, x, y, w, z: te.sum( + A_interleaved[b, x, k // tile_cols_A, w, idxm(k, tile_cols_A)].astype("int32") + * B_interleaved_t[y, k // tile_cols_B, z, idxm(k, tile_cols_B)].astype("int32"), + axis=k, + ), + name="C_interleaved", + ) + # Unpack the result + C = te.compute( + (batches, M, N), + lambda b, x, y: C_interleaved[ + b, + x // tile_rows_A, + y // tile_rows_B, + idxm(x, tile_rows_A), + idxm(y, tile_rows_B), + ].astype(out_dtype), + name="C", + ) zero = tvm.tir.const(0) else: # No need to pack/unpack, execute GEMM directly @@ -255,7 +310,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): s[data_im2col].compute_inline() # Computation(through tensorize) - b, xo, yo, xi, yi = C_interleaved.op.axis + b, xo, yo, xi, yi = C_interleaved.op.axis[0:5] outer_gemm, inner_gemm = cfg["reorder_gemm"].apply(s, C_interleaved, [xo, yo]) b_outer_gemm_fused = s[C_interleaved].fuse(b, outer_gemm) @@ -271,40 +326,50 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): k = C_interleaved.op.reduce_axis[0] _, M, N = C.shape - if is_dotprod_available(): - gemm_acc = gemm_acc_4x4_int8_int8_int32(in_type) - xi_outer, yi_outer, xi_inner, yi_inner = s[C_interleaved].tile( - xi, yi, x_factor=8, y_factor=4 - ) - k_outer, k_inner = s[C_interleaved].split(k, 4) - xi_inner_outer, xi_inner_inner = s[C_interleaved].split(xi_inner, 4) - s[C_interleaved].reorder( - b_outer_gemm_fused, - inner_gemm, - xi_outer, - yi_outer, - k_outer, - xi_inner_outer, - xi_inner_inner, - yi_inner, - k_inner, - ) - s[C_interleaved].tensorize(xi_inner_inner, gemm_acc) - s[C_interleaved].unroll(xi_inner_outer) - - elif is_aarch64_arm(): - s[C_interleaved].reorder(yi, xi) - K = A_interleaved_input.shape[2] - assert in_type in ["int8", "uint8"], "Only int8 and uint8 gemm are supported" - unroll = cfg["gemm_quantized_unroll"].val - interleave = cfg["gemm_quantized_interleave"].val - gemm = gemm_quantized(M, N, K, unroll, interleave, in_type, out_type) - s[C_interleaved].pragma( - b_outer_gemm_fused, - "import_llvm", - gemm_quantized_impl(M, N, K, unroll, interleave, in_type), - ) - s[C_interleaved].tensorize(yi, gemm) + if in_type in ["int8", "uint8"]: + if is_mmla_available(): + gemm_acc = gemm_acc_2x2_int8_int8_int32(in_type) + xi_inner, yi_inner = C_interleaved.op.axis[-2:] + k_outer, k_inner = s[C_interleaved].split(k, 8) + s[C_interleaved].reorder( + b_outer_gemm_fused, inner_gemm, k_outer, xi, yi, xi_inner, yi_inner, k_inner + ) + s[C_interleaved].tensorize(xi_inner, gemm_acc) + s[C_interleaved].unroll(xi) + s[C_interleaved].unroll(yi) + elif is_dotprod_available(): + gemm_acc = gemm_acc_4x4_int8_int8_int32(in_type) + xi_outer, yi_outer, xi_inner, yi_inner = s[C_interleaved].tile( + xi, yi, x_factor=8, y_factor=4 + ) + k_outer, k_inner = s[C_interleaved].split(k, 4) + xi_inner_outer, xi_inner_inner = s[C_interleaved].split(xi_inner, 4) + s[C_interleaved].reorder( + b_outer_gemm_fused, + inner_gemm, + xi_outer, + yi_outer, + k_outer, + xi_inner_outer, + xi_inner_inner, + yi_inner, + k_inner, + ) + s[C_interleaved].tensorize(xi_inner_inner, gemm_acc) + s[C_interleaved].unroll(xi_inner_outer) + + elif is_aarch64_arm(): + s[C_interleaved].reorder(yi, xi) + K = A_interleaved_input.shape[2] + unroll = cfg["gemm_quantized_unroll"].val + interleave = cfg["gemm_quantized_interleave"].val + gemm = gemm_quantized(M, N, K, unroll, interleave, in_type, out_type) + s[C_interleaved].pragma( + b_outer_gemm_fused, + "import_llvm", + gemm_quantized_impl(M, N, K, unroll, interleave, in_type), + ) + s[C_interleaved].tensorize(yi, gemm) # Output transform if out != final_out: diff --git a/python/tvm/topi/arm_cpu/tensor_intrin.py b/python/tvm/topi/arm_cpu/tensor_intrin.py index 1b999dfe4e80..8ccbe0c41298 100644 --- a/python/tvm/topi/arm_cpu/tensor_intrin.py +++ b/python/tvm/topi/arm_cpu/tensor_intrin.py @@ -411,6 +411,7 @@ def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type): intrin : TensorIntrin The ARM uint8/int8 TensorIntrin that can be used in tensorizing schedule """ + assert in_type in ["uint8", "int8"] A = te.placeholder((K // 16, te.var("m"), 16), dtype=in_type, name="A") B = te.placeholder((K // 16, te.var("n"), 16), dtype=in_type, name="B") @@ -627,7 +628,7 @@ def gemm_acc_4x4_int8_int8_int32(dtype): Int8 4x4 matrix multiplication and accumulation using sdot/udot instructions. This function takes two arrays of int8 datatype -- A[4][4] and B[4][4] and produces a 4x4 matrix - which is equal to A*B. + which is equal to A*B'. The pseudo code is as follows. @@ -643,7 +644,6 @@ def gemm_acc_4x4_int8_int8_int32(dtype): } Notes: - * The rows of matrix B are transposed * The tiling strategy is picked to maximize register usage. Parameters @@ -656,6 +656,7 @@ def gemm_acc_4x4_int8_int8_int32(dtype): intrin : TensorIntrin The Arm TensorIntrin that can be used in tensorizing schedule """ + assert dtype in ["uint8", "int8"] # This needs to be a variable number of "rows" since TVM # "thinks" I only need to compute one row because of # padding @@ -755,7 +756,7 @@ def gemm_acc_nx16_int8_int8_int32(dtype, rows): """ Int8 nx16 matrix multiplication and accumulation using sdot/udot instructions This function takes two arrays of int8 datatype -- A[n][4] and - B[4][16] and produces a rowsx16 matrix which is equal to A*B + B[4][16] and produces a rowsx16 matrix which is equal to A*B' The pseudo code is as follows. .. code-block:: c @@ -771,7 +772,6 @@ def gemm_acc_nx16_int8_int8_int32(dtype, rows): } Notes: - * The rows of matrix B are transposed * The tile size of B is 16x4. Since the reduction variable k moves between 0 and 16 we need 4 tiles of B to compute a single row of the output. The first 4 values of k will be fetched from B[0][j][k], the second batch of 4 from B[1][j][k] and so on @@ -789,6 +789,7 @@ def gemm_acc_nx16_int8_int8_int32(dtype, rows): intrin : TensorIntrin The Arm TensorIntrin that can be used in tensorizing schedule """ + assert dtype in ["uint8", "int8"] A = te.placeholder((rows, 16), dtype, name="A") B = te.placeholder((4, 16, 4), dtype, name="B") dtype_vec = dtype + "x16" @@ -969,6 +970,103 @@ def _instr(index): ) +def gemm_acc_2x2_int8_int8_int32(dtype): + """ + Int8 2x2 matrix multiplication using smmla/ummla instructions + This function takes two arrays of int8 datatype -- A[2][8] and + B[2][8] and produces a 2x2 matrix which is equal to A*B' + The pseudo code is as follows. + + .. code-block:: c + + void mmla_2x2_int8_int8_int32(int8 A[2][8], int8 B[2][8], int32 C[2][2]){ + for (int i = 0; i < 2; i++){ + for (int j = 0; i < 2; i++){ + for (int k = 0; k < 8; k++){ + C[i][j] += A[i][k] * B[j][k] + } + } + } + + Parameters + ---------- + dtype: str, {"uint8", "int8"} + Whether it works on unsigned int or signed int + + Returns + ------- + intrin : TensorIntrin + The Arm TensorIntrin that can be used in tensorizing schedule + """ + assert dtype in ["uint8", "int8"] + A = te.placeholder((2, 8), dtype, name="A") + B = te.placeholder((2, 8), dtype, name="B") + dtype_vec = dtype + "x16" + + k = te.reduce_axis((0, 8), name="k") + C = te.compute( + (2, 2), + lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k), + name="C", + ) + + aa_buffer = tvm.tir.decl_buffer( + A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1] + ) + bb_buffer = tvm.tir.decl_buffer( + B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1] + ) + cc_buffer = tvm.tir.decl_buffer( + C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1] + ) + + llvm_intrin = "llvm.aarch64.neon.smmla" if dtype == "int8" else "llvm.aarch64.neon.ummla" + + def _intrin_func(ins, outs): + def _instr(index): + ib = tvm.tir.ir_builder.create() + if index == 1: + ib.emit(outs[0].vstore([0, 0], tvm.tir.const(0, "int32x4"))) + return ib.get() + # Load in vec_a the two rows of A + # vec_a = [a, b, c, d, e, f, g, h; + # i, j, k, l, m, n, o, p,] + vec_a = ins[0].vload([0, 0], dtype_vec) + # Load in vec_b the two rows of B + # vec_b = [0, 2, 4, 6, 8, 10, 12, 14; + # 1, 3, 5, 7, 9, 11, 13, 14,] + vec_b = ins[1].vload([0, 0], dtype_vec) + + # Execute the matrix multiplication via (s/u)mmla: + # vec_c = [a*0 + b*2 + c*4 + d*6 +e*8 + f*10 + g*12 + h*14; + # a*1 + b*3 + c*5 + d*7 +e*9 + f*11 + g*13 + h*15; + # i*0 + j*2 + k*4 + l*6 +m*8 + n*10 + o*12 + p*14; + # i*1 + j*3 + k*5 + l*7 +m*9 + n*11 + o*13 + p*15] + vec_c = outs[0].vload([0, 0], "int32x4") + vmmla = tvm.tir.call_llvm_intrin( + "int32x4", + llvm_intrin, + tvm.tir.const(3, "uint32"), + vec_c, + vec_a, + vec_b, + ) + # Store the result + ib.emit(outs[0].vstore([0, 0], vmmla)) + return ib.get() + + # body, reset, update + return _instr(0), _instr(1), _instr(2) + + buffer_params = {"offset_factor": 1} + return te.decl_tensor_intrin( + C.op, + _intrin_func, + binds={A: aa_buffer, B: bb_buffer, C: cc_buffer}, + default_buffer_params=buffer_params, + ) + + def _q_multiply_shift_arm(op): """ Implementation of q_multiply_shift_arm through arm intrinsics diff --git a/tests/python/topi/python/test_topi_conv2d_int8.py b/tests/python/topi/python/test_topi_conv2d_int8.py index e5b0689da008..1bf83eba53ac 100644 --- a/tests/python/topi/python/test_topi_conv2d_int8.py +++ b/tests/python/topi/python/test_topi_conv2d_int8.py @@ -72,6 +72,12 @@ def compile_conv2d_NHWC_gemm_int8_arm( topi.arm_cpu.compute_conv2d_NHWC_quantized_native, topi.arm_cpu.schedule_conv2d_NHWC_quantized_native, ), + # TODO(giuseros) Need LLVM-11 in order to compile with +i8mm extension + # ( + # "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+i8mm", + # topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved, + # topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved, + # ), ] for device_tuple in devices: