Skip to content

Commit

Permalink
[SVE] Support scalable vectors in LoopVectorizer
Browse files Browse the repository at this point in the history
This patch add support for turning loops marked for vectorizing into
scalable vectors if the extent of the loop is a vscale dependent
expression in a correct form.

The testing for both scalable and fixed length vectors in
test_tir_transform.py has been extended and most of the tests
have been converted to TVMScript based testing against expected
output.

Co-authored-by: Luke Hutton <[email protected]>
Co-authored-by: Neil Hickey <[email protected]>
  • Loading branch information
3 people committed Mar 25, 2024
1 parent 9899f9c commit 2aa3efd
Show file tree
Hide file tree
Showing 4 changed files with 405 additions and 147 deletions.
11 changes: 9 additions & 2 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <tvm/ir/expr.h>
#include <tvm/ir/op.h>
#include <tvm/ir/type.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>

Expand Down Expand Up @@ -959,10 +960,16 @@ inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) {

template <typename ValueType, typename>
inline PrimExpr make_const(DataType t, ValueType value, Span span) {
if (t.lanes() == 1) {
if (t.is_scalar()) {
return MakeConstScalar(t, value, span);
} else {
return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span);
if (t.is_fixed_length_vector()) {
return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span);
} else {
PrimExpr lanes =
tir::Mul(tir::Call(DataType::Int(32), tir::builtin::vscale(), {}), t.vscale_factor());
return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), lanes, span);
}
}
}

Expand Down
14 changes: 10 additions & 4 deletions src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@ TVM_REGISTER_NODE_TYPE(StringImmNode);
// Cast
Cast::Cast(DataType t, PrimExpr value, Span span) {
ICHECK(value.defined());
ICHECK_EQ(t.lanes(), value.dtype().lanes());
ICHECK_EQ(t.get_lanes_or_vscale_factor(), value.dtype().get_lanes_or_vscale_factor());
ICHECK((t.is_scalable_vector() == value.dtype().is_scalable_vector()) ||
(!t.is_scalable_vector() && !value.dtype().is_scalable_vector()));
ObjectPtr<CastNode> node = make_object<CastNode>();
node->dtype = t;
node->value = std::move(value);
Expand Down Expand Up @@ -354,7 +356,8 @@ And::And(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";

ObjectPtr<AndNode> node = make_object<AndNode>();
node->dtype = DataType::Bool(a.dtype().lanes());
node->dtype =
DataType::Bool(a.dtype().get_lanes_or_vscale_factor(), a.dtype().is_scalable_vector());
node->a = std::move(a);
node->b = std::move(b);
node->span = std::move(span);
Expand All @@ -376,7 +379,8 @@ Or::Or(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";

ObjectPtr<OrNode> node = make_object<OrNode>();
node->dtype = DataType::Bool(a.dtype().lanes());
node->dtype =
DataType::Bool(a.dtype().get_lanes_or_vscale_factor(), a.dtype().is_scalable_vector());
node->a = std::move(a);
node->b = std::move(b);
node->span = std::move(span);
Expand Down Expand Up @@ -412,7 +416,9 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp
ICHECK(true_value.defined()) << "ValueError: true_value is undefined";
ICHECK(false_value.defined()) << "ValueError: true_value is undefined";
ICHECK(condition.dtype().is_bool());
ICHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1);
ICHECK(condition.dtype().get_lanes_or_vscale_factor() ==
true_value.dtype().get_lanes_or_vscale_factor() ||
condition.dtype().is_scalar());
ICHECK(false_value.dtype() == true_value.dtype())
<< "TypeError: mismatched types. "
<< "False type: " << false_value.dtype() << "; True type: " << true_value.dtype();
Expand Down
Loading

0 comments on commit 2aa3efd

Please sign in to comment.