From 278e629ccc02b5e867eb2b27aa00c5d14140bb3f Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Mon, 25 Mar 2024 12:05:07 +0000 Subject: [PATCH 1/3] [SVE] Support scalable vectors in LoopVectorizer This patch add support for turning loops marked for vectorizing into scalable vectors if the extent of the loop is a vscale dependent expression in a correct form. The testing for both scalable and fixed length vectors in test_tir_transform.py has been extended and most of the tests have been converted to TVMScript based testing against expected output. Co-authored-by: Luke Hutton Co-authored-by: Neil Hickey --- include/tvm/tir/op.h | 11 +- src/tir/ir/expr.cc | 14 +- src/tir/transforms/vectorize_loop.cc | 184 +++++++--- .../test_tir_transform_vectorize.py | 343 +++++++++++++----- 4 files changed, 405 insertions(+), 147 deletions(-) diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index ce4a4d6a2845..d06bb779d0bb 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -959,10 +960,16 @@ inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) { template inline PrimExpr make_const(DataType t, ValueType value, Span span) { - if (t.lanes() == 1) { + if (t.is_scalar()) { return MakeConstScalar(t, value, span); } else { - return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span); + if (t.is_fixed_length_vector()) { + return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span); + } else { + PrimExpr lanes = + tir::Mul(tir::Call(DataType::Int(32), tir::builtin::vscale(), {}), t.vscale_factor()); + return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), lanes, span); + } } } diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 90dad720393f..6a5585f6aeff 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -196,7 +196,9 @@ TVM_REGISTER_NODE_TYPE(StringImmNode); // Cast Cast::Cast(DataType t, PrimExpr value, Span span) { ICHECK(value.defined()); - ICHECK_EQ(t.lanes(), value.dtype().lanes()); + ICHECK_EQ(t.get_lanes_or_vscale_factor(), value.dtype().get_lanes_or_vscale_factor()); + ICHECK((t.is_scalable_vector() == value.dtype().is_scalable_vector()) || + (!t.is_scalable_vector() && !value.dtype().is_scalable_vector())); ObjectPtr node = make_object(); node->dtype = t; node->value = std::move(value); @@ -354,7 +356,8 @@ And::And(PrimExpr a, PrimExpr b, Span span) { ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; ObjectPtr node = make_object(); - node->dtype = DataType::Bool(a.dtype().lanes()); + node->dtype = + DataType::Bool(a.dtype().get_lanes_or_vscale_factor(), a.dtype().is_scalable_vector()); node->a = std::move(a); node->b = std::move(b); node->span = std::move(span); @@ -376,7 +379,8 @@ Or::Or(PrimExpr a, PrimExpr b, Span span) { ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; ObjectPtr node = make_object(); - node->dtype = DataType::Bool(a.dtype().lanes()); + node->dtype = + DataType::Bool(a.dtype().get_lanes_or_vscale_factor(), a.dtype().is_scalable_vector()); node->a = std::move(a); node->b = std::move(b); node->span = std::move(span); @@ -412,7 +416,9 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp ICHECK(true_value.defined()) << "ValueError: true_value is undefined"; ICHECK(false_value.defined()) << "ValueError: true_value is undefined"; ICHECK(condition.dtype().is_bool()); - ICHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1); + ICHECK(condition.dtype().get_lanes_or_vscale_factor() == + true_value.dtype().get_lanes_or_vscale_factor() || + condition.dtype().is_scalar()); ICHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types. " << "False type: " << false_value.dtype() << "; True type: " << true_value.dtype(); diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 57536422cf64..4b4694d6edfe 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -37,19 +37,35 @@ namespace tvm { namespace tir { -// TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455 -inline PrimExpr BroadcastTo(PrimExpr e, int lanes) { - if (e.dtype().lanes() == lanes) return e; +inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) { + if (is_scalable) { + return Mul(Call(DataType::Int(32), builtin::vscale(), {}), lanes_or_vscale_factor); + } else + return lanes_or_vscale_factor; +} + +inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { + // Check if e is already in the expected form + if (e.dtype().get_lanes_or_vscale_factor() == lanes && + e.dtype().is_scalable_vector() == is_scalable) + return e; + if (const BroadcastNode* op = e.as()) { - ICHECK(!e.dtype().is_scalable_vector()); - int broadcast_lanes = static_cast(Downcast(op->lanes)->value); - if (lanes % broadcast_lanes == 0) { - return Broadcast(op->value, lanes); + ICHECK(op->dtype.is_scalable_vector() == is_scalable) + << "Can't broadcast between scalable and fixed length vectors."; + int e_lanes = is_scalable ? op->dtype.vscale_factor() : op->dtype.lanes(); + + if (lanes % e_lanes == 0) { + return Broadcast(op->value, CreateNewLanes(is_scalable, lanes)); } } - ICHECK_EQ(e.dtype().lanes(), 1) << "Cannot broadcast lane=" << e.dtype().lanes() << " to " - << lanes; - return Broadcast(e, lanes); + + ICHECK(e.dtype().is_scalar()) << "Cannot broadcast lanes=" + << e.dtype().get_lanes_or_vscale_factor() + << " is_scalable=" << e.dtype().is_scalable_vector() << " to " + << lanes; + + return Broadcast(e, CreateNewLanes(is_scalable, lanes)); } // Rewrite vectorized allocation access @@ -62,7 +78,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes) { // class VecAllocAccess : public StmtExprMutator { public: - VecAllocAccess(const VarNode* buf, Var var, int var_lanes) + VecAllocAccess(const VarNode* buf, Var var, PrimExpr var_lanes) : buf_(buf), var_(var), var_lanes_(var_lanes) {} PrimExpr VisitExpr_(const BufferLoadNode* op) final { @@ -138,7 +154,7 @@ class VecAllocAccess : public StmtExprMutator { // variable to be replaced Var var_; // the lanes. - int var_lanes_; + PrimExpr var_lanes_; // Analyzer for simplifications arith::Analyzer analyzer_; }; @@ -151,7 +167,7 @@ class Vectorizer : public StmtMutator, public ExprFunctordtype, 0), IntImm(var->dtype, 1), var_lanes); } @@ -182,21 +198,29 @@ class Vectorizer : public StmtMutator, public ExprFunctora) && b.same_as(op->b)) { return GetRef(op); } else { - // TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455 - int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); - if (lanes != 1) { + bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector(); + bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector(); + if (is_vec_a && is_vec_b) { + // Let's not multiply scalable and fixed length vectors + ICHECK(a.dtype().is_scalable_vector() == b.dtype().is_scalable_vector()); + } + if (is_vec_a || is_vec_b) { const RampNode* b_ramp = b.as(); const RampNode* a_ramp = a.as(); - if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) { - int lanes = static_cast(Downcast(a_ramp->lanes)->value); + if (a_ramp && !b_ramp && analyzer_.CanProve(b > 0)) { + PrimExpr lanes = a_ramp->lanes; return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes); } - if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) { - int lanes = static_cast(Downcast(b_ramp->lanes)->value); + if (b_ramp && !a_ramp && analyzer_.CanProve(a > 0)) { + PrimExpr lanes = b_ramp->lanes; return Ramp(b_ramp->base * a, b_ramp->stride * a, lanes); } + int a_lanes = a.dtype().get_lanes_or_vscale_factor(); + int b_lanes = b.dtype().get_lanes_or_vscale_factor(); + int max_lanes = std::max(a_lanes, b_lanes); + bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); + return Mul(BroadcastTo(a, max_lanes, is_scalable), BroadcastTo(b, max_lanes, is_scalable)); } - return Mul(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } return BinaryVec(op); } @@ -227,18 +251,24 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->base); PrimExpr stride = this->VisitExpr(op->stride); - // TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455 - int op_lanes = static_cast(Downcast(op->lanes)->value); - if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) { + ICHECK(!base.dtype().is_scalable_vector()) + << "Creating scalable vectors from existing vectors is not supported."; + ICHECK(!stride.dtype().is_scalable_vector()) + << "Ramp stride with scalable dtype is not supported"; + if (base.dtype().is_fixed_length_vector() && stride.dtype().is_scalar()) { + ICHECK(op->lanes->IsInstance()) + << "Vectorizing over existing scalable vectors is not supported."; const RampNode* base_ramp = base.as(); + int op_lanes = static_cast(Downcast(op->lanes)->value); int base_ramp_lanes = static_cast(Downcast(base_ramp->lanes)->value); - if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op_lanes))) { + if (analyzer_.CanProve(base_ramp->stride == + stride * make_const(stride.dtype(), base_ramp_lanes))) { return Ramp(base_ramp->base, stride, op_lanes * base_ramp_lanes); } } int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes()); - base = BroadcastTo(base, lanes); - stride = BroadcastTo(stride, lanes); + base = BroadcastTo(base, lanes, false); + stride = BroadcastTo(stride, lanes, false); Array elems; for (int i = 0; i < lanes; ++i) { elems.push_back( @@ -249,7 +279,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->value); - if (value.dtype().lanes() != 1) { + if (value.dtype().is_scalable_or_fixed_length_vector()) { need_scalarize_ = true; return GetRef(op); } @@ -267,16 +297,27 @@ class Vectorizer : public StmtMutator, public ExprFunctorcondition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { return GetRef(op); } else { - int lanes = std::max(std::max(cond.dtype().lanes(), t.dtype().lanes()), f.dtype().lanes()); - return Select(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes)); + int cond_lanes = cond.dtype().get_lanes_or_vscale_factor(); + int t_lanes = t.dtype().get_lanes_or_vscale_factor(); + int f_lanes = f.dtype().get_lanes_or_vscale_factor(); + int lanes = std::max(std::max(cond_lanes, t_lanes), f_lanes); + bool is_scalable = cond.dtype().is_scalable_vector() || t.dtype().is_scalable_vector() || + f.dtype().is_scalable_vector(); + return Select(BroadcastTo(cond, lanes, is_scalable), BroadcastTo(t, lanes, is_scalable), + BroadcastTo(f, lanes, is_scalable)); } } + PrimExpr VisitExpr_(const CastNode* op) final { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); } else { - return Cast(op->dtype.with_lanes(value.dtype().lanes()), value); + if (value.dtype().is_scalable_vector()) { + return Cast(op->dtype.with_scalable_vscale_factor(value.dtype().vscale_factor()), value); + } else { + return Cast(op->dtype.with_lanes(value.dtype().lanes()), value); + } } } @@ -312,10 +353,17 @@ class Vectorizer : public StmtMutator, public ExprFunctorargs[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) { return GetRef(op); } else { - int lanes = std::max(t.dtype().lanes(), f.dtype().lanes()); - t = BroadcastTo(t, lanes); - f = BroadcastTo(f, lanes); - return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}); + int t_lanes = t.dtype().get_lanes_or_vscale_factor(); + int f_lanes = f.dtype().get_lanes_or_vscale_factor(); + int lanes = std::max(t_lanes, f_lanes); + bool is_scalable = t.dtype().is_scalable_vector() || f.dtype().is_scalable_vector(); + t = BroadcastTo(t, lanes, is_scalable); + f = BroadcastTo(f, lanes, is_scalable); + if (is_scalable) { + return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, {cond, t, f}); + } else { + return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}); + } } } // Reinterpret expr @@ -325,8 +373,12 @@ class Vectorizer : public StmtMutator, public ExprFunctorargs[0])) { return GetRef(op); } else { - int lanes = value.dtype().lanes(); - return Call(op->dtype.with_lanes(lanes), op->op, {value}); + int lanes = value.dtype().get_lanes_or_vscale_factor(); + if (value.dtype().is_scalable_vector()) { + return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, {value}); + } else { + return Call(op->dtype.with_lanes(lanes), op->op, {value}); + } } } // Call @@ -351,7 +403,8 @@ class Vectorizer : public StmtMutator, public ExprFunctorop.as(); - bool vectorizable = optional_op && op_vectorizable_.get(optional_op.value(), false); + bool vectorizable = optional_op && op_vectorizable_.get(optional_op.value(), false) && + !op->dtype.is_scalable_vector(); if (!vectorizable) { // Cannot vectorize this op @@ -409,7 +462,8 @@ class Vectorizer : public StmtMutator, public ExprFunctorsecond, value)) << "Let cannot bind the same var to two different values"; } - if (value.dtype().lanes() != op->value.dtype().lanes()) { + if (value.dtype().get_lanes_or_vscale_factor() != + op->value.dtype().get_lanes_or_vscale_factor()) { Var new_var(op->var->name_hint, value.dtype()); let_binding_[op->var] = new_var; return Let(new_var, value, this->VisitExpr(op->body)); @@ -433,20 +487,27 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->value); if (!indices.same_as(op->indices) || !value.same_as(op->value)) { + ICHECK(!op->buffer->dtype.is_scalable_vector()) + << "Vectorizing over scalable buffer elements is not supported in vectorizer."; // How many lanes of indexing are present in the index and - // buffer element type, excluding the last index. T + // buffer element type, excluding the last index. int other_index_lanes = op->buffer->dtype.lanes(); for (size_t i = 0; i < indices.size() - 1; i++) { other_index_lanes *= indices[i].dtype().lanes(); + // Only allow the last index to be scalable + ICHECK(!indices[i].dtype().is_scalable_vector()) << "Only the last index can be scalable."; } // The total number of lanes of indexing, including the last index. - int index_lanes = other_index_lanes * indices[indices.size() - 1].dtype().lanes(); + int lanes_in_last_index = indices[indices.size() - 1].dtype().get_lanes_or_vscale_factor(); + int index_lanes = other_index_lanes * lanes_in_last_index; // The total number of lanes in this store operation. Either // the index or the value will be broadcast out to this number // of lanes, depending on which has more lanes. - int total_lanes = std::max(index_lanes, value.dtype().lanes()); + int value_dtype_lanes = value.dtype().get_lanes_or_vscale_factor(); + bool is_last_index_scalable = indices[indices.size() - 1].dtype().is_scalable_vector(); + int total_lanes = std::max(index_lanes, value_dtype_lanes); ICHECK_EQ(total_lanes % other_index_lanes, 0) << "When storing to buffer " << op->buffer->name << ", cannot produce " << total_lanes @@ -455,11 +516,12 @@ class Vectorizer : public StmtMutator, public ExprFunctorindices = indices; - writer->value = BroadcastTo(value, total_lanes); + writer->value = BroadcastTo(value, total_lanes, is_last_index_scalable); } return std::move(store); @@ -512,7 +574,8 @@ class Vectorizer : public StmtMutator, public ExprFunctorvar)) << "SSA violation, a single var is binded twice"; let_binding_[op->var] = value; - if (value.dtype().lanes() != op->value.dtype().lanes()) { + if (value.dtype().get_lanes_or_vscale_factor() != + op->value.dtype().get_lanes_or_vscale_factor()) { Var new_var(op->var->name_hint, value.dtype()); let_binding_[op->var] = new_var; return LetStmt(new_var, value, this->VisitStmt(op->body)); @@ -566,8 +629,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorname_hint + ".s", var_->dtype); stmt = Substitute(stmt, {{var_, idx}}); - return For(idx, IntImm(var_->dtype, 0), IntImm(var_->dtype, var_lanes_), ForKind::kSerial, - stmt); + return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt); } // ProducerStore Stmt VisitStmt_(const ProducerStoreNode* op) final { @@ -582,7 +644,7 @@ class Vectorizer : public StmtMutator, public ExprFunctora) && b.same_as(op->b)) { return GetRef(op); } else { - int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); - return TOp(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + int a_lanes = a.dtype().get_lanes_or_vscale_factor(); + int b_lanes = b.dtype().get_lanes_or_vscale_factor(); + int lanes = std::max(a_lanes, b_lanes); + bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); + return TOp(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable)); } } template @@ -635,19 +700,22 @@ class Vectorizer : public StmtMutator, public ExprFunctora) && b.same_as(op->b)) { return GetRef(op); } else { - int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); + int a_lanes = a.dtype().get_lanes_or_vscale_factor(); + int b_lanes = b.dtype().get_lanes_or_vscale_factor(); + int lanes = std::max(a_lanes, b_lanes); if (lanes != 1) { const RampNode* b_ramp = b.as(); const RampNode* a_ramp = a.as(); - if (a.dtype().lanes() == 1 && b_ramp) { + if (!a.dtype().is_scalable_or_fixed_length_vector() && b_ramp) { return Ramp(fcompute(a, b_ramp->base), fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes); } - if (b.dtype().lanes() == 1 && a_ramp) { + if (!b.dtype().is_scalable_or_fixed_length_vector() && a_ramp) { return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); } } - return fcompute(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); + return fcompute(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable)); } } }; @@ -657,11 +725,7 @@ class LoopVectorizer : public StmtMutator { Stmt VisitStmt_(const ForNode* op) final { if (op->kind == ForKind::kVectorized) { ICHECK(is_zero(op->min)); - auto* extent_as_int = op->extent.as(); - if (!extent_as_int || extent_as_int->value < 1) { - LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent; - } - return Vectorizer(op->loop_var, static_cast(extent_as_int->value))(op->body); + return Vectorizer(op->loop_var, op->extent)(op->body); } else { return StmtMutator::VisitStmt_(op); } diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index 7d0fac242307..e9f2757f4743 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -19,32 +19,29 @@ from tvm import te from tvm.script import ir as I from tvm.script import tir as T +import pytest -def test_vectorize_loop(): - dtype = "int64" - n = te.var("n") - ib = tvm.tir.ir_builder.create() - A = ib.pointer("float32", name="A") - with ib.for_range(0, n) as i: - with ib.for_range(0, 4, kind="vectorize") as j: - A[j] = tvm.tir.const(1, A.dtype) - stmt = ib.get() - - assert isinstance(stmt.body, tvm.tir.For) +@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +def test_vectorize_loop(extent): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((16,), "float32")): + for j in T.vectorized(0, extent): + A[j] = 1 - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) - stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((16,), "float32")): + A[T.Ramp(0, 1, extent)] = T.Broadcast(1, extent) - assert isinstance(stmt, tvm.tir.For) - assert not isinstance(stmt.body, tvm.tir.For) - assert len(stmt.body.indices) == 1 - assert isinstance(stmt.body.indices[0], tvm.tir.Ramp) - assert isinstance(stmt.body.value, tvm.tir.Broadcast) + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) def test_vectorize_vector(): - dtype = "int64" n = te.var("n") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32x4", name="A") @@ -64,28 +61,86 @@ def test_vectorize_vector(): assert isinstance(stmt.body.value, tvm.tir.Broadcast) -def test_vectorize_with_if(): - n = te.var("n") - x = te.var("x") - ib = tvm.tir.ir_builder.create() - A = ib.pointer("float32", name="A") - with ib.for_range(0, 4, kind="vectorize") as i: - with ib.if_scope(x < n): - A[i] = A[i] + 1 - with ib.else_scope(): - with ib.if_scope(i < n): - A[i] = 2.0 - stmt = ib.get() +def test_vectorize_vector_scalable_error(): + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((25,), "float32")): + for j in T.vectorized(T.vscale() * 4): + A[j * 4 : j * 4 + 4] = T.Broadcast(T.float32(1), 4) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt)) - stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + with pytest.raises(tvm.error.InternalError): + tvm.tir.transform.VectorizeLoop()(Module) + + +def test_vectorize_vector_scalable_error2(): + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((25,), "float32xvscalex4")): + for j in T.vectorized(4): + A[j] = T.Broadcast(T.float32(1), T.vscale() * 4) + + with pytest.raises(tvm.error.InternalError): + tvm.tir.transform.VectorizeLoop()(Module) - assert isinstance(stmt, tvm.tir.IfThenElse) - assert len(stmt.then_case.indices) == 1 - assert isinstance(stmt.then_case.indices[0], tvm.tir.Ramp) - assert isinstance(stmt.then_case.value, tvm.tir.Add) - assert stmt.then_case.value.dtype == "float32x4" - assert isinstance(stmt.else_case, tvm.tir.For) + +def test_vectorize_vector_scalable_error3(): + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((25,), "float32")): + for j in T.vectorized(4): + A[j * T.vscale() * 4 : j * T.vscale() * 4 + T.vscale() * 4] = T.Broadcast( + T.float32(1), T.vscale() * 4 + ) + + with pytest.raises(tvm.error.InternalError): + tvm.tir.transform.VectorizeLoop()(Module) + + +def test_vectorize_vector_scalable_error4(): + @I.ir_module + class Module: + @T.prim_func(private=True) + def main(A: T.Buffer((25,), "float32")): + for j in T.vectorized(T.vscale() * 4): + A[j * T.vscale() * 4 : j * T.vscale() * 4 + T.vscale() * 4] = T.Broadcast( + T.float32(1), T.vscale() * 4 + ) + + with pytest.raises(tvm.error.InternalError): + tvm.tir.transform.VectorizeLoop()(Module) + + +@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +def test_vectorize_with_if(extent): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): + for i in T.vectorized(extent): + if x < n: + A[i] = A[i] + T.float32(1) + else: + if i < n: + A[i] = T.float32(2) + + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): + if x < n: + A[T.Ramp(0, 1, extent)] = A[T.Ramp(0, 1, extent)] + T.Broadcast( + T.float32(1), extent + ) + else: + for i_s in range(extent): + if i_s < n: + A[i_s] = T.float32(2) + + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) def test_vectorize_with_if_cond_int64(): @@ -98,25 +153,33 @@ def test_vectorize_with_if_cond_int64(): f = tvm.build(s, [A, B], "llvm") -def test_vectorize_let(): - v = tvm.tir.Var("v", "float32") - ib = tvm.tir.ir_builder.create() - A = ib.pointer("float32", name="A") - with ib.for_range(0, 4, kind="vectorize") as i: - ib.emit(lambda body: tvm.tir.LetStmt(v, A[i] + 1, body)) - A[i] = v + 2 +@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +def test_vectorize_let(extent): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "float32")): + for i in T.vectorized(extent): + v = A[i] + T.float32(1) + A[i] = v + T.float32(2) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], ib.get())) - stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body - assert isinstance(stmt, tvm.tir.LetStmt) - assert stmt.value.dtype == "float32x4" + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "float32")): + v = A[T.Ramp(0, 1, extent)] + T.Broadcast(T.float32(1), extent) + A[T.Ramp(0, 1, extent)] = v + T.Broadcast(T.float32(2), extent) + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) -def test_vectorize_with_le_cond(): + +@pytest.mark.parametrize("extent", (4, tvm.tir.vscale() * 4)) +def test_vectorize_with_le_cond(extent): n = te.var("n") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") - with ib.for_range(0, 4, kind="vectorize") as i: + with ib.for_range(0, extent, kind="vectorize") as i: with ib.if_scope(i <= n): A[i] = A[i] + 1 stmt = ib.get() @@ -124,14 +187,16 @@ def test_vectorize_with_le_cond(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + # Check that the loop was't vectorised assert isinstance(stmt, tvm.tir.For) -def test_vectorize_with_ge_cond(): +@pytest.mark.parametrize("extent", (4, tvm.tir.vscale() * 4)) +def test_vectorize_with_ge_cond(extent): n = te.var("n") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") - with ib.for_range(0, 4, kind="vectorize") as i: + with ib.for_range(0, extent, kind="vectorize") as i: with ib.if_scope(i >= n): A[i] = A[i] + 1 stmt = ib.get() @@ -139,39 +204,51 @@ def test_vectorize_with_ge_cond(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + # Check that the loop wasn't vectorised assert isinstance(stmt, tvm.tir.For) -def test_vectorize_if_then_else(): - n = te.var("n") - x = te.var("x") - ib = tvm.tir.ir_builder.create() - A = ib.pointer("float32", name="A") - with ib.for_range(0, 4, kind="vectorize") as i: - A[i] = tvm.tir.call_intrin("float32", "tir.if_then_else", i > 0, A[i] + 1, A[i]) - stmt = ib.get() +@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +def test_vectorize_if_then_else_scalarize(extent): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "float32")): + for i in T.vectorized(extent): + A[i] = T.if_then_else(i > 0, A[i] + T.float32(1), A[i]) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt)) - stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "float32")): + for i_s in range(extent): + A[i_s] = T.if_then_else(i_s > 0, A[i_s] + T.float32(1), A[i_s]) - assert isinstance(stmt, tvm.tir.For) + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) - ib = tvm.tir.ir_builder.create() - A = ib.pointer("float32", name="A") - with ib.for_range(0, n) as k: - with ib.for_range(0, 4, kind="vectorize") as i: - A[k * 4 + i] = tvm.tir.call_intrin( - "float32", "tir.if_then_else", k > 0, A[k * 4 + i], 0 - ) - stmt = ib.get() - assert isinstance(stmt.body, tvm.tir.For) +@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +def test_vectorize_if_then_else_vector(extent): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), n: T.int32): + for i in range(n): + for j in T.vectorized(extent): + A[i * extent + j] = T.if_then_else(i > 0, A[i * extent + j], 0) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) - stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), n: T.int32): + for i in range(n): + A[T.Ramp(i * extent, 1, extent)] = T.if_then_else( + i > 0, A[T.Ramp(i * extent, 1, extent)], T.Broadcast(0, extent) + ) - assert not isinstance(stmt.body, tvm.tir.For) - assert isinstance(stmt.body.value.args[2], tvm.tir.Broadcast) + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) def test_vectorize_while_fail(): @@ -229,19 +306,123 @@ def test_vectorize_dtype_mismatch(): tvm.lower(s, [A], "llvm", simple_mode=True) -def test_vectorize_with_reinterpret(): +@pytest.mark.parametrize( + "extent, vec_str", [(16, "float32x16"), (T.vscale() * 8, "float32xvscalex8")] +) +def test_vectorize_with_reinterpret(extent, vec_str): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): - for i in T.vectorized(0, 16): + for i in T.vectorized(0, extent): B[i] = T.reinterpret("float32", A[i]) @I.ir_module class After: @T.prim_func def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): - B[0:16] = T.reinterpret("float32x16", A[0:16]) + B[T.Ramp(0, 1, extent)] = T.reinterpret(vec_str, A[T.Ramp(0, 1, extent)]) + + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) + + +@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +@pytest.mark.parametrize( + "op", + ( + T.Mul, + T.Add, + T.Sub, + T.Div, + T.Mod, + T.FloorDiv, + T.FloorMod, + T.Min, + T.Max, + T.EQ, + T.LT, + T.LE, + T.GE, + T.GT, + T.NE, + ), +) +def test_vectorize_binary(op, extent): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + for j in T.vectorized(extent): + A[j] = op(T.float32(3), B[j]) + + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.float32(3), extent), B[T.Ramp(0, 1, extent)]) + + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) + + +@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +@pytest.mark.parametrize("op", (T.And, T.Or)) +def test_vectorize_logical(op, extent): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")): + for j in T.vectorized(extent): + A[j] = op(T.bool(1), B[j]) + + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")): + A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.bool(1), extent), B[T.Ramp(0, 1, extent)]) + + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) + + +@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +def test_vectorize_select(extent): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + for j in T.vectorized(extent): + A[j] = T.Select(T.bool(True), A[j], B[j]) + + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + A[T.Ramp(0, 1, extent)] = T.Select( + T.Broadcast(T.bool(True), extent), + A[T.Ramp(0, 1, extent)], + B[T.Ramp(0, 1, extent)], + ) + + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) + + +@pytest.mark.parametrize("extent, vec_str", [(4, "int32x4"), (T.vscale() * 4, "int32xvscalex4")]) +def test_vectorize_cast(extent, vec_str): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): + for j in T.vectorized(extent): + A[j] = T.Cast("int32", B[j]) + + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): + A[T.Ramp(0, 1, extent)] = T.Cast(vec_str, B[T.Ramp(0, 1, extent)]) mod = tvm.tir.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) From bd59c5d4a5e5bed21f6b2fa9df1b86d80ca903d1 Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Mon, 25 Mar 2024 17:08:15 +0000 Subject: [PATCH 2/3] Linting... --- src/tir/transforms/vectorize_loop.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 4b4694d6edfe..7f96f5abeee2 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -40,8 +40,9 @@ namespace tir { inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) { if (is_scalable) { return Mul(Call(DataType::Int(32), builtin::vscale(), {}), lanes_or_vscale_factor); - } else + } else { return lanes_or_vscale_factor; + } } inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { From 477b89abbfee9ef657d520c634bee36f24d533dd Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Thu, 4 Apr 2024 11:35:23 +0100 Subject: [PATCH 3/3] Respond to review and add one more tests --- include/tvm/runtime/data_type.h | 4 ++- src/tir/ir/expr.cc | 3 +-- src/tir/transforms/vectorize_loop.cc | 18 +++++++------ .../test_tir_transform_vectorize.py | 26 ++++++++++++++++--- 4 files changed, 36 insertions(+), 15 deletions(-) diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 8f3ae9b42460..f7284ec690a4 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -111,7 +111,9 @@ class DataType { return -lanes_as_int; } /*! \return get vscale factor or lanes depending on scalability of the vector. */ - int get_lanes_or_vscale_factor() { return is_scalable_vector() ? vscale_factor() : lanes(); } + int get_lanes_or_vscale_factor() const { + return is_scalable_vector() ? vscale_factor() : lanes(); + } /*! \return whether type is a scalar type. */ bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; } /*! \return whether type is a scalar type. */ diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 6a5585f6aeff..2cd2a698debe 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -197,8 +197,7 @@ TVM_REGISTER_NODE_TYPE(StringImmNode); Cast::Cast(DataType t, PrimExpr value, Span span) { ICHECK(value.defined()); ICHECK_EQ(t.get_lanes_or_vscale_factor(), value.dtype().get_lanes_or_vscale_factor()); - ICHECK((t.is_scalable_vector() == value.dtype().is_scalable_vector()) || - (!t.is_scalable_vector() && !value.dtype().is_scalable_vector())); + ICHECK(t.is_scalable_vector() == value.dtype().is_scalable_vector()); ObjectPtr node = make_object(); node->dtype = t; node->value = std::move(value); diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 7f96f5abeee2..a9cc4975801a 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -54,7 +54,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { if (const BroadcastNode* op = e.as()) { ICHECK(op->dtype.is_scalable_vector() == is_scalable) << "Can't broadcast between scalable and fixed length vectors."; - int e_lanes = is_scalable ? op->dtype.vscale_factor() : op->dtype.lanes(); + int e_lanes = op->dtype.get_lanes_or_vscale_factor(); if (lanes % e_lanes == 0) { return Broadcast(op->value, CreateNewLanes(is_scalable, lanes)); @@ -203,16 +203,17 @@ class Vectorizer : public StmtMutator, public ExprFunctor(); const RampNode* a_ramp = a.as(); - if (a_ramp && !b_ramp && analyzer_.CanProve(b > 0)) { + if (a_ramp && b.dtype().is_scalar() && analyzer_.CanProve(b > 0)) { PrimExpr lanes = a_ramp->lanes; return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes); } - if (b_ramp && !a_ramp && analyzer_.CanProve(a > 0)) { + if (b_ramp && a.dtype().is_scalar() && analyzer_.CanProve(a > 0)) { PrimExpr lanes = b_ramp->lanes; return Ramp(b_ramp->base * a, b_ramp->stride * a, lanes); } @@ -500,14 +501,15 @@ class Vectorizer : public StmtMutator, public ExprFunctor(); const RampNode* a_ramp = a.as(); - if (!a.dtype().is_scalable_or_fixed_length_vector() && b_ramp) { + if (a.dtype().is_scalar() && b_ramp) { return Ramp(fcompute(a, b_ramp->base), fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes); } - if (!b.dtype().is_scalable_or_fixed_length_vector() && a_ramp) { + if (b.dtype().is_scalar() && a_ramp) { return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); } } diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index e9f2757f4743..dbca006b19cb 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -69,7 +69,8 @@ def main(A: T.Buffer((25,), "float32")): for j in T.vectorized(T.vscale() * 4): A[j * 4 : j * 4 + 4] = T.Broadcast(T.float32(1), 4) - with pytest.raises(tvm.error.InternalError): + error_msg = f"Creating scalable vectors from existing vectors is not supported." + with pytest.raises(tvm.error.InternalError, match=error_msg): tvm.tir.transform.VectorizeLoop()(Module) @@ -81,7 +82,8 @@ def main(A: T.Buffer((25,), "float32xvscalex4")): for j in T.vectorized(4): A[j] = T.Broadcast(T.float32(1), T.vscale() * 4) - with pytest.raises(tvm.error.InternalError): + error_msg = f"Vectorizing over scalable buffer elements is not supported in vectorizer." + with pytest.raises(tvm.error.InternalError, match=error_msg): tvm.tir.transform.VectorizeLoop()(Module) @@ -95,7 +97,8 @@ def main(A: T.Buffer((25,), "float32")): T.float32(1), T.vscale() * 4 ) - with pytest.raises(tvm.error.InternalError): + error_msg = f"Vectorizing over existing scalable vectors is not supported." + with pytest.raises(tvm.error.InternalError, match=error_msg): tvm.tir.transform.VectorizeLoop()(Module) @@ -109,7 +112,8 @@ def main(A: T.Buffer((25,), "float32")): T.float32(1), T.vscale() * 4 ) - with pytest.raises(tvm.error.InternalError): + error_msg = f"Creating scalable vectors from existing vectors is not supported." + with pytest.raises(tvm.error.InternalError, match=error_msg): tvm.tir.transform.VectorizeLoop()(Module) @@ -428,5 +432,19 @@ def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): tvm.ir.assert_structural_equal(mod, After) +def test_illegal_extent(): + @I.ir_module(check_well_formed=False) + class Mod: + @T.prim_func + def main(A: T.Buffer((25,), "int32")): + n = T.Var("n", dtype="int32") + for j in T.vectorized(n): + A[j] = 3 + + error_msg = f"Invalid expression for scalable lanes n" + with pytest.raises(tvm.error.InternalError, match=error_msg): + tvm.tir.transform.VectorizeLoop()(Mod) + + if __name__ == "__main__": tvm.testing.main()