From 2c5d90c396f1b6a84f1a9efe851b1fc677d08b8b Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Tue, 16 Apr 2024 16:08:13 +0100 Subject: [PATCH] [SVE] Check for SVE target in VectorizeLoop Check that we are compiling for an SVE enabled target when the extent of a loop marked for vectorizing has a vscale dependent extent. --- src/driver/driver_api.cc | 4 + src/tir/transforms/vectorize_loop.cc | 25 +++- .../test_tir_transform_vectorize.py | 109 +++++++++++++----- 3 files changed, 106 insertions(+), 32 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7ea5032fa0cc9..e88137989969f 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -161,6 +161,7 @@ Array CreatePassList(bool disable_loop_partition) { .value(); bool instrument_lwp = pass_ctx->GetConfig("tir.instrument_lwp", Bool(false)).value(); + Target current_target = Target::Current(); Array user_lower_phase0 = Array(); Array user_lower_phase1 = Array(); @@ -196,6 +197,9 @@ Array CreatePassList(bool disable_loop_partition) { Array pass_list = user_lower_phase0; // PHASE 1 + if (current_target.defined()) { + pass_list.push_back(tir::transform::BindTarget(current_target)); + } pass_list.push_back(tir::transform::InjectPrefetch()); pass_list.push_back(tir::transform::TextureFlatten()); pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index a9cc4975801a0..541ec80bbccf3 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -34,6 +34,9 @@ #include #include +#include "../../src/arith/scalable_expression.h" +#include "../../tir/analysis/check_contains.h" + namespace tvm { namespace tir { @@ -725,17 +728,33 @@ class Vectorizer : public StmtMutator, public ExprFunctorattrs.GetAttr(tvm::attr::kTarget); + if (target.defined()) { + target_ = Downcast(target); + has_sve_ = target_->GetFeature("has_sve").value_or(Bool(false)); + } + } + Stmt VisitStmt_(const ForNode* op) final { if (op->kind == ForKind::kVectorized) { + auto* extent_as_int = op->extent.as(); + if (!extent_as_int || extent_as_int->value < 1) { + bool is_scalable_expr = CheckContains::ExprContains(op->extent, arith::IsVScaleCall); + ICHECK(is_scalable_expr && has_sve_) + << "Failed to vectorize loop with extent " << op->extent << " for target " << target_; + } ICHECK(is_zero(op->min)); return Vectorizer(op->loop_var, op->extent)(op->body); } else { return StmtMutator::VisitStmt_(op); } } -}; -Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer()(std::move(stmt)); } + private: + bool has_sve_{false}; + Target target_{}; +}; class VectorizeSkipper : public StmtMutator { public: @@ -759,7 +778,7 @@ Pass VectorizeLoop(bool enable_vectorize) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); if (enable_vectorize) { - n->body = LoopVectorizer()(std::move(n->body)); + n->body = LoopVectorizer(f)(std::move(n->body)); } else { n->body = VectorizeSkipper()(std::move(n->body)); } diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index dbca006b19cb3..4dce7def8604f 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -22,12 +22,17 @@ import pytest -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_loop(extent): +simple_target = tvm.target.Target("llvm -mtriple=x86_64-linux-gnu") +sve_target = tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+sve") + + +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_loop(extent, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((16,), "float32")): + T.func_attr({"target": target}) for j in T.vectorized(0, extent): A[j] = 1 @@ -35,6 +40,7 @@ def main(A: T.Buffer((16,), "float32")): class After: @T.prim_func def main(A: T.Buffer((16,), "float32")): + T.func_attr({"target": target}) A[T.Ramp(0, 1, extent)] = T.Broadcast(1, extent) mod = tvm.tir.transform.VectorizeLoop()(Before) @@ -66,6 +72,7 @@ def test_vectorize_vector_scalable_error(): class Module: @T.prim_func def main(A: T.Buffer((25,), "float32")): + T.func_attr({"target": sve_target}) for j in T.vectorized(T.vscale() * 4): A[j * 4 : j * 4 + 4] = T.Broadcast(T.float32(1), 4) @@ -99,7 +106,8 @@ def main(A: T.Buffer((25,), "float32")): error_msg = f"Vectorizing over existing scalable vectors is not supported." with pytest.raises(tvm.error.InternalError, match=error_msg): - tvm.tir.transform.VectorizeLoop()(Module) + with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"): + tvm.tir.transform.VectorizeLoop()(Module) def test_vectorize_vector_scalable_error4(): @@ -107,6 +115,7 @@ def test_vectorize_vector_scalable_error4(): class Module: @T.prim_func(private=True) def main(A: T.Buffer((25,), "float32")): + T.func_attr({"target": sve_target}) for j in T.vectorized(T.vscale() * 4): A[j * T.vscale() * 4 : j * T.vscale() * 4 + T.vscale() * 4] = T.Broadcast( T.float32(1), T.vscale() * 4 @@ -114,15 +123,17 @@ def main(A: T.Buffer((25,), "float32")): error_msg = f"Creating scalable vectors from existing vectors is not supported." with pytest.raises(tvm.error.InternalError, match=error_msg): - tvm.tir.transform.VectorizeLoop()(Module) + with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"): + tvm.tir.transform.VectorizeLoop()(Module) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_with_if(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_with_if(extent, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): + T.func_attr({"target": target}) for i in T.vectorized(extent): if x < n: A[i] = A[i] + T.float32(1) @@ -134,6 +145,7 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): class After: @T.prim_func def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): + T.func_attr({"target": target}) if x < n: A[T.Ramp(0, 1, extent)] = A[T.Ramp(0, 1, extent)] + T.Broadcast( T.float32(1), extent @@ -157,12 +169,13 @@ def test_vectorize_with_if_cond_int64(): f = tvm.build(s, [A, B], "llvm") -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_let(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_let(extent, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) for i in T.vectorized(extent): v = A[i] + T.float32(1) A[i] = v + T.float32(2) @@ -171,6 +184,7 @@ def main(A: T.Buffer((25,), "float32")): class After: @T.prim_func def main(A: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) v = A[T.Ramp(0, 1, extent)] + T.Broadcast(T.float32(1), extent) A[T.Ramp(0, 1, extent)] = v + T.Broadcast(T.float32(2), extent) @@ -178,8 +192,8 @@ def main(A: T.Buffer((25,), "float32")): tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent", (4, tvm.tir.vscale() * 4)) -def test_vectorize_with_le_cond(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)]) +def test_vectorize_with_le_cond(extent, target): n = te.var("n") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") @@ -188,15 +202,16 @@ def test_vectorize_with_le_cond(extent): A[i] = A[i] + 1 stmt = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) + func = tvm.tir.PrimFunc([A, n], stmt).with_attr("target", target) + mod = tvm.IRModule.from_expr(func) stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body # Check that the loop was't vectorised assert isinstance(stmt, tvm.tir.For) -@pytest.mark.parametrize("extent", (4, tvm.tir.vscale() * 4)) -def test_vectorize_with_ge_cond(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)]) +def test_vectorize_with_ge_cond(extent, target): n = te.var("n") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") @@ -205,19 +220,21 @@ def test_vectorize_with_ge_cond(extent): A[i] = A[i] + 1 stmt = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) + func = tvm.tir.PrimFunc([A, n], stmt).with_attr("target", target) + mod = tvm.IRModule.from_expr(func) stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body # Check that the loop wasn't vectorised assert isinstance(stmt, tvm.tir.For) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_if_then_else_scalarize(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_if_then_else_scalarize(extent, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) for i in T.vectorized(extent): A[i] = T.if_then_else(i > 0, A[i] + T.float32(1), A[i]) @@ -225,6 +242,7 @@ def main(A: T.Buffer((25,), "float32")): class After: @T.prim_func def main(A: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) for i_s in range(extent): A[i_s] = T.if_then_else(i_s > 0, A[i_s] + T.float32(1), A[i_s]) @@ -232,12 +250,13 @@ def main(A: T.Buffer((25,), "float32")): tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_if_then_else_vector(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_if_then_else_vector(extent, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((25,), "float32"), n: T.int32): + T.func_attr({"target": target}) for i in range(n): for j in T.vectorized(extent): A[i * extent + j] = T.if_then_else(i > 0, A[i * extent + j], 0) @@ -246,6 +265,7 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32): class After: @T.prim_func def main(A: T.Buffer((25,), "float32"), n: T.int32): + T.func_attr({"target": target}) for i in range(n): A[T.Ramp(i * extent, 1, extent)] = T.if_then_else( i > 0, A[T.Ramp(i * extent, 1, extent)], T.Broadcast(0, extent) @@ -311,13 +331,15 @@ def test_vectorize_dtype_mismatch(): @pytest.mark.parametrize( - "extent, vec_str", [(16, "float32x16"), (T.vscale() * 8, "float32xvscalex8")] + "extent, vec_str, target", + [(16, "float32x16", simple_target), (T.vscale() * 8, "float32xvscalex8", sve_target)], ) -def test_vectorize_with_reinterpret(extent, vec_str): +def test_vectorize_with_reinterpret(extent, vec_str, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): + T.func_attr({"target": target}) for i in T.vectorized(0, extent): B[i] = T.reinterpret("float32", A[i]) @@ -325,13 +347,14 @@ def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): class After: @T.prim_func def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): + T.func_attr({"target": target}) B[T.Ramp(0, 1, extent)] = T.reinterpret(vec_str, A[T.Ramp(0, 1, extent)]) mod = tvm.tir.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @pytest.mark.parametrize( "op", ( @@ -352,11 +375,12 @@ def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): T.NE, ), ) -def test_vectorize_binary(op, extent): +def test_vectorize_binary(op, extent, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) for j in T.vectorized(extent): A[j] = op(T.float32(3), B[j]) @@ -364,19 +388,21 @@ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): class After: @T.prim_func def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.float32(3), extent), B[T.Ramp(0, 1, extent)]) mod = tvm.tir.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @pytest.mark.parametrize("op", (T.And, T.Or)) -def test_vectorize_logical(op, extent): +def test_vectorize_logical(op, extent, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")): + T.func_attr({"target": target}) for j in T.vectorized(extent): A[j] = op(T.bool(1), B[j]) @@ -384,18 +410,20 @@ def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")): class After: @T.prim_func def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")): + T.func_attr({"target": target}) A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.bool(1), extent), B[T.Ramp(0, 1, extent)]) mod = tvm.tir.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_select(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_select(extent, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) for j in T.vectorized(extent): A[j] = T.Select(T.bool(True), A[j], B[j]) @@ -403,6 +431,7 @@ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): class After: @T.prim_func def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) A[T.Ramp(0, 1, extent)] = T.Select( T.Broadcast(T.bool(True), extent), A[T.Ramp(0, 1, extent)], @@ -413,12 +442,16 @@ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent, vec_str", [(4, "int32x4"), (T.vscale() * 4, "int32xvscalex4")]) -def test_vectorize_cast(extent, vec_str): +@pytest.mark.parametrize( + "extent, vec_str, target", + [(4, "int32x4", simple_target), (T.vscale() * 4, "int32xvscalex4", sve_target)], +) +def test_vectorize_cast(extent, vec_str, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) for j in T.vectorized(extent): A[j] = T.Cast("int32", B[j]) @@ -426,6 +459,7 @@ def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): class After: @T.prim_func def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) A[T.Ramp(0, 1, extent)] = T.Cast(vec_str, B[T.Ramp(0, 1, extent)]) mod = tvm.tir.transform.VectorizeLoop()(Before) @@ -441,10 +475,27 @@ def main(A: T.Buffer((25,), "int32")): for j in T.vectorized(n): A[j] = 3 - error_msg = f"Invalid expression for scalable lanes n" + error_msg = f"Failed to vectorize loop with extent n for target \\(nullptr\\)" with pytest.raises(tvm.error.InternalError, match=error_msg): tvm.tir.transform.VectorizeLoop()(Mod) +def test_illegal_vscale_in_non_sve_compilation(): + @I.ir_module + class Mod: + @T.prim_func + def main(A: T.Buffer((16,), "float32")): + T.func_attr({"target": simple_target}) + for j in T.vectorized(0, 4 * T.vscale()): + A[j] = 13 + + msg = ( + f"Failed to vectorize loop with extent T.vscale\\(\\) \\* 4 for target " + f"llvm -keys=cpu -mtriple=x86_64-linux-gnu" + ) + with pytest.raises(tvm.error.InternalError, match=msg): + tvm.tir.transform.VectorizeLoop()(Mod) + + if __name__ == "__main__": tvm.testing.main()