Skip to content

Commit

Permalink
[SVE] Support scalable vectors in LoopVectorizer (#16782)
Browse files Browse the repository at this point in the history
This patch add support for turning loops marked for vectorizing into
scalable vectors if the extent of the loop is a vscale dependent
expression in a correct form.

The testing for both scalable and fixed length vectors in
test_tir_transform.py has been extended and most of the tests
have been converted to TVMScript based testing against expected
output.

Co-authored-by: Luke Hutton <[email protected]>
Co-authored-by: Neil Hickey <[email protected]>
  • Loading branch information
3 people authored Apr 9, 2024
1 parent a309b6b commit 4d4f050
Show file tree
Hide file tree
Showing 5 changed files with 428 additions and 148 deletions.
4 changes: 3 additions & 1 deletion include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ class DataType {
return -lanes_as_int;
}
/*! \return get vscale factor or lanes depending on scalability of the vector. */
int get_lanes_or_vscale_factor() { return is_scalable_vector() ? vscale_factor() : lanes(); }
int get_lanes_or_vscale_factor() const {
return is_scalable_vector() ? vscale_factor() : lanes();
}
/*! \return whether type is a scalar type. */
bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
/*! \return whether type is a scalar type. */
Expand Down
11 changes: 9 additions & 2 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <tvm/ir/expr.h>
#include <tvm/ir/op.h>
#include <tvm/ir/type.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>

Expand Down Expand Up @@ -959,10 +960,16 @@ inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) {

template <typename ValueType, typename>
inline PrimExpr make_const(DataType t, ValueType value, Span span) {
if (t.lanes() == 1) {
if (t.is_scalar()) {
return MakeConstScalar(t, value, span);
} else {
return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span);
if (t.is_fixed_length_vector()) {
return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span);
} else {
PrimExpr lanes =
tir::Mul(tir::Call(DataType::Int(32), tir::builtin::vscale(), {}), t.vscale_factor());
return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), lanes, span);
}
}
}

Expand Down
13 changes: 9 additions & 4 deletions src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ TVM_REGISTER_NODE_TYPE(StringImmNode);
// Cast
Cast::Cast(DataType t, PrimExpr value, Span span) {
ICHECK(value.defined());
ICHECK_EQ(t.lanes(), value.dtype().lanes());
ICHECK_EQ(t.get_lanes_or_vscale_factor(), value.dtype().get_lanes_or_vscale_factor());
ICHECK(t.is_scalable_vector() == value.dtype().is_scalable_vector());
ObjectPtr<CastNode> node = make_object<CastNode>();
node->dtype = t;
node->value = std::move(value);
Expand Down Expand Up @@ -354,7 +355,8 @@ And::And(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";

ObjectPtr<AndNode> node = make_object<AndNode>();
node->dtype = DataType::Bool(a.dtype().lanes());
node->dtype =
DataType::Bool(a.dtype().get_lanes_or_vscale_factor(), a.dtype().is_scalable_vector());
node->a = std::move(a);
node->b = std::move(b);
node->span = std::move(span);
Expand All @@ -376,7 +378,8 @@ Or::Or(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";

ObjectPtr<OrNode> node = make_object<OrNode>();
node->dtype = DataType::Bool(a.dtype().lanes());
node->dtype =
DataType::Bool(a.dtype().get_lanes_or_vscale_factor(), a.dtype().is_scalable_vector());
node->a = std::move(a);
node->b = std::move(b);
node->span = std::move(span);
Expand Down Expand Up @@ -412,7 +415,9 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp
ICHECK(true_value.defined()) << "ValueError: true_value is undefined";
ICHECK(false_value.defined()) << "ValueError: true_value is undefined";
ICHECK(condition.dtype().is_bool());
ICHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1);
ICHECK(condition.dtype().get_lanes_or_vscale_factor() ==
true_value.dtype().get_lanes_or_vscale_factor() ||
condition.dtype().is_scalar());
ICHECK(false_value.dtype() == true_value.dtype())
<< "TypeError: mismatched types. "
<< "False type: " << false_value.dtype() << "; True type: " << true_value.dtype();
Expand Down
Loading

0 comments on commit 4d4f050

Please sign in to comment.