Skip to content

Commit

Permalink
[SVE] Check for SVE target in VectorizeLoop
Browse files Browse the repository at this point in the history
Check that we are compiling for an SVE enabled target when the extent
of a loop marked for vectorizing has a vscale dependent extent.
  • Loading branch information
ekalda committed Apr 18, 2024
1 parent de91c5c commit 2c5d90c
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 32 deletions.
4 changes: 4 additions & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
.value();

bool instrument_lwp = pass_ctx->GetConfig<Bool>("tir.instrument_lwp", Bool(false)).value();
Target current_target = Target::Current();

Array<transform::Pass> user_lower_phase0 = Array<transform::Pass>();
Array<transform::Pass> user_lower_phase1 = Array<transform::Pass>();
Expand Down Expand Up @@ -196,6 +197,9 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
Array<tvm::transform::Pass> pass_list = user_lower_phase0;

// PHASE 1
if (current_target.defined()) {
pass_list.push_back(tir::transform::BindTarget(current_target));
}
pass_list.push_back(tir::transform::InjectPrefetch());
pass_list.push_back(tir::transform::TextureFlatten());
pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
Expand Down
25 changes: 22 additions & 3 deletions src/tir/transforms/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
#include <unordered_map>
#include <vector>

#include "../../src/arith/scalable_expression.h"
#include "../../tir/analysis/check_contains.h"

namespace tvm {
namespace tir {

Expand Down Expand Up @@ -725,17 +728,33 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp

class LoopVectorizer : public StmtMutator {
public:
LoopVectorizer(PrimFunc f) {
auto target = f->attrs.GetAttr<tvm::Target>(tvm::attr::kTarget);
if (target.defined()) {
target_ = Downcast<Target>(target);
has_sve_ = target_->GetFeature<Bool>("has_sve").value_or(Bool(false));
}
}

Stmt VisitStmt_(const ForNode* op) final {
if (op->kind == ForKind::kVectorized) {
auto* extent_as_int = op->extent.as<IntImmNode>();
if (!extent_as_int || extent_as_int->value < 1) {
bool is_scalable_expr = CheckContains::ExprContains(op->extent, arith::IsVScaleCall);
ICHECK(is_scalable_expr && has_sve_)
<< "Failed to vectorize loop with extent " << op->extent << " for target " << target_;
}
ICHECK(is_zero(op->min));
return Vectorizer(op->loop_var, op->extent)(op->body);
} else {
return StmtMutator::VisitStmt_(op);
}
}
};

Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer()(std::move(stmt)); }
private:
bool has_sve_{false};
Target target_{};
};

class VectorizeSkipper : public StmtMutator {
public:
Expand All @@ -759,7 +778,7 @@ Pass VectorizeLoop(bool enable_vectorize) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
if (enable_vectorize) {
n->body = LoopVectorizer()(std::move(n->body));
n->body = LoopVectorizer(f)(std::move(n->body));
} else {
n->body = VectorizeSkipper()(std::move(n->body));
}
Expand Down
Loading

0 comments on commit 2c5d90c

Please sign in to comment.