Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TIR] Enhance loop unroll with unroll local access
Browse files Browse the repository at this point in the history
This PR enhances the unroller with an unroll local access option.
This option will detect loop variables that access local memories
and unroll them independent of other options.

A test case is added. This option is by default turned off and
can be useful in certain cases to improve unroller as these
local memory access have to be unrolled at some time pt to be
lifted as registers
tqchen committed Mar 7, 2023
1 parent 2c4af88 commit b61381b
Showing 2 changed files with 101 additions and 5 deletions.
64 changes: 59 additions & 5 deletions src/tir/transforms/unroll_loop.cc
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@
#include <unordered_set>
#include <vector>

#include "../../runtime/thread_storage_scope.h"
#include "ir_utils.h"

namespace tvm {
@@ -43,6 +44,7 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode<UnrollLoopConfigNode> {
int auto_max_depth;
int auto_max_extent;
int explicit_unroll;
int unroll_local_access;

TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") {
TVM_ATTR_FIELD(auto_max_step)
@@ -57,6 +59,9 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode<UnrollLoopConfigNode> {
TVM_ATTR_FIELD(explicit_unroll)
.describe("Whether to explicitly unroll the loop instead of setting a pragma")
.set_default(true);
TVM_ATTR_FIELD(unroll_local_access)
.describe("Whether to always unroll local access")
.set_default(false);
}
};

@@ -68,14 +73,30 @@ class UnrollLoopConfig : public Attrs {
TVM_REGISTER_NODE_TYPE(UnrollLoopConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);

class VarLocalAccessMarker : public ExprVisitor {
public:
explicit VarLocalAccessMarker(
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>* var_touched_local)
: var_touched_local_(var_touched_local) {}

void VisitExpr_(const VarNode* op) final { var_touched_local_->insert(GetRef<Var>(op)); }

private:
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>* var_touched_local_;
};

// The Visitor is used to check whether var is used as write index in a local memory
// If a loop var is used as indices to a local memory, it must be unrolled so
// the local memory access can be turned into register access.
class LoopUnroller : public StmtExprMutator {
public:
explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent,
bool explicit_unroll)
bool explicit_unroll, bool unroll_local_access)
: auto_max_step_(auto_max_step),
auto_max_depth_(auto_max_depth),
auto_max_extent_(auto_max_extent),
explicit_unroll_(explicit_unroll) {}
explicit_unroll_(explicit_unroll),
unroll_local_access_(unroll_local_access) {}

Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == "pragma_auto_unroll_max_step") {
@@ -96,6 +117,7 @@ class LoopUnroller : public StmtExprMutator {
}

Stmt VisitStmt_(const ForNode* op) {
// Post order so we can collect more information
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
int value = GetExtent(op);
@@ -111,6 +133,12 @@ class LoopUnroller : public StmtExprMutator {
auto_unroll = true;
}

// If a loop var is used as indices to a local memory, it must be unrolled so
// the local memory access can be turned into register access.
if (this->var_touched_local_.count(op->loop_var) && value > 0 && unroll_local_access_) {
auto_unroll = true;
}

if (auto_unroll) {
step_count_ *= value;
unroll_depth_ += 1;
@@ -137,14 +165,36 @@ class LoopUnroller : public StmtExprMutator {
LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
if (unroll_local_access_) {
auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data));
if (storage_scope.rank == runtime::StorageRank::kLocal) {
VarLocalAccessMarker marker(&var_touched_local_);
for (PrimExpr e : op->indices) {
marker(e);
}
}
}
return GetRef<PrimExpr>(op);
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
++step_count_;
return StmtExprMutator::VisitStmt_(op);
if (unroll_local_access_) {
auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data));
if (storage_scope.rank == runtime::StorageRank::kLocal) {
VarLocalAccessMarker marker(&var_touched_local_);
for (PrimExpr e : op->indices) {
marker(e);
}
}
}
return StmtMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const EvaluateNode* op) final {
++step_count_;
return StmtExprMutator::VisitStmt_(op);
return StmtMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const SeqStmtNode* op) final {
@@ -202,19 +252,23 @@ class LoopUnroller : public StmtExprMutator {
// this not not count the total steps, only count the number of loops
int auto_max_extent_;
bool explicit_unroll_;
// Wether to unroll loops to local access.
bool unroll_local_access_{false};
// Number of normal loops in scope
int normal_loop_depth_{0};
// number of unrolled cases in current scope.
int unroll_depth_{0};
// Number of total steps unrolled
int step_count_{0};
// set of indices touched during visit local memory
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> var_touched_local_;
// analyzer
arith::Analyzer analyzer_;
};

Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) {
Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent,
cfg->explicit_unroll)(stmt);
cfg->explicit_unroll, cfg->unroll_local_access)(stmt);
if (!ret.same_as(stmt)) {
return ConvertSSA(ret);
} else {
42 changes: 42 additions & 0 deletions tests/python/unittest/test_tir_transform_unroll_loop.py
Original file line number Diff line number Diff line change
@@ -134,7 +134,49 @@ def main():
tvm.ir.assert_structural_equal(after, expected)


def test_unroll_local_access():
@tvm.script.ir_module
class Before:
@T.prim_func
def main(B: T.Buffer((64,), "float32")):
for bx in T.thread_binding(4, thread="blockIdx.x"):
for tx in T.thread_binding(4, thread="threadIdx.x"):
A_local_data = T.allocate([4], dtype="float32", scope="local")
A_local = T.Buffer([4], dtype="float32", data=A_local_data)
for i in T.serial(4):
A_local[i] = T.float32(i)

@tvm.script.ir_module
class Expected:
@T.prim_func
def main(B: T.Buffer((64,), "float32")):
for bx in T.thread_binding(4, thread="blockIdx.x"):
for tx in T.thread_binding(4, thread="threadIdx.x"):
A_local_data = T.allocate([4], dtype="float32", scope="local")
A_local = T.Buffer([4], dtype="float32", data=A_local_data)
A_local[0] = T.float32(0)
A_local[1] = T.float32(1)
A_local[2] = T.float32(2)
A_local[3] = T.float32(3)

with tvm.transform.PassContext(
config={
"tir.UnrollLoop": {
"auto_max_depth": 0,
"auto_max_extent": 1,
"explicit_unroll": True,
"unroll_local_access": True,
}
}
):
after = tvm.tir.transform.UnrollLoop()(Before)
after = tvm.tir.transform.Simplify()(after)

tvm.ir.assert_structural_equal(after, Expected)


if __name__ == "__main__":
test_unroll_local_access()
test_unroll_loop()
test_unroll_fake_loop()
test_unroll_single_count_loops()

0 comments on commit b61381b

Please sign in to comment.