Skip to content

Commit

Permalink
Address comments and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Anndrey24 committed Apr 22, 2024
1 parent 0d408e4 commit 63a5001
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 116 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
# for scalable vector alias analysis, which causes redundant loads / stores
# to remain after LLVM's optimisation passes, unlike the non-scalable case.
# Hence, it is given a lower priority level until these issues are resolved.
# Last checked manually using: LLVM 18.1.0
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_SVE),
Expand Down
85 changes: 85 additions & 0 deletions python/tvm/topi/arm_cpu/arm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,59 @@
from tvm.target import Target


def get_tiling_A(interleave_A, in_dtype):
"""Compute the tiling information for matrix A in C=A*B,
which corresponds to the im2col-transformed input matrix.
The tiling information is chosen to maximize register usage during
the tile computation.
Please refer to:
- https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-performance-for-armv8-architectures # pylint: disable=line-too-long
- https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product
- https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-through-mmla-instruction
- Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
In order to have more information
Parameters
----------
interleave_A : bool
determines if A is expected to be interleaved
in_dtype : str
input datatype
Returns
----------
tile_M: the output tile size of A on M axis (M = OH * OW)
tile_K: the output tile size of A on K axis (K = KW * KH * IC)
"""
target = Target.current(allow_none=False)
if in_dtype in ["int8", "uint8"]:
if target.features.has_matmul_i8:
# If smmla/ummla is enabled, we are loading 8 rows from A. Each row
# will contain 8 elements
tile_M = 8
tile_K = 8
elif target.features.has_dotprod and interleave_A:
# If dot product has been enabled, and we are interleaving A
# tile size should be 8x4
tile_M = 8
tile_K = 4
else:
# If either there is no dot product or if we are using a native strategy
# tile size should be 4x16
tile_M = 4
tile_K = 16
else:
# In non-quantized cases, A is not interleaved.
# We are loading 4 rows from A.
# Each row will contain 4 elements, along the dimension of reduction
tile_M = 4
tile_K = 4

return tile_M, tile_K


def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False):
"""Compute the tiling information for matrix B', where B'
is the tiled, interleaved (and transposed) version of matrix B in C=A*B.
Expand Down Expand Up @@ -101,6 +154,38 @@ def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False)
return tile_N, tile_K


def get_conv2d_im2col_padding(M, K, tile_M, tile_K):
"""Compute the necessary padding for matrix A in C=A*B,
which corresponds to the im2col-transformed input matrix.
Parameters
----------
M : int
Number of rows in A = OH * OW
K : int
Number of columns in A = KW * KH * IC
tile_M : int
tile size of A on M axis
tile_K : int
tile size of A on K axis
Returns
----------
pad_M : padding for M axis
pad_K : padding for K axis
"""
pad_M = 0
pad_K = 0

if M % tile_M != 0:
pad_M = tile_M - (M % tile_M)

if K % tile_K != 0:
pad_K = tile_K - (K % tile_K)

return pad_M, pad_K


def get_conv2d_weights_padding(N, K, tile_N, tile_K):
"""Compute the necessary padding for matrix B', where B'
is the transformed version of matrix B in C=A*B.
Expand Down
101 changes: 30 additions & 71 deletions python/tvm/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from tvm.target import Target
from tvm import te
from tvm.topi import nn
from tvm.topi.arm_cpu import arm_utils
from tvm.autotvm.task.space import AnnotateEntity, ReorderEntity, OtherOptionEntity
from tvm.topi.arm_cpu.arm_utils import get_tiling_B_transformed
from ..utils import get_const_tuple, get_const_int
from ..nn.utils import get_pad_tuple
from .tensor_intrin import (
Expand Down Expand Up @@ -93,14 +93,16 @@ def compute_conv2d_gemm_without_weight_transform(

OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1

# Input padding (if necessary)
if pad_top or pad_left or pad_down or pad_right:
data_pad = nn.pad(
data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], name="data_pad"
)
else:
data_pad = data

# Im2col
# Im2col transformation
M = OH * OW
K = IC * kernel_area
N = OC
Expand All @@ -120,65 +122,19 @@ def compute_conv2d_gemm_without_weight_transform(
name="data_im2col",
)

# Pad if necessary
tile_N, tile_K_B = get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors)

# Select the tiling strategy for A.
# The tiling information is chosen to maximize register usage during
# the tile computation.
#
# Please refer to:
# - https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-performance-for-armv8-architectures # pylint: disable=line-too-long
# - https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product
# - https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-through-mmla-instruction
# - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
# In order to have more information
#
target = Target.current(allow_none=False)
if in_dtype in ["int8", "uint8"]:
if target.features.has_matmul_i8:
# If smmla/ummla is enabled, we are loading 8 rows from A. Each row
# will contain 8 elements
tile_M = 8
tile_K_A = 8
elif target.features.has_dotprod and interleave_A:
# If dot product has been enabled, and we are interleaving A
# tile size should be 8x4
tile_M = 8
tile_K_A = 4
else:
# If either there is no dot product or if we are using a native strategy
# tile size should be 4x16
tile_M = 4
tile_K_A = 16
else:
# In non-quantized cases, A is not interleaved.
# We are loading 4 rows from A.
# Each row will contain 4 elements, along the dimension of reduction
tile_M = 4
tile_K_A = 4

pad_M = 0
pad_N = 0
pad_K = 0

if M % tile_M != 0:
pad_M = tile_M - (M % tile_M)

if K % tile_K_A != 0:
pad_K = tile_K_A - (K % tile_K_A)
# Select the tiling strategy for A and B
tile_M, tile_K_A = arm_utils.get_tiling_A(interleave_A, in_dtype)
tile_N, tile_K_B = arm_utils.get_tiling_B_transformed(
interleave_A, in_dtype, use_scalable_vectors
)

if N % tile_N != 0:
pad_N = tile_N - (N % tile_N)
# Pad to tiles (if necessary)
pad_M, pad_K = arm_utils.get_conv2d_im2col_padding(M, K, tile_M, tile_K_A)
pad_N, _ = arm_utils.get_conv2d_weights_padding(N, K, tile_N, tile_K_B)

M_padded = M + pad_M
K_padded = K + pad_K

if use_scalable_vectors:
N_padded = N + pad_N
else:
N_transformed = B_interleaved_t.shape[0]
N_padded = N_transformed * tile_N
N_padded = N + pad_N

pad_before = (0, 0, 0)
pad_after = (0, pad_M, pad_K)
Expand All @@ -191,7 +147,10 @@ def compute_conv2d_gemm_without_weight_transform(
idxm = tvm.tir.indexmod
k = te.reduce_axis((0, K_padded), "k")

# Determine matrix multiplication compute definition
target = Target.current(allow_none=False)
if in_dtype in ["int8", "uint8"]:
assert len(B_interleaved_t.shape) == 4
if interleave_A:
# Configuration space
configure_knobs(cfg, M_padded, K_padded, target)
Expand All @@ -208,7 +167,7 @@ def compute_conv2d_gemm_without_weight_transform(
lambda b, x, y, z, w: A[b, z + tile_M * x, w + tile_K_A * y],
name="A_interleaved",
)
target = Target.current(allow_none=False)
N_transformed = B_interleaved_t.shape[0]
if target.features.has_matmul_i8:
# Execute GEMM. In the case of mmla, we need to enforce the tiling
# from the compute. This is because mmla is doing a tiled computation
Expand Down Expand Up @@ -384,6 +343,8 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):
C = out.op.input_tensors[0]
C_interleaved = C.op.input_tensors[0]
A_interleaved = C_interleaved.op.input_tensors[0]
in_type = A_interleaved.dtype
tile_M, tile_K = arm_utils.get_tiling_A(True, in_type)

# Input transform
A_interleaved_input = A_interleaved.op.input_tensors[0]
Expand Down Expand Up @@ -422,17 +383,14 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):
s, A_interleaved, [outer_A_interleaved, inner_A_interleaved]
)

in_type = A_interleaved.dtype
out_type = C.dtype

k = C_interleaved.op.reduce_axis[0]
_, M, N = C.shape
if in_type in ["int8", "uint8"]:
target = Target.current(allow_none=False)
if target.features.has_matmul_i8:
gemm_acc = gemm_acc_2x2_int8_int8_int32(in_type)
xi_inner, yi_inner = C_interleaved.op.axis[-2:]
k_outer, k_inner = s[C_interleaved].split(k, 8)
k_outer, k_inner = s[C_interleaved].split(k, tile_K)
s[C_interleaved].reorder(
b_outer_gemm_fused, inner_gemm, k_outer, xi, yi, xi_inner, yi_inner, k_inner
)
Expand All @@ -442,9 +400,9 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):
elif target.features.has_dotprod:
gemm_acc = gemm_acc_4x4_int8_int8_int32(in_type)
xi_outer, yi_outer, xi_inner, yi_inner = s[C_interleaved].tile(
xi, yi, x_factor=8, y_factor=4
xi, yi, x_factor=tile_M, y_factor=4
)
k_outer, k_inner = s[C_interleaved].split(k, 4)
k_outer, k_inner = s[C_interleaved].split(k, tile_K)
xi_inner_outer, xi_inner_inner = s[C_interleaved].split(xi_inner, 4)
s[C_interleaved].reorder(
b_outer_gemm_fused,
Expand Down Expand Up @@ -483,24 +441,25 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
A = C.op.input_tensors[0]
in_type = A.dtype
use_scalable_vectors = out.op.attrs["use_scalable_vectors"].value
y_tile_size, _ = get_tiling_B_transformed(False, in_type, use_scalable_vectors)
tile_M, tile_K = arm_utils.get_tiling_A(False, in_type)
tile_N, _ = arm_utils.get_tiling_B_transformed(False, in_type, use_scalable_vectors)

# Computation
b, x, y = C.op.axis
(k,) = C.op.reduce_axis

if in_type in ["int8", "uint8"]:
k_outer, k_inner = s[C].split(k, 16)
x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size)
k_outer, k_inner = s[C].split(k, tile_K)
x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=tile_M, y_factor=tile_N)
s[C].reorder(b, x_outer, y_outer, k_outer, x_inner, y_inner, k_inner)
gemm_acc = gemm_acc_nx16_int8_int8_int32(in_type, rows=1)
s[C].unroll(x_inner)
s[C].tensorize(y_inner, gemm_acc)
s[C].parallel(x_outer)
else:
k_outer, k_inner = s[C].split(k, factor=4)
x_outer, x_inner = s[C].split(x, factor=4)
y_outer, y_inner = s[C].split(y, factor=y_tile_size, disable_predication=True)
k_outer, k_inner = s[C].split(k, factor=tile_K)
x_outer, x_inner = s[C].split(x, factor=tile_M)
y_outer, y_inner = s[C].split(y, factor=tile_N, disable_predication=use_scalable_vectors)
b_x_outer_fused = s[C].fuse(b, x_outer)
s[C].parallel(b_x_outer_fused)
s[C].reorder(
Expand Down Expand Up @@ -552,7 +511,7 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
s[data_im2col].vectorize(n_inner)
elif padding_A:
s[data_im2col].compute_inline()
_, n_inner = s[A].split(A.op.axis[2], y_tile_size)
_, n_inner = s[A].split(A.op.axis[2], tile_N)
s[A].vectorize(n_inner)
s[A].compute_at(s[C], x_inner)
else:
Expand Down
3 changes: 2 additions & 1 deletion src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "constraint_extract.h"
#include "int_operator.h"
#include "pattern_match.h"
#include "scalable_expression.h"

namespace tvm {
namespace arith {
Expand Down Expand Up @@ -369,7 +370,7 @@ class ConstIntBoundAnalyzer::Impl
return VisitLeftShift(op);
} else if (op->op.same_as(tir::builtin::bitwise_and())) {
return VisitBitwiseAnd(op);
} else if (op->op.same_as(tir::builtin::vscale())) {
} else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE()) {
return MakeBound(1, 16);
} else {
return Everything(op->dtype);
Expand Down
Loading

0 comments on commit 63a5001

Please sign in to comment.