Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SVE][TOPI] Add conv2d NHWC hybrid SVE schedule for arm_cpu #16899

Merged
merged 4 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,18 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
)
# Non-quantized cases
if is_aarch64 and data.dtype in ["float32", "float16"]:
if target.features.has_sve:
# This strategy is currently suboptimal because of LLVM's limited support
# for scalable vector alias analysis, which causes redundant loads / stores
Anndrey24 marked this conversation as resolved.
Show resolved Hide resolved
# to remain after LLVM's optimisation passes, unlike the non-scalable case.
# Hence, it is given a lower priority level until these issues are resolved.
# Last checked manually using: LLVM 18.1.0
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_SVE),
name="conv2d_NHWC_hybrid_SVE.arm_cpu",
plevel=5,
)
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid),
Expand Down
99 changes: 98 additions & 1 deletion python/tvm/topi/arm_cpu/arm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,64 @@
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Arm target utility functions"""

import tvm
from tvm.target import Target


def get_tiling_B_transformed(interleave_A, in_dtype):
def get_tiling_A(interleave_A, in_dtype):
"""Compute the tiling information for matrix A in C=A*B,
which corresponds to the im2col-transformed input matrix.

The tiling information is chosen to maximize register usage during
the tile computation.

Please refer to:
- https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-performance-for-armv8-architectures # pylint: disable=line-too-long
- https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product
- https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-through-mmla-instruction
- Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
In order to have more information

Parameters
----------
interleave_A : bool
determines if A is expected to be interleaved
in_dtype : str
input datatype

Returns
----------
tile_M: the output tile size of A on M axis (M = OH * OW)
tile_K: the output tile size of A on K axis (K = KW * KH * IC)
"""
target = Target.current(allow_none=False)
if in_dtype in ["int8", "uint8"]:
if target.features.has_matmul_i8:
# If smmla/ummla is enabled, we are loading 8 rows from A. Each row
# will contain 8 elements
tile_M = 8
tile_K = 8
elif target.features.has_dotprod and interleave_A:
# If dot product has been enabled, and we are interleaving A
# tile size should be 8x4
tile_M = 8
tile_K = 4
else:
# If either there is no dot product or if we are using a native strategy
# tile size should be 4x16
tile_M = 4
tile_K = 16
else:
# In non-quantized cases, A is not interleaved.
# We are loading 4 rows from A.
# Each row will contain 4 elements, along the dimension of reduction
tile_M = 4
tile_K = 4

return tile_M, tile_K


def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False):
"""Compute the tiling information for matrix B', where B'
is the tiled, interleaved (and transposed) version of matrix B in C=A*B.

Expand All @@ -40,6 +94,8 @@ def get_tiling_B_transformed(interleave_A, in_dtype):
determines if A is expected to be interleaved
in_dtype : str
input datatype
use_scalable_vectors : bool
determines if operations on scalable vectors are expected


Returns
Expand Down Expand Up @@ -75,6 +131,15 @@ def get_tiling_B_transformed(interleave_A, in_dtype):
tile_N = 4
tile_K = 16
# In non-quantized cases, A is not interleaved.
elif use_scalable_vectors:
if in_dtype == "float16":
# Each load from B' contains 32 * vscale elements (i.e. 32 * vscale columns from B)
tile_N = 32 * tvm.tir.vscale()
else:
# Each load from B' contains 16 * vscale elements (i.e. 16 * vscale columns from B)
tile_N = 16 * tvm.tir.vscale()
# We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B)
tile_K = 4
elif in_dtype == "float16" and target.features.has_fp16_simd:
# Each load from B' contains 32 elements (i.e. 32 columns from B)
# We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B)
Expand All @@ -89,6 +154,38 @@ def get_tiling_B_transformed(interleave_A, in_dtype):
return tile_N, tile_K


def get_conv2d_im2col_padding(M, K, tile_M, tile_K):
"""Compute the necessary padding for matrix A in C=A*B,
which corresponds to the im2col-transformed input matrix.

Parameters
----------
M : int
Number of rows in A = OH * OW
K : int
Number of columns in A = KW * KH * IC
tile_M : int
tile size of A on M axis
tile_K : int
tile size of A on K axis

Returns
----------
pad_M : padding for M axis
pad_K : padding for K axis
"""
pad_M = 0
pad_K = 0

if M % tile_M != 0:
pad_M = tile_M - (M % tile_M)

if K % tile_K != 0:
pad_K = tile_K - (K % tile_K)

return pad_M, pad_K


def get_conv2d_weights_padding(N, K, tile_N, tile_K):
"""Compute the necessary padding for matrix B', where B'
is the transformed version of matrix B in C=A*B.
Expand Down
43 changes: 39 additions & 4 deletions python/tvm/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,14 +517,35 @@ def schedule_conv2d_nhwc_dsp(cfg, outs):
return conv2d_nhwc_dsp_schedule(cfg, outs)


def compute_conv2d_NHWC(cfg, data, kernel, strides, padding, dilation, out_dtype, interleave_A):
def compute_conv2d_NHWC(
cfg,
data,
kernel,
strides,
padding,
dilation,
out_dtype,
interleave_A,
use_scalable_vectors=False,
):
"""Compute definition for conv2d NHWC"""
N, IH, IW, IC = get_const_tuple(data.shape)
KH, KW, _, OC = get_const_tuple(kernel.shape)
tile_N, tile_K = get_tiling_B_transformed(interleave_A, data.dtype)
tile_N, tile_K = get_tiling_B_transformed(interleave_A, data.dtype, use_scalable_vectors)

kernel = nn.conv2d_gemm_weight_transform(kernel, tile_N, tile_K)
kernel = nn.conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors)
return compute_conv2d_gemm_without_weight_transform(
cfg, data, kernel, strides, padding, dilation, out_dtype, (KH, KW), OC, interleave_A
cfg,
data,
kernel,
strides,
padding,
dilation,
out_dtype,
(KH, KW),
OC,
interleave_A,
use_scalable_vectors,
)


Expand Down Expand Up @@ -620,3 +641,17 @@ def schedule_conv2d_NHWC_hybrid(cfg, outs):
def schedule_conv2d_NHWC_hybrid_without_transform(cfg, outs):
"""Interface for hybrid schedule_conv2d_NHWC_hybrid"""
return schedule_conv2d_NHWC(cfg, outs, False)


@autotvm.register_topi_compute("conv2d_NHWC_hybrid_SVE.arm_cpu")
def compute_conv2d_NHWC_hybrid_SVE(cfg, data, kernel, strides, padding, dilation, out_dtype):
"""Interface for hybrid compute_conv2d_NHWC_hybrid_SVE"""
return compute_conv2d_NHWC(
cfg, data, kernel, strides, padding, dilation, out_dtype, False, True
)


@autotvm.register_topi_schedule("conv2d_NHWC_hybrid_SVE.arm_cpu")
def schedule_conv2d_NHWC_hybrid_SVE(cfg, outs):
"""Interface for hybrid schedule_conv2d_NHWC_hybrid_SVE"""
return schedule_conv2d_NHWC(cfg, outs, False)
Loading
Loading