Skip to content

Commit

Permalink
Use Target::Current()
Browse files Browse the repository at this point in the history
Use Target::Current() in LoopVectorizer to check for SVE

Change-Id: I15363bad540d6752d6c2098c93efce25c107309b
  • Loading branch information
ekalda committed Apr 18, 2024
1 parent 2c5d90c commit 8c95c41
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 77 deletions.
4 changes: 0 additions & 4 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
.value();

bool instrument_lwp = pass_ctx->GetConfig<Bool>("tir.instrument_lwp", Bool(false)).value();
Target current_target = Target::Current();

Array<transform::Pass> user_lower_phase0 = Array<transform::Pass>();
Array<transform::Pass> user_lower_phase1 = Array<transform::Pass>();
Expand Down Expand Up @@ -197,9 +196,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
Array<tvm::transform::Pass> pass_list = user_lower_phase0;

// PHASE 1
if (current_target.defined()) {
pass_list.push_back(tir::transform::BindTarget(current_target));
}
pass_list.push_back(tir::transform::InjectPrefetch());
pass_list.push_back(tir::transform::TextureFlatten());
pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
Expand Down
23 changes: 8 additions & 15 deletions src/tir/transforms/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -728,32 +728,25 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp

class LoopVectorizer : public StmtMutator {
public:
LoopVectorizer(PrimFunc f) {
auto target = f->attrs.GetAttr<tvm::Target>(tvm::attr::kTarget);
if (target.defined()) {
target_ = Downcast<Target>(target);
has_sve_ = target_->GetFeature<Bool>("has_sve").value_or(Bool(false));
}
}

Stmt VisitStmt_(const ForNode* op) final {
if (op->kind == ForKind::kVectorized) {
auto* extent_as_int = op->extent.as<IntImmNode>();
if (!extent_as_int || extent_as_int->value < 1) {
Target current_target = Target::Current();
bool has_sve{false};
if (current_target.defined()) {
has_sve = current_target->GetFeature<Bool>("has_sve").value_or(Bool(false));
}
bool is_scalable_expr = CheckContains::ExprContains(op->extent, arith::IsVScaleCall);
ICHECK(is_scalable_expr && has_sve_)
<< "Failed to vectorize loop with extent " << op->extent << " for target " << target_;
ICHECK(is_scalable_expr && has_sve) << "Failed to vectorize loop with extent " << op->extent
<< " for target " << current_target;
}
ICHECK(is_zero(op->min));
return Vectorizer(op->loop_var, op->extent)(op->body);
} else {
return StmtMutator::VisitStmt_(op);
}
}

private:
bool has_sve_{false};
Target target_{};
};

class VectorizeSkipper : public StmtMutator {
Expand All @@ -778,7 +771,7 @@ Pass VectorizeLoop(bool enable_vectorize) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
if (enable_vectorize) {
n->body = LoopVectorizer(f)(std::move(n->body));
n->body = LoopVectorizer()(std::move(n->body));
} else {
n->body = VectorizeSkipper()(std::move(n->body));
}
Expand Down
Loading

0 comments on commit 8c95c41

Please sign in to comment.