Skip to content

Commit

Permalink
[SVE][TOPI] Add conv2d NHWC hybrid SVE schedule for arm_cpu
Browse files Browse the repository at this point in the history
This commit adds an `arm_cpu` conv2d NHWC schedule which generates SVE instructions by extending the hybrid GeMM approach implemented in #16106 to use scalable expressions as splitting factors.

Various vscale-related fixes needed to implement the schedule are also included, such as:

 - adding vscale bounds in the `ConstIntBoundAnalyzer` and `IntervalSetEvaluator`
 - simplifying `MinNode` and `MaxNode` that have scalable expression operands in `RewriteSimplifier`, which would appear when defining the shape of a buffer padded to be a multiple of vscale and in its respective buffer access indices (e.g. `C_1 = T.Buffer((1024 * (T.vscale() * 16 + 256 - 16 % T.vscale() * 16),), data=C)` instead of `C_1 = T.Buffer((1024 * (T.max(255, T.vscale() * 16 + 255 - 16 % T.vscale() * 16) + 1),), data=C)`)

The correctness of the new schedule is checked using a TOPI test, while the presence of generated SVE instructions is verified by a codegen test. The new `rewrite_simplify` rules are also covered by additional test cases.
  • Loading branch information
Anndrey24 committed Apr 17, 2024
1 parent d030ce2 commit 59b5b30
Show file tree
Hide file tree
Showing 16 changed files with 232 additions and 38 deletions.
11 changes: 11 additions & 0 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,17 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
)
# Non-quantized cases
if is_aarch64 and data.dtype in ["float32", "float16"]:
if target.features.has_sve:
# This strategy is currently suboptimal because of LLVM's limited support
# for scalable vector alias analysis, which causes redundant loads / stores
# to remain after LLVM's optimisation passes, unlike the non-scalable case.
# Hence, it is given a lower priority level until these issues are resolved.
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_SVE),
name="conv2d_NHWC_hybrid_SVE.arm_cpu",
plevel=5,
)
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid),
Expand Down
14 changes: 13 additions & 1 deletion python/tvm/topi/arm_cpu/arm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Arm target utility functions"""

import tvm
from tvm.target import Target


def get_tiling_B_transformed(interleave_A, in_dtype):
def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False):
"""Compute the tiling information for matrix B', where B'
is the tiled, interleaved (and transposed) version of matrix B in C=A*B.
Expand All @@ -40,6 +41,8 @@ def get_tiling_B_transformed(interleave_A, in_dtype):
determines if A is expected to be interleaved
in_dtype : str
input datatype
use_scalable_vectors : bool
determines if operations on scalable vectors are expected
Returns
Expand Down Expand Up @@ -75,6 +78,15 @@ def get_tiling_B_transformed(interleave_A, in_dtype):
tile_N = 4
tile_K = 16
# In non-quantized cases, A is not interleaved.
elif use_scalable_vectors:
if in_dtype == "float16":
# Each load from B' contains 32 * vscale elements (i.e. 32 * vscale columns from B)
tile_N = 32 * tvm.tir.vscale()
else:
# Each load from B' contains 16 * vscale elements (i.e. 16 * vscale columns from B)
tile_N = 16 * tvm.tir.vscale()
# We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B)
tile_K = 4
elif in_dtype == "float16" and target.features.has_fp16_simd:
# Each load from B' contains 32 elements (i.e. 32 columns from B)
# We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B)
Expand Down
42 changes: 38 additions & 4 deletions python/tvm/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,14 +517,34 @@ def schedule_conv2d_nhwc_dsp(cfg, outs):
return conv2d_nhwc_dsp_schedule(cfg, outs)


def compute_conv2d_NHWC(cfg, data, kernel, strides, padding, dilation, out_dtype, interleave_A):
def compute_conv2d_NHWC(
cfg,
data,
kernel,
strides,
padding,
dilation,
out_dtype,
interleave_A,
use_scalable_vectors=False,
):
N, IH, IW, IC = get_const_tuple(data.shape)
KH, KW, _, OC = get_const_tuple(kernel.shape)
tile_N, tile_K = get_tiling_B_transformed(interleave_A, data.dtype)
tile_N, tile_K = get_tiling_B_transformed(interleave_A, data.dtype, use_scalable_vectors)

kernel = nn.conv2d_gemm_weight_transform(kernel, tile_N, tile_K)
kernel = nn.conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors)
return compute_conv2d_gemm_without_weight_transform(
cfg, data, kernel, strides, padding, dilation, out_dtype, (KH, KW), OC, interleave_A
cfg,
data,
kernel,
strides,
padding,
dilation,
out_dtype,
(KH, KW),
OC,
interleave_A,
use_scalable_vectors,
)


Expand Down Expand Up @@ -620,3 +640,17 @@ def schedule_conv2d_NHWC_hybrid(cfg, outs):
def schedule_conv2d_NHWC_hybrid_without_transform(cfg, outs):
"""Interface for hybrid schedule_conv2d_NHWC_hybrid"""
return schedule_conv2d_NHWC(cfg, outs, False)


@autotvm.register_topi_compute("conv2d_NHWC_hybrid_SVE.arm_cpu")
def compute_conv2d_NHWC_hybrid_SVE(cfg, data, kernel, strides, padding, dilation, out_dtype):
"""Interface for hybrid compute_conv2d_NHWC_hybrid_SVE"""
return compute_conv2d_NHWC(
cfg, data, kernel, strides, padding, dilation, out_dtype, False, True
)


@autotvm.register_topi_schedule("conv2d_NHWC_hybrid_SVE.arm_cpu")
def schedule_conv2d_NHWC_hybrid_SVE(cfg, outs):
"""Interface for hybrid schedule_conv2d_NHWC_hybrid_SVE"""
return schedule_conv2d_NHWC(cfg, outs, False)
63 changes: 44 additions & 19 deletions python/tvm/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def compute_conv2d_gemm_without_weight_transform(
kernel_size,
output_channels,
interleave_A,
use_scalable_vectors=False,
):
"""Compute conv2d by transforming the input,
executing GEMM and transforming the output back"""
Expand Down Expand Up @@ -120,13 +121,7 @@ def compute_conv2d_gemm_without_weight_transform(
)

# Pad if necessary
N_transformed = B_interleaved_t.shape[0]
if in_dtype in ["int8", "uint8"]:
tile_N = B_interleaved_t.shape[2]
tile_K_B = B_interleaved_t.shape[3]
else:
tile_N = B_interleaved_t.shape[3]
tile_K_B = B_interleaved_t.shape[2]
tile_N, tile_K_B = get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors)

# Select the tiling strategy for A.
# The tiling information is chosen to maximize register usage during
Expand Down Expand Up @@ -164,6 +159,7 @@ def compute_conv2d_gemm_without_weight_transform(
tile_K_A = 4

pad_M = 0
pad_N = 0
pad_K = 0

if M % tile_M != 0:
Expand All @@ -172,9 +168,17 @@ def compute_conv2d_gemm_without_weight_transform(
if K % tile_K_A != 0:
pad_K = tile_K_A - (K % tile_K_A)

if N % tile_N != 0:
pad_N = tile_N - (N % tile_N)

M_padded = M + pad_M
K_padded = K + pad_K
N_padded = N_transformed * tile_N

if use_scalable_vectors:
N_padded = N + pad_N
else:
N_transformed = B_interleaved_t.shape[0]
N_padded = N_transformed * tile_N

pad_before = (0, 0, 0)
pad_after = (0, pad_M, pad_K)
Expand Down Expand Up @@ -322,10 +326,24 @@ def compute_conv2d_gemm_without_weight_transform(
tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
- tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
)
elif use_scalable_vectors:
assert len(B_interleaved_t.shape) == 2
C = te.compute(
(batches, M_padded, N_padded),
lambda b, x, y: te.sum(
A[b, x, k].astype(in_dtype) * B_interleaved_t[k, y].astype(in_dtype),
axis=k,
),
name="C",
)
# Ensure padding on the N axis does not get removed during tir passes
# by adding a dummy reference to the specific padded area of the result
zero = (
tvm.tir.const(1, C.dtype) * C[0, 0, N_padded - 1]
- tvm.tir.const(1, C.dtype) * C[0, 0, N_padded - 1]
)
else:
# Configuration space
configure_knobs(cfg, M_padded, K_padded, target)

assert len(B_interleaved_t.shape) == 4
C = te.compute(
(batches, M_padded, N_padded),
lambda b, x, y: te.sum(
Expand Down Expand Up @@ -356,6 +374,7 @@ def compute_conv2d_gemm_without_weight_transform(
out_shape,
lambda b, x, y, z: (C(b, y + OW * x, z) + zero).astype(out_dtype),
name="conv2d_gemm_output",
attrs={"use_scalable_vectors": use_scalable_vectors},
)
return out

Expand Down Expand Up @@ -463,7 +482,8 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
C = out.op.input_tensors[0]
A = C.op.input_tensors[0]
in_type = A.dtype
y_tile_size, _ = get_tiling_B_transformed(False, in_type)
use_scalable_vectors = out.op.attrs["use_scalable_vectors"].value
y_tile_size, _ = get_tiling_B_transformed(False, in_type, use_scalable_vectors)

# Computation
b, x, y = C.op.axis
Expand All @@ -478,23 +498,21 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
s[C].tensorize(y_inner, gemm_acc)
s[C].parallel(x_outer)
else:
k_outer, k_inner = s[C].split(k, 4)
x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size)
y_inner_outer, y_inner_inner = s[C].split(y_inner, nparts=4)
k_outer, k_inner = s[C].split(k, factor=4)
x_outer, x_inner = s[C].split(x, factor=4)
y_outer, y_inner = s[C].split(y, factor=y_tile_size, disable_predication=True)
b_x_outer_fused = s[C].fuse(b, x_outer)
s[C].parallel(b_x_outer_fused)
s[C].reorder(
b_x_outer_fused,
y_outer,
k_outer,
k_inner,
y_inner_outer,
x_inner,
y_inner_inner,
y_inner,
)
s[C].unroll(y_inner_outer)
s[C].unroll(x_inner)
s[C].vectorize(y_inner_inner)
s[C].vectorize(y_inner)

# Input transform
if A.op.name == "A_padded_K" or A.op.name == "A_padded_M":
Expand Down Expand Up @@ -547,6 +565,13 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
s[A_pad].parallel(n_h_fused)
s[A_pad].vectorize(c)

# Weight transform
if use_scalable_vectors:
B_pad = C.op.input_tensors[1]
s[B_pad].parallel(B_pad.op.axis[0])
B_flat = B_pad.op.input_tensors[0]
s[B_flat].compute_inline()

# Output transform
if out != final_out:
n, h, w, c = out.op.axis
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def conv2d_NCHWc_int8(
)


def conv2d_gemm_weight_transform(kernel, tile_N, tile_K):
def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors=False):
"""Weight transformation for winograd
Parameters
Expand All @@ -626,6 +626,8 @@ def conv2d_gemm_weight_transform(kernel, tile_N, tile_K):
Tile size across N axis of the weight transformation for ConvGemm. (N = OC)
tile_K: int
Tile size across K axis of the weight transformation for ConvGemm. (K = KW * KH * IC)
use_scalable_vectors : bool
determines if operations on scalable vectors are expected
Returns
-------
Expand All @@ -650,6 +652,9 @@ def conv2d_gemm_weight_transform(kernel, tile_N, tile_K):
kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N), name="weight_padding"
)

if use_scalable_vectors:
return kernel_flat

if kernel.dtype in ["int8", "uint8"]:
B_inter_t = te.compute(
(N_padded // tile_N, K_padded // tile_K, tile_N, tile_K),
Expand Down
2 changes: 1 addition & 1 deletion src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
// "T.vscale" and the compile target uses a scalable architecture extension like
// SVE, we can make some assumptions about the value of vscale and iterate over a
// space of pre-defined values to attempt to prove the expression.
if (tir::CheckContains::ExprContains(expr, IsVScaleCall)) {
if (tir::CheckContains::ExprContains(simplified, IsVScaleCall)) {
Target curr_target = tvm::Target::Current();
if (curr_target.defined() && curr_target->features.defined() &&
(curr_target->features.find("has_sve") != curr_target->features.end()) &&
Expand Down
2 changes: 2 additions & 0 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ class ConstIntBoundAnalyzer::Impl
return VisitLeftShift(op);
} else if (op->op.same_as(tir::builtin::bitwise_and())) {
return VisitBitwiseAnd(op);
} else if (op->op.same_as(tir::builtin::vscale())) {
return MakeBound(1, 16);
} else {
return Everything(op->dtype);
}
Expand Down
6 changes: 6 additions & 0 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,12 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
}

IntervalSet VisitExpr_(const CallNode* op) final {
if (op->op.same_as(tir::builtin::vscale()))
return IntervalSet(GetRef<PrimExpr>(op), GetRef<PrimExpr>(op));
return IntervalSet::Everything();
}

IntervalSet VisitExprDefault_(const Object* op) final {
DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey();
return IntervalSet::Everything();
Expand Down
20 changes: 20 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1415,6 +1415,16 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) {
}
}

// vscale expression comparison
if (ContainsVscaleCall(op->a) || ContainsVscaleCall(op->b)) {
if (analyzer_->CanProve(op->a <= op->b)) {
return op->a;
}
if (analyzer_->CanProve(op->b <= op->a)) {
return op->b;
}
}

// canonicalization
TVM_TRY_RECURSIVE_REWRITE(min(min(x, c1), y), min(min(x, y), c1));
TVM_TRY_RECURSIVE_REWRITE_IF(min(c1 - x, c2), c1 - max(x, c1 - c2), c2.Eval()->value != 0);
Expand Down Expand Up @@ -1598,6 +1608,16 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) {
}
}

// vscale expression comparison
if (ContainsVscaleCall(op->a) || ContainsVscaleCall(op->b)) {
if (analyzer_->CanProve(op->a >= op->b)) {
return op->a;
}
if (analyzer_->CanProve(op->b >= op->a)) {
return op->b;
}
}

// canonicalization
TVM_TRY_RECURSIVE_REWRITE(max(max(x, c1), y), max(max(x, y), c1));
TVM_TRY_RECURSIVE_REWRITE_IF(max(c1 - x, c2), c1 - min(x, c1 - c2), c2.Eval()->value != 0);
Expand Down
5 changes: 5 additions & 0 deletions src/arith/scalable_expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <vector>

#include "../tir/analysis/check_contains.h"
#include "../tir/transforms/replace_selected_expr.h"
#include "./pattern_match.h"

Expand All @@ -42,6 +43,10 @@ bool IsVScaleCall(const PrimExpr& expr) {
return false;
}

bool ContainsVscaleCall(const PrimExpr& expr) {
return tir::CheckContains::ExprContains(expr, IsVScaleCall);
}

PrimExpr SubstituteVScaleWithKnownValue(const PrimExpr& expr, unsigned int vscale_value) {
std::function<bool(const PrimExpr&)> predicate_selector = [](const PrimExpr& current_expr) {
return IsVScaleCall(current_expr);
Expand Down
7 changes: 7 additions & 0 deletions src/arith/scalable_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ static const std::vector<unsigned int> kAArch64VScaleValues = {1, 2, 3, 4, 5,
*/
bool IsVScaleCall(const PrimExpr& expr);

/*!
* \brief Check if an expr contains a call to the vscale intrinsic.
* \param expr The expr to check
* \return True if the expr contains a call to the vscale intrinsic, false if not.
*/
bool ContainsVscaleCall(const PrimExpr& expr);

/*!
* \brief Substitute a vscale intrinsic call with a known scalar value.
* \param expr The expr to apply substitutions to.
Expand Down
4 changes: 3 additions & 1 deletion src/relay/backend/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <tvm/runtime/ndarray.h>
#include <tvm/tir/stmt_functor.h>

#include "../../arith/scalable_expression.h"
#include "../../te/operation/create_primfunc.h"

namespace tvm {
Expand Down Expand Up @@ -421,7 +422,8 @@ Optional<tir::PrimFunc> DefaultTIRConverterImpl(const Array<te::Tensor>& args,
bool dynamic_loop_extent = false;
tir::PostOrderVisit(func->body, [&dynamic_loop_extent](const ObjectRef& obj) -> void {
if (const auto* loop = obj.as<tir::ForNode>()) {
if (!loop->extent->IsInstance<IntImmNode>()) {
if (!loop->extent->IsInstance<IntImmNode>() &&
!tvm::arith::ContainsVscaleCall(loop->extent)) {
dynamic_loop_extent = true;
}
}
Expand Down
Loading

0 comments on commit 59b5b30

Please sign in to comment.