From 07b4d68b48e459c84cc837bf4f1700ae1f26cbd9 Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Mon, 22 Apr 2024 09:18:23 +0000 Subject: [PATCH 1/4] [SVE][TOPI] Add conv2d NHWC hybrid SVE schedule for `arm_cpu` This commit adds an `arm_cpu` conv2d NHWC schedule which generates SVE instructions by extending the hybrid GeMM approach implemented in #16106 to use scalable expressions as splitting factors. Various vscale-related fixes needed to implement the schedule are also included, such as: - adding vscale bounds in the `ConstIntBoundAnalyzer` and `IntervalSetEvaluator` - simplifying `MinNode` and `MaxNode` that have scalable expression operands in `RewriteSimplifier`, which would appear when defining the shape of a buffer padded to be a multiple of vscale and in its respective buffer access indices (e.g. `C_1 = T.Buffer((1024 * (T.vscale() * 16 + 256 - 16 % T.vscale() * 16),), data=C)` instead of `C_1 = T.Buffer((1024 * (T.max(255, T.vscale() * 16 + 255 - 16 % T.vscale() * 16) + 1),), data=C)`) The correctness of the new schedule is checked using a TOPI test, while the presence of generated SVE instructions is verified by a codegen test. The new `rewrite_simplify` rules are also covered by additional test cases. --- python/tvm/relay/op/strategy/arm_cpu.py | 11 ++++ python/tvm/topi/arm_cpu/arm_utils.py | 14 ++++- python/tvm/topi/arm_cpu/conv2d.py | 42 +++++++++++-- python/tvm/topi/arm_cpu/conv2d_gemm.py | 63 +++++++++++++------ python/tvm/topi/nn/conv2d.py | 7 ++- src/arith/analyzer.cc | 2 +- src/arith/const_int_bound.cc | 2 + src/arith/int_set.cc | 6 ++ src/arith/rewrite_simplify.cc | 20 ++++++ src/arith/scalable_expression.cc | 5 ++ src/arith/scalable_expression.h | 7 +++ src/relay/backend/utils.cc | 4 +- src/tir/transforms/storage_rewrite.cc | 7 +++ .../arith/test_arith_rewrite_simplify.py | 13 ++++ .../codegen/test_target_codegen_aarch64.py | 35 +++++++++++ tests/python/topi/test_topi_conv2d_nhwc.py | 32 ++++++---- 16 files changed, 232 insertions(+), 38 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 1a2f7abb6f37..88202458afbf 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -252,6 +252,17 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): ) # Non-quantized cases if is_aarch64 and data.dtype in ["float32", "float16"]: + if target.features.has_sve: + # This strategy is currently suboptimal because of LLVM's limited support + # for scalable vector alias analysis, which causes redundant loads / stores + # to remain after LLVM's optimisation passes, unlike the non-scalable case. + # Hence, it is given a lower priority level until these issues are resolved. + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_SVE), + name="conv2d_NHWC_hybrid_SVE.arm_cpu", + plevel=5, + ) strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid), diff --git a/python/tvm/topi/arm_cpu/arm_utils.py b/python/tvm/topi/arm_cpu/arm_utils.py index 91a6762717c9..0dd17ce4fa34 100644 --- a/python/tvm/topi/arm_cpu/arm_utils.py +++ b/python/tvm/topi/arm_cpu/arm_utils.py @@ -17,10 +17,11 @@ # pylint: disable=invalid-name,unused-variable,unused-argument,no-member """Arm target utility functions""" +import tvm from tvm.target import Target -def get_tiling_B_transformed(interleave_A, in_dtype): +def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False): """Compute the tiling information for matrix B', where B' is the tiled, interleaved (and transposed) version of matrix B in C=A*B. @@ -40,6 +41,8 @@ def get_tiling_B_transformed(interleave_A, in_dtype): determines if A is expected to be interleaved in_dtype : str input datatype + use_scalable_vectors : bool + determines if operations on scalable vectors are expected Returns @@ -75,6 +78,15 @@ def get_tiling_B_transformed(interleave_A, in_dtype): tile_N = 4 tile_K = 16 # In non-quantized cases, A is not interleaved. + elif use_scalable_vectors: + if in_dtype == "float16": + # Each load from B' contains 32 * vscale elements (i.e. 32 * vscale columns from B) + tile_N = 32 * tvm.tir.vscale() + else: + # Each load from B' contains 16 * vscale elements (i.e. 16 * vscale columns from B) + tile_N = 16 * tvm.tir.vscale() + # We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B) + tile_K = 4 elif in_dtype == "float16" and target.features.has_fp16_simd: # Each load from B' contains 32 elements (i.e. 32 columns from B) # We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B) diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index 90e199f36a03..b6b5df8677ad 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -517,14 +517,34 @@ def schedule_conv2d_nhwc_dsp(cfg, outs): return conv2d_nhwc_dsp_schedule(cfg, outs) -def compute_conv2d_NHWC(cfg, data, kernel, strides, padding, dilation, out_dtype, interleave_A): +def compute_conv2d_NHWC( + cfg, + data, + kernel, + strides, + padding, + dilation, + out_dtype, + interleave_A, + use_scalable_vectors=False, +): N, IH, IW, IC = get_const_tuple(data.shape) KH, KW, _, OC = get_const_tuple(kernel.shape) - tile_N, tile_K = get_tiling_B_transformed(interleave_A, data.dtype) + tile_N, tile_K = get_tiling_B_transformed(interleave_A, data.dtype, use_scalable_vectors) - kernel = nn.conv2d_gemm_weight_transform(kernel, tile_N, tile_K) + kernel = nn.conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors) return compute_conv2d_gemm_without_weight_transform( - cfg, data, kernel, strides, padding, dilation, out_dtype, (KH, KW), OC, interleave_A + cfg, + data, + kernel, + strides, + padding, + dilation, + out_dtype, + (KH, KW), + OC, + interleave_A, + use_scalable_vectors, ) @@ -620,3 +640,17 @@ def schedule_conv2d_NHWC_hybrid(cfg, outs): def schedule_conv2d_NHWC_hybrid_without_transform(cfg, outs): """Interface for hybrid schedule_conv2d_NHWC_hybrid""" return schedule_conv2d_NHWC(cfg, outs, False) + + +@autotvm.register_topi_compute("conv2d_NHWC_hybrid_SVE.arm_cpu") +def compute_conv2d_NHWC_hybrid_SVE(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Interface for hybrid compute_conv2d_NHWC_hybrid_SVE""" + return compute_conv2d_NHWC( + cfg, data, kernel, strides, padding, dilation, out_dtype, False, True + ) + + +@autotvm.register_topi_schedule("conv2d_NHWC_hybrid_SVE.arm_cpu") +def schedule_conv2d_NHWC_hybrid_SVE(cfg, outs): + """Interface for hybrid schedule_conv2d_NHWC_hybrid_SVE""" + return schedule_conv2d_NHWC(cfg, outs, False) diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index b725984ae1d8..c454d72e2642 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -67,6 +67,7 @@ def compute_conv2d_gemm_without_weight_transform( kernel_size, output_channels, interleave_A, + use_scalable_vectors=False, ): """Compute conv2d by transforming the input, executing GEMM and transforming the output back""" @@ -120,13 +121,7 @@ def compute_conv2d_gemm_without_weight_transform( ) # Pad if necessary - N_transformed = B_interleaved_t.shape[0] - if in_dtype in ["int8", "uint8"]: - tile_N = B_interleaved_t.shape[2] - tile_K_B = B_interleaved_t.shape[3] - else: - tile_N = B_interleaved_t.shape[3] - tile_K_B = B_interleaved_t.shape[2] + tile_N, tile_K_B = get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors) # Select the tiling strategy for A. # The tiling information is chosen to maximize register usage during @@ -164,6 +159,7 @@ def compute_conv2d_gemm_without_weight_transform( tile_K_A = 4 pad_M = 0 + pad_N = 0 pad_K = 0 if M % tile_M != 0: @@ -172,9 +168,17 @@ def compute_conv2d_gemm_without_weight_transform( if K % tile_K_A != 0: pad_K = tile_K_A - (K % tile_K_A) + if N % tile_N != 0: + pad_N = tile_N - (N % tile_N) + M_padded = M + pad_M K_padded = K + pad_K - N_padded = N_transformed * tile_N + + if use_scalable_vectors: + N_padded = N + pad_N + else: + N_transformed = B_interleaved_t.shape[0] + N_padded = N_transformed * tile_N pad_before = (0, 0, 0) pad_after = (0, pad_M, pad_K) @@ -322,10 +326,24 @@ def compute_conv2d_gemm_without_weight_transform( tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1] - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1] ) + elif use_scalable_vectors: + assert len(B_interleaved_t.shape) == 2 + C = te.compute( + (batches, M_padded, N_padded), + lambda b, x, y: te.sum( + A[b, x, k].astype(in_dtype) * B_interleaved_t[k, y].astype(in_dtype), + axis=k, + ), + name="C", + ) + # Ensure padding on the N axis does not get removed during tir passes + # by adding a dummy reference to the specific padded area of the result + zero = ( + tvm.tir.const(1, C.dtype) * C[0, 0, N_padded - 1] + - tvm.tir.const(1, C.dtype) * C[0, 0, N_padded - 1] + ) else: - # Configuration space - configure_knobs(cfg, M_padded, K_padded, target) - + assert len(B_interleaved_t.shape) == 4 C = te.compute( (batches, M_padded, N_padded), lambda b, x, y: te.sum( @@ -356,6 +374,7 @@ def compute_conv2d_gemm_without_weight_transform( out_shape, lambda b, x, y, z: (C(b, y + OW * x, z) + zero).astype(out_dtype), name="conv2d_gemm_output", + attrs={"use_scalable_vectors": use_scalable_vectors}, ) return out @@ -463,7 +482,8 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): C = out.op.input_tensors[0] A = C.op.input_tensors[0] in_type = A.dtype - y_tile_size, _ = get_tiling_B_transformed(False, in_type) + use_scalable_vectors = out.op.attrs["use_scalable_vectors"].value + y_tile_size, _ = get_tiling_B_transformed(False, in_type, use_scalable_vectors) # Computation b, x, y = C.op.axis @@ -478,9 +498,9 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): s[C].tensorize(y_inner, gemm_acc) s[C].parallel(x_outer) else: - k_outer, k_inner = s[C].split(k, 4) - x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size) - y_inner_outer, y_inner_inner = s[C].split(y_inner, nparts=4) + k_outer, k_inner = s[C].split(k, factor=4) + x_outer, x_inner = s[C].split(x, factor=4) + y_outer, y_inner = s[C].split(y, factor=y_tile_size, disable_predication=True) b_x_outer_fused = s[C].fuse(b, x_outer) s[C].parallel(b_x_outer_fused) s[C].reorder( @@ -488,13 +508,11 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): y_outer, k_outer, k_inner, - y_inner_outer, x_inner, - y_inner_inner, + y_inner, ) - s[C].unroll(y_inner_outer) s[C].unroll(x_inner) - s[C].vectorize(y_inner_inner) + s[C].vectorize(y_inner) # Input transform if A.op.name == "A_padded_K" or A.op.name == "A_padded_M": @@ -547,6 +565,13 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): s[A_pad].parallel(n_h_fused) s[A_pad].vectorize(c) + # Weight transform + if use_scalable_vectors: + B_pad = C.op.input_tensors[1] + s[B_pad].parallel(B_pad.op.axis[0]) + B_flat = B_pad.op.input_tensors[0] + s[B_flat].compute_inline() + # Output transform if out != final_out: n, h, w, c = out.op.axis diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 93ad00586a6f..e21c0bd4e106 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -615,7 +615,7 @@ def conv2d_NCHWc_int8( ) -def conv2d_gemm_weight_transform(kernel, tile_N, tile_K): +def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors=False): """Weight transformation for winograd Parameters @@ -626,6 +626,8 @@ def conv2d_gemm_weight_transform(kernel, tile_N, tile_K): Tile size across N axis of the weight transformation for ConvGemm. (N = OC) tile_K: int Tile size across K axis of the weight transformation for ConvGemm. (K = KW * KH * IC) + use_scalable_vectors : bool + determines if operations on scalable vectors are expected Returns ------- @@ -650,6 +652,9 @@ def conv2d_gemm_weight_transform(kernel, tile_N, tile_K): kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N), name="weight_padding" ) + if use_scalable_vectors: + return kernel_flat + if kernel.dtype in ["int8", "uint8"]: B_inter_t = te.compute( (N_padded // tile_N, K_padded // tile_K, tile_N, tile_K), diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index db39e4c0a42a..23e86cf9e7af 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -234,7 +234,7 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { // "T.vscale" and the compile target uses a scalable architecture extension like // SVE, we can make some assumptions about the value of vscale and iterate over a // space of pre-defined values to attempt to prove the expression. - if (tir::CheckContains::ExprContains(expr, IsVScaleCall)) { + if (tir::CheckContains::ExprContains(simplified, IsVScaleCall)) { if (TargetHasSVE()) { return CanProveVscaleExpressionFromKnownValues(this, simplified, kAArch64VScaleValues); } diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 8d41f0f2c6e7..3b4c6a62c7fa 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -369,6 +369,8 @@ class ConstIntBoundAnalyzer::Impl return VisitLeftShift(op); } else if (op->op.same_as(tir::builtin::bitwise_and())) { return VisitBitwiseAnd(op); + } else if (op->op.same_as(tir::builtin::vscale())) { + return MakeBound(1, 16); } else { return Everything(op->dtype); } diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 579870e5f5c0..587e0121f057 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -532,6 +532,12 @@ class IntervalSetEvaluator : public ExprFunctor { return IntervalSet::SinglePoint(GetRef(op)); } + IntervalSet VisitExpr_(const CallNode* op) final { + if (op->op.same_as(tir::builtin::vscale())) + return IntervalSet(GetRef(op), GetRef(op)); + return IntervalSet::Everything(); + } + IntervalSet VisitExprDefault_(const Object* op) final { DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey(); return IntervalSet::Everything(); diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index a4602bb8b96b..42447ef2f8f2 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1415,6 +1415,16 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) { } } + // vscale expression comparison + if (ContainsVscaleCall(op->a) || ContainsVscaleCall(op->b)) { + if (analyzer_->CanProve(op->a <= op->b)) { + return op->a; + } + if (analyzer_->CanProve(op->b <= op->a)) { + return op->b; + } + } + // canonicalization TVM_TRY_RECURSIVE_REWRITE(min(min(x, c1), y), min(min(x, y), c1)); TVM_TRY_RECURSIVE_REWRITE_IF(min(c1 - x, c2), c1 - max(x, c1 - c2), c2.Eval()->value != 0); @@ -1598,6 +1608,16 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { } } + // vscale expression comparison + if (ContainsVscaleCall(op->a) || ContainsVscaleCall(op->b)) { + if (analyzer_->CanProve(op->a >= op->b)) { + return op->a; + } + if (analyzer_->CanProve(op->b >= op->a)) { + return op->b; + } + } + // canonicalization TVM_TRY_RECURSIVE_REWRITE(max(max(x, c1), y), max(max(x, y), c1)); TVM_TRY_RECURSIVE_REWRITE_IF(max(c1 - x, c2), c1 - min(x, c1 - c2), c2.Eval()->value != 0); diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc index 0c5aea4e7da7..2df035d6151a 100644 --- a/src/arith/scalable_expression.cc +++ b/src/arith/scalable_expression.cc @@ -29,6 +29,7 @@ #include +#include "../tir/analysis/check_contains.h" #include "../tir/transforms/replace_selected_expr.h" #include "./pattern_match.h" @@ -42,6 +43,10 @@ bool IsVScaleCall(const PrimExpr& expr) { return false; } +bool ContainsVscaleCall(const PrimExpr& expr) { + return tir::CheckContains::ExprContains(expr, IsVScaleCall); +} + PrimExpr SubstituteVScaleWithKnownValue(const PrimExpr& expr, unsigned int vscale_value) { std::function predicate_selector = [](const PrimExpr& current_expr) { return IsVScaleCall(current_expr); diff --git a/src/arith/scalable_expression.h b/src/arith/scalable_expression.h index 091783a59f8c..800d920fb707 100644 --- a/src/arith/scalable_expression.h +++ b/src/arith/scalable_expression.h @@ -45,6 +45,13 @@ static const std::vector kAArch64VScaleValues = {1, 2, 3, 4, 5, */ bool IsVScaleCall(const PrimExpr& expr); +/*! + * \brief Check if an expr contains a call to the vscale intrinsic. + * \param expr The expr to check + * \return True if the expr contains a call to the vscale intrinsic, false if not. + */ +bool ContainsVscaleCall(const PrimExpr& expr); + /*! * \brief Substitute a vscale intrinsic call with a known scalar value. * \param expr The expr to apply substitutions to. diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index f7af74c4dbe0..b7453590742d 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -30,6 +30,7 @@ #include #include +#include "../../arith/scalable_expression.h" #include "../../te/operation/create_primfunc.h" namespace tvm { @@ -421,7 +422,8 @@ Optional DefaultTIRConverterImpl(const Array& args, bool dynamic_loop_extent = false; tir::PostOrderVisit(func->body, [&dynamic_loop_extent](const ObjectRef& obj) -> void { if (const auto* loop = obj.as()) { - if (!loop->extent->IsInstance()) { + if (!loop->extent->IsInstance() && + !tvm::arith::ContainsVscaleCall(loop->extent)) { dynamic_loop_extent = true; } } diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 3f34f2e870fd..2ebb7671492a 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -1473,6 +1473,13 @@ class VectorTypeRewriter : public StmtExprMutator { Array indices = node->indices; const PrimExpr& last_dim_index = indices[indices.size() - 1]; const RampNode* ramp_index = indices[indices.size() - 1].as(); + + if (node->buffer->dtype.is_scalable_vector() || last_dim_index.dtype().is_scalable_vector()) { + // Scalable types are not currently supported in storage_rewrite. Scalable buffer + // accesses are not currently checked and therefore are not rewritten. + return {node, shuffle_index}; + } + if (ramp_index && is_one(ramp_index->stride) && ramp_index->lanes->IsInstance()) { int lanes = static_cast(Downcast(ramp_index->lanes)->value); PrimExpr new_index = ramp_index->base / make_const(ramp_index->base.dtype(), lanes); diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 6180167555d2..816c85b834ee 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -745,6 +745,11 @@ class TestMinIndex(BaseCompare): TestCase(tvm.te.min(tvm.te.max(x, 4), fld(x + 3, 4) * 4), tvm.te.max(x, 4), x > 0), TestCase(tvm.te.min(fld(x, 10), fld(y, 10)), fld(tvm.te.min(x, y), 10)), TestCase(tvm.te.min(fld(x, (-10)), fld(y, (-10))), fld(tvm.te.max(x, y), (-10))), + # vscale expression comparison + TestCase(tvm.te.min(x + tir.vscale() * 4, x), x), + TestCase(tvm.te.min(x - tir.vscale() * 4, x), x + tir.vscale() * -4), + TestCase(tvm.te.min(x + tir.vscale() * 4, x + tir.vscale() * 8), tir.vscale() * 4 + x), + TestCase(tvm.te.min(x + tir.vscale() * 4 - flm(4, tir.vscale() * 4), x), x), ) @@ -811,6 +816,14 @@ class TestMaxIndex(BaseCompare): TestCase(tvm.te.max(fld(x + 3, 4) * 4, x), fld(x + 3, 4) * 4), TestCase(tvm.te.max(fld(x, 4) * 4, x), x), TestCase(tvm.te.max(x, fld(x, 4) * 4), x), + # vscale expression comparison + TestCase(tvm.te.max(x + tir.vscale() * 4, x), x + tir.vscale() * 4), + TestCase(tvm.te.max(x - tir.vscale() * 4, x), x), + TestCase(tvm.te.max(x + tir.vscale() * 4, x + tir.vscale() * 4), x + tir.vscale() * 4), + TestCase( + tvm.te.max(x + tir.vscale() * 4 - flm(4, tir.vscale() * 4), x), + x + tir.vscale() * 4 - flm(4, tir.vscale() * 4), + ), ) diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 80aedd60b3f7..8f22ba5b73ed 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -645,5 +645,40 @@ def prim_func(a: T.handle, c: T.handle): tvm.build(prim_func, target=target) +@pytest.mark.skipif( + llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" +) +@pytest.mark.parametrize("dtype", ["float16", "float32"]) +def test_conv2d_sve(dtype): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + def check_correct_assembly(dtype): + A = te.placeholder((1, 32, 32, 3), dtype=dtype, name="A") + W = te.placeholder((3, 3, 3, 8), dtype=dtype, name="B") + stride = padding = dilation = 1 + + compute = tvm.topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE + schedule = tvm.topi.arm_cpu.schedule_conv2d_NHWC_hybrid_SVE + B = compute(A, W, stride, padding, dilation, dtype) + s = schedule([B]) + + f = tvm.build(s, [A, W, B], target) + assembly = f.get_source("asm") + + loads = re.findall(r"ld1[r]?[q]?[whdb]\t{\s?z", assembly) + compute_ops = re.findall( + r"fm(la|ad)\tz\d+.[shdb], (p\d+\/[zm], )?z\d+.[shdb], z\d+.[shdb]", + assembly, + ) + stores = re.findall(r"st1[whdb]\t{\s?z", assembly) + + assert len(loads) > 0 + assert len(compute_ops) > 0 + assert len(stores) > 0 + + with tvm.target.Target(target): + check_correct_assembly(dtype=dtype) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py b/tests/python/topi/test_topi_conv2d_nhwc.py index 05f9cb9c0570..0084d3f4b647 100644 --- a/tests/python/topi/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/test_topi_conv2d_nhwc.py @@ -57,11 +57,17 @@ topi.arm_cpu.compute_conv2d_NHWC_hybrid, topi.arm_cpu.schedule_conv2d_NHWC_hybrid, ), + ( + "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.6a,+sve", + topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE, + topi.arm_cpu.schedule_conv2d_NHWC_hybrid_SVE, + ), ) dtype = tvm.testing.parameter("float32") batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = tvm.testing.parameters( + (1, 1, 3, 15, 1, 1, "SAME", 1), (1, 256, 32, 256, 3, 1, "SAME", 1), (4, 128, 16, 128, 5, 2, "SAME", 1), (4, 128, 16, 256, 5, 2, "SAME", 1), @@ -100,19 +106,23 @@ def test_conv2d_nhwc_gemm_fp32(device, ref_data, dtype, stride, padding, dilatio target, compute, schedule = device dev = tvm.device(target, 0) - with tvm.target.Target(target): + with tvm.target.Target(target) as target: B = compute(A, W, stride, padding, dilation, dtype) s = schedule([B]) - a = tvm.nd.array(a_np, dev) - w = tvm.nd.array(w_np, dev) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) - func = tvm.build(s, [A, W, B], target) - - build_only = platform.machine() != "aarch64" - if build_only: - return - - func(a, w, b) + a = tvm.nd.array(a_np, dev) + w = tvm.nd.array(w_np, dev) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) + func = tvm.build(s, [A, W, B], target) + + # Run only on AArch64 devices + # Do not run SVE schedules on non-SVE devices + build_only = platform.machine() != "aarch64" or ( + target.features.has_sve and not tvm.testing.requires_aarch64_sve.run_time_check() + ) + if build_only: + return + + func(a, w, b) tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) From 0d408e4d2ef884a65d633f04e3090bd786c405bd Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Mon, 22 Apr 2024 09:18:59 +0000 Subject: [PATCH 2/4] Remove unnecessary import and fix linting --- python/tvm/topi/arm_cpu/conv2d.py | 1 + src/arith/analyzer.cc | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index b6b5df8677ad..44c4f7f76f69 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -528,6 +528,7 @@ def compute_conv2d_NHWC( interleave_A, use_scalable_vectors=False, ): + """Compute definition for conv2d NHWC""" N, IH, IW, IC = get_const_tuple(data.shape) KH, KW, _, OC = get_const_tuple(kernel.shape) tile_N, tile_K = get_tiling_B_transformed(interleave_A, data.dtype, use_scalable_vectors) diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 23e86cf9e7af..0c4248bd3f26 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -25,7 +25,6 @@ #include #include -#include "../tir/analysis/check_contains.h" #include "./scalable_expression.h" #include "const_fold.h" #include "product_normal_form.h" @@ -234,7 +233,7 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { // "T.vscale" and the compile target uses a scalable architecture extension like // SVE, we can make some assumptions about the value of vscale and iterate over a // space of pre-defined values to attempt to prove the expression. - if (tir::CheckContains::ExprContains(simplified, IsVScaleCall)) { + if (ContainsVscaleCall(simplified)) { if (TargetHasSVE()) { return CanProveVscaleExpressionFromKnownValues(this, simplified, kAArch64VScaleValues); } From 63a5001120e0b92539a18068c32c6e82e5dbfef3 Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Mon, 22 Apr 2024 14:49:55 +0000 Subject: [PATCH 3/4] Address comments and fix tests --- python/tvm/relay/op/strategy/arm_cpu.py | 1 + python/tvm/topi/arm_cpu/arm_utils.py | 85 +++++++++++++++ python/tvm/topi/arm_cpu/conv2d_gemm.py | 101 ++++++------------ src/arith/const_int_bound.cc | 3 +- .../test_tir_schedule_split_fuse.py | 86 +++++++-------- tests/python/topi/test_topi_conv2d_nhwc.py | 15 ++- 6 files changed, 175 insertions(+), 116 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 88202458afbf..2fc148c3effd 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -257,6 +257,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): # for scalable vector alias analysis, which causes redundant loads / stores # to remain after LLVM's optimisation passes, unlike the non-scalable case. # Hence, it is given a lower priority level until these issues are resolved. + # Last checked manually using: LLVM 18.1.0 strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_SVE), diff --git a/python/tvm/topi/arm_cpu/arm_utils.py b/python/tvm/topi/arm_cpu/arm_utils.py index 0dd17ce4fa34..c350b87167b2 100644 --- a/python/tvm/topi/arm_cpu/arm_utils.py +++ b/python/tvm/topi/arm_cpu/arm_utils.py @@ -21,6 +21,59 @@ from tvm.target import Target +def get_tiling_A(interleave_A, in_dtype): + """Compute the tiling information for matrix A in C=A*B, + which corresponds to the im2col-transformed input matrix. + + The tiling information is chosen to maximize register usage during + the tile computation. + + Please refer to: + - https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-performance-for-armv8-architectures # pylint: disable=line-too-long + - https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product + - https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-through-mmla-instruction + - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h + In order to have more information + + Parameters + ---------- + interleave_A : bool + determines if A is expected to be interleaved + in_dtype : str + input datatype + + Returns + ---------- + tile_M: the output tile size of A on M axis (M = OH * OW) + tile_K: the output tile size of A on K axis (K = KW * KH * IC) + """ + target = Target.current(allow_none=False) + if in_dtype in ["int8", "uint8"]: + if target.features.has_matmul_i8: + # If smmla/ummla is enabled, we are loading 8 rows from A. Each row + # will contain 8 elements + tile_M = 8 + tile_K = 8 + elif target.features.has_dotprod and interleave_A: + # If dot product has been enabled, and we are interleaving A + # tile size should be 8x4 + tile_M = 8 + tile_K = 4 + else: + # If either there is no dot product or if we are using a native strategy + # tile size should be 4x16 + tile_M = 4 + tile_K = 16 + else: + # In non-quantized cases, A is not interleaved. + # We are loading 4 rows from A. + # Each row will contain 4 elements, along the dimension of reduction + tile_M = 4 + tile_K = 4 + + return tile_M, tile_K + + def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False): """Compute the tiling information for matrix B', where B' is the tiled, interleaved (and transposed) version of matrix B in C=A*B. @@ -101,6 +154,38 @@ def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False) return tile_N, tile_K +def get_conv2d_im2col_padding(M, K, tile_M, tile_K): + """Compute the necessary padding for matrix A in C=A*B, + which corresponds to the im2col-transformed input matrix. + + Parameters + ---------- + M : int + Number of rows in A = OH * OW + K : int + Number of columns in A = KW * KH * IC + tile_M : int + tile size of A on M axis + tile_K : int + tile size of A on K axis + + Returns + ---------- + pad_M : padding for M axis + pad_K : padding for K axis + """ + pad_M = 0 + pad_K = 0 + + if M % tile_M != 0: + pad_M = tile_M - (M % tile_M) + + if K % tile_K != 0: + pad_K = tile_K - (K % tile_K) + + return pad_M, pad_K + + def get_conv2d_weights_padding(N, K, tile_N, tile_K): """Compute the necessary padding for matrix B', where B' is the transformed version of matrix B in C=A*B. diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index c454d72e2642..26a65f0f224d 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -21,8 +21,8 @@ from tvm.target import Target from tvm import te from tvm.topi import nn +from tvm.topi.arm_cpu import arm_utils from tvm.autotvm.task.space import AnnotateEntity, ReorderEntity, OtherOptionEntity -from tvm.topi.arm_cpu.arm_utils import get_tiling_B_transformed from ..utils import get_const_tuple, get_const_int from ..nn.utils import get_pad_tuple from .tensor_intrin import ( @@ -93,6 +93,8 @@ def compute_conv2d_gemm_without_weight_transform( OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 + + # Input padding (if necessary) if pad_top or pad_left or pad_down or pad_right: data_pad = nn.pad( data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], name="data_pad" @@ -100,7 +102,7 @@ def compute_conv2d_gemm_without_weight_transform( else: data_pad = data - # Im2col + # Im2col transformation M = OH * OW K = IC * kernel_area N = OC @@ -120,65 +122,19 @@ def compute_conv2d_gemm_without_weight_transform( name="data_im2col", ) - # Pad if necessary - tile_N, tile_K_B = get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors) - - # Select the tiling strategy for A. - # The tiling information is chosen to maximize register usage during - # the tile computation. - # - # Please refer to: - # - https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-performance-for-armv8-architectures # pylint: disable=line-too-long - # - https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product - # - https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-through-mmla-instruction - # - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h - # In order to have more information - # - target = Target.current(allow_none=False) - if in_dtype in ["int8", "uint8"]: - if target.features.has_matmul_i8: - # If smmla/ummla is enabled, we are loading 8 rows from A. Each row - # will contain 8 elements - tile_M = 8 - tile_K_A = 8 - elif target.features.has_dotprod and interleave_A: - # If dot product has been enabled, and we are interleaving A - # tile size should be 8x4 - tile_M = 8 - tile_K_A = 4 - else: - # If either there is no dot product or if we are using a native strategy - # tile size should be 4x16 - tile_M = 4 - tile_K_A = 16 - else: - # In non-quantized cases, A is not interleaved. - # We are loading 4 rows from A. - # Each row will contain 4 elements, along the dimension of reduction - tile_M = 4 - tile_K_A = 4 - - pad_M = 0 - pad_N = 0 - pad_K = 0 - - if M % tile_M != 0: - pad_M = tile_M - (M % tile_M) - - if K % tile_K_A != 0: - pad_K = tile_K_A - (K % tile_K_A) + # Select the tiling strategy for A and B + tile_M, tile_K_A = arm_utils.get_tiling_A(interleave_A, in_dtype) + tile_N, tile_K_B = arm_utils.get_tiling_B_transformed( + interleave_A, in_dtype, use_scalable_vectors + ) - if N % tile_N != 0: - pad_N = tile_N - (N % tile_N) + # Pad to tiles (if necessary) + pad_M, pad_K = arm_utils.get_conv2d_im2col_padding(M, K, tile_M, tile_K_A) + pad_N, _ = arm_utils.get_conv2d_weights_padding(N, K, tile_N, tile_K_B) M_padded = M + pad_M K_padded = K + pad_K - - if use_scalable_vectors: - N_padded = N + pad_N - else: - N_transformed = B_interleaved_t.shape[0] - N_padded = N_transformed * tile_N + N_padded = N + pad_N pad_before = (0, 0, 0) pad_after = (0, pad_M, pad_K) @@ -191,7 +147,10 @@ def compute_conv2d_gemm_without_weight_transform( idxm = tvm.tir.indexmod k = te.reduce_axis((0, K_padded), "k") + # Determine matrix multiplication compute definition + target = Target.current(allow_none=False) if in_dtype in ["int8", "uint8"]: + assert len(B_interleaved_t.shape) == 4 if interleave_A: # Configuration space configure_knobs(cfg, M_padded, K_padded, target) @@ -208,7 +167,7 @@ def compute_conv2d_gemm_without_weight_transform( lambda b, x, y, z, w: A[b, z + tile_M * x, w + tile_K_A * y], name="A_interleaved", ) - target = Target.current(allow_none=False) + N_transformed = B_interleaved_t.shape[0] if target.features.has_matmul_i8: # Execute GEMM. In the case of mmla, we need to enforce the tiling # from the compute. This is because mmla is doing a tiled computation @@ -384,6 +343,8 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): C = out.op.input_tensors[0] C_interleaved = C.op.input_tensors[0] A_interleaved = C_interleaved.op.input_tensors[0] + in_type = A_interleaved.dtype + tile_M, tile_K = arm_utils.get_tiling_A(True, in_type) # Input transform A_interleaved_input = A_interleaved.op.input_tensors[0] @@ -422,9 +383,6 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): s, A_interleaved, [outer_A_interleaved, inner_A_interleaved] ) - in_type = A_interleaved.dtype - out_type = C.dtype - k = C_interleaved.op.reduce_axis[0] _, M, N = C.shape if in_type in ["int8", "uint8"]: @@ -432,7 +390,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): if target.features.has_matmul_i8: gemm_acc = gemm_acc_2x2_int8_int8_int32(in_type) xi_inner, yi_inner = C_interleaved.op.axis[-2:] - k_outer, k_inner = s[C_interleaved].split(k, 8) + k_outer, k_inner = s[C_interleaved].split(k, tile_K) s[C_interleaved].reorder( b_outer_gemm_fused, inner_gemm, k_outer, xi, yi, xi_inner, yi_inner, k_inner ) @@ -442,9 +400,9 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): elif target.features.has_dotprod: gemm_acc = gemm_acc_4x4_int8_int8_int32(in_type) xi_outer, yi_outer, xi_inner, yi_inner = s[C_interleaved].tile( - xi, yi, x_factor=8, y_factor=4 + xi, yi, x_factor=tile_M, y_factor=4 ) - k_outer, k_inner = s[C_interleaved].split(k, 4) + k_outer, k_inner = s[C_interleaved].split(k, tile_K) xi_inner_outer, xi_inner_inner = s[C_interleaved].split(xi_inner, 4) s[C_interleaved].reorder( b_outer_gemm_fused, @@ -483,24 +441,25 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): A = C.op.input_tensors[0] in_type = A.dtype use_scalable_vectors = out.op.attrs["use_scalable_vectors"].value - y_tile_size, _ = get_tiling_B_transformed(False, in_type, use_scalable_vectors) + tile_M, tile_K = arm_utils.get_tiling_A(False, in_type) + tile_N, _ = arm_utils.get_tiling_B_transformed(False, in_type, use_scalable_vectors) # Computation b, x, y = C.op.axis (k,) = C.op.reduce_axis if in_type in ["int8", "uint8"]: - k_outer, k_inner = s[C].split(k, 16) - x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size) + k_outer, k_inner = s[C].split(k, tile_K) + x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=tile_M, y_factor=tile_N) s[C].reorder(b, x_outer, y_outer, k_outer, x_inner, y_inner, k_inner) gemm_acc = gemm_acc_nx16_int8_int8_int32(in_type, rows=1) s[C].unroll(x_inner) s[C].tensorize(y_inner, gemm_acc) s[C].parallel(x_outer) else: - k_outer, k_inner = s[C].split(k, factor=4) - x_outer, x_inner = s[C].split(x, factor=4) - y_outer, y_inner = s[C].split(y, factor=y_tile_size, disable_predication=True) + k_outer, k_inner = s[C].split(k, factor=tile_K) + x_outer, x_inner = s[C].split(x, factor=tile_M) + y_outer, y_inner = s[C].split(y, factor=tile_N, disable_predication=use_scalable_vectors) b_x_outer_fused = s[C].fuse(b, x_outer) s[C].parallel(b_x_outer_fused) s[C].reorder( @@ -552,7 +511,7 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): s[data_im2col].vectorize(n_inner) elif padding_A: s[data_im2col].compute_inline() - _, n_inner = s[A].split(A.op.axis[2], y_tile_size) + _, n_inner = s[A].split(A.op.axis[2], tile_N) s[A].vectorize(n_inner) s[A].compute_at(s[C], x_inner) else: diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 3b4c6a62c7fa..b82fff218f68 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -31,6 +31,7 @@ #include "constraint_extract.h" #include "int_operator.h" #include "pattern_match.h" +#include "scalable_expression.h" namespace tvm { namespace arith { @@ -369,7 +370,7 @@ class ConstIntBoundAnalyzer::Impl return VisitLeftShift(op); } else if (op->op.same_as(tir::builtin::bitwise_and())) { return VisitBitwiseAnd(op); - } else if (op->op.same_as(tir::builtin::vscale())) { + } else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE()) { return MakeBound(1, 16); } else { return Everything(op->dtype); diff --git a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py index 93c36ef67218..f5e5b3b54e76 100644 --- a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py +++ b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py @@ -662,32 +662,31 @@ def test_sve_scalable_split_predicated(num_elements): compile-time, we don't know if vscale is a multiple of the extent of the loop to be split. """ - - @T.prim_func - def before(a: T.handle): - A = T.match_buffer(a, (num_elements,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) - for i in T.serial(num_elements): - with T.block("A"): - v_i = T.axis.remap("S", [i]) - A[v_i] = 1.0 - - @T.prim_func - def after(a: T.handle): - A = T.match_buffer(a, (num_elements,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) - for i_0, i_1 in T.grid( - (T.vscale() * 4 + (num_elements - 1)) // (T.vscale() * 4), T.vscale() * 4 - ): - with T.block("A"): - v_i = T.axis.spatial(num_elements, i_0 * (T.vscale() * 4) + i_1) - T.where(i_0 * (T.vscale() * 4) + i_1 < num_elements) - A[v_i] = 1.0 - with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"): + outer_extent = tvm.arith.Analyzer().simplify(T.ceildiv(num_elements, 4 * T.vscale())) + + @T.prim_func + def before(a: T.handle): + A = T.match_buffer(a, (num_elements,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i in T.serial(num_elements): + with T.block("A"): + v_i = T.axis.remap("S", [i]) + A[v_i] = 1.0 + + @T.prim_func + def after(a: T.handle): + A = T.match_buffer(a, (num_elements,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i_0, i_1 in T.grid(outer_extent, T.vscale() * 4): + with T.block("A"): + v_i = T.axis.spatial(num_elements, i_0 * (T.vscale() * 4) + i_1) + T.where(i_0 * (T.vscale() * 4) + i_1 < num_elements) + A[v_i] = 1.0 + sch = tvm.tir.Schedule(before) (a,) = sch.get_loops("A") - sch.split(a, factors=[T.ceildiv(num_elements, 4 * T.vscale()), 4 * T.vscale()]) + sch.split(a, factors=[outer_extent, 4 * T.vscale()]) tvm.ir.assert_structural_equal(sch.mod["main"], after) @@ -699,31 +698,32 @@ def test_sve_scalable_split_assume_exact_multiple(): a predicate is not created. This can be used to ensure predication is not inserted. """ - - @T.prim_func - def before(a: T.handle): - A = T.match_buffer(a, (128,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) - for i in T.serial(128): - with T.block("A"): - v_i = T.axis.remap("S", [i]) - A[v_i] = 1.0 - - @T.prim_func - def after(a: T.handle): - A = T.match_buffer(a, (128,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) - for i_0, i_1 in T.grid((T.vscale() * 4 + (128 - 1)) // (T.vscale() * 4), T.vscale() * 4): - with T.block("A"): - v_i = T.axis.spatial(128, i_0 * (T.vscale() * 4) + i_1) - A[v_i] = 1.0 - with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"): + outer_extent = tvm.arith.Analyzer().simplify(T.ceildiv(128, 4 * T.vscale())) + + @T.prim_func + def before(a: T.handle): + A = T.match_buffer(a, (128,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i in T.serial(128): + with T.block("A"): + v_i = T.axis.remap("S", [i]) + A[v_i] = 1.0 + + @T.prim_func + def after(a: T.handle): + A = T.match_buffer(a, (128,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i_0, i_1 in T.grid(outer_extent, T.vscale() * 4): + with T.block("A"): + v_i = T.axis.spatial(128, i_0 * (T.vscale() * 4) + i_1) + A[v_i] = 1.0 + sch = tvm.tir.Schedule(before) (a,) = sch.get_loops("A") sch.split( a, - factors=[T.ceildiv(128, 4 * T.vscale()), 4 * T.vscale()], + factors=[outer_extent, 4 * T.vscale()], disable_predication=True, ) diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py b/tests/python/topi/test_topi_conv2d_nhwc.py index 0084d3f4b647..e9e532ef4c6d 100644 --- a/tests/python/topi/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/test_topi_conv2d_nhwc.py @@ -67,12 +67,25 @@ dtype = tvm.testing.parameter("float32") batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = tvm.testing.parameters( + # Pad M, N, K (1, 1, 3, 15, 1, 1, "SAME", 1), + # Pad M, K + (1, 3, 9, 16, 3, 1, "SAME", 1), + # Pad M, N + (1, 2, 9, 15, 4, 1, "SAME", 1), + # Pad K, N + (1, 7, 4, 15, 3, 1, "SAME", 1), + # Pad M + (1, 2, 9, 16, 4, 1, "SAME", 1), + # Pad K + (1, 7, 4, 16, 3, 1, "SAME", 1), + # Pad N + (1, 2, 4, 15, 4, 1, "SAME", 1), + # Large workloads (1, 256, 32, 256, 3, 1, "SAME", 1), (4, 128, 16, 128, 5, 2, "SAME", 1), (4, 128, 16, 256, 5, 2, "SAME", 1), (1, 256, 32, 256, 3, 1, "VALID", 1), - (1, 256, 32, 256, 3, 1, "VALID", 1), (4, 128, 16, 128, 5, 2, "VALID", 1), (4, 128, 16, 256, 5, 2, "VALID", 1), (1, 128, 16, 256, 3, 2, (0, 0, 1, 1), 1), From 16b015d9199ef201d083762a689af8e3a20fd29b Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Tue, 23 Apr 2024 08:39:43 +0000 Subject: [PATCH 4/4] Fix scalable index `rewrite_simplify` tests --- .../arith/test_arith_rewrite_simplify.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 816c85b834ee..fcb6aa572910 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -745,11 +745,6 @@ class TestMinIndex(BaseCompare): TestCase(tvm.te.min(tvm.te.max(x, 4), fld(x + 3, 4) * 4), tvm.te.max(x, 4), x > 0), TestCase(tvm.te.min(fld(x, 10), fld(y, 10)), fld(tvm.te.min(x, y), 10)), TestCase(tvm.te.min(fld(x, (-10)), fld(y, (-10))), fld(tvm.te.max(x, y), (-10))), - # vscale expression comparison - TestCase(tvm.te.min(x + tir.vscale() * 4, x), x), - TestCase(tvm.te.min(x - tir.vscale() * 4, x), x + tir.vscale() * -4), - TestCase(tvm.te.min(x + tir.vscale() * 4, x + tir.vscale() * 8), tir.vscale() * 4 + x), - TestCase(tvm.te.min(x + tir.vscale() * 4 - flm(4, tir.vscale() * 4), x), x), ) @@ -816,7 +811,19 @@ class TestMaxIndex(BaseCompare): TestCase(tvm.te.max(fld(x + 3, 4) * 4, x), fld(x + 3, 4) * 4), TestCase(tvm.te.max(fld(x, 4) * 4, x), x), TestCase(tvm.te.max(x, fld(x, 4) * 4), x), - # vscale expression comparison + ) + + +class TestScalableIndex(BaseCompare): + x, y = te.var("x"), te.var("y") + test_case = tvm.testing.parameter( + # MinNode + TestCase(tvm.te.min(x + tir.vscale() * 4, x), x), + TestCase(tvm.te.min(x - tir.vscale() * 4, x), x + tir.vscale() * -4), + TestCase(tvm.te.min(x + tir.vscale() * 4, x + tir.vscale() * 8), tir.vscale() * 4 + x), + TestCase(tvm.te.min(x + tir.vscale() * 4 - flm(4, tir.vscale() * 4), x), x), + TestCase(tvm.te.min(tir.vscale() * x, tir.vscale() * y), tir.vscale() * x, x < y), + # MaxNode TestCase(tvm.te.max(x + tir.vscale() * 4, x), x + tir.vscale() * 4), TestCase(tvm.te.max(x - tir.vscale() * 4, x), x), TestCase(tvm.te.max(x + tir.vscale() * 4, x + tir.vscale() * 4), x + tir.vscale() * 4), @@ -824,8 +831,13 @@ class TestMaxIndex(BaseCompare): tvm.te.max(x + tir.vscale() * 4 - flm(4, tir.vscale() * 4), x), x + tir.vscale() * 4 - flm(4, tir.vscale() * 4), ), + TestCase(tvm.te.max(tir.vscale() * x, tir.vscale() * y), tir.vscale() * x, x > y), ) + def test_simplify(self, test_case): + with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"): + super().test_simplify(test_case) + class TestComparisons(BaseCompare): x, y, z = te.var("x"), te.var("y"), te.var("z")