Skip to content

Commit

Permalink
Fix Op Pattern Detection (tlc-pack#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinhongyii authored and MasterJH5574 committed Mar 31, 2022
1 parent ab6c35a commit 46685cb
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 66 deletions.
138 changes: 72 additions & 66 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1917,100 +1917,106 @@ bool CheckSameArray(const Array<PrimExpr>& arr1, const Array<PrimExpr>& arr2) {
return true;
}

bool CheckElemwisePattern(const Array<PrimExpr>& indices_l, const Array<PrimExpr>& indices_r) {
if (indices_l.size() != indices_r.size()) {
return false;
}
int n = indices_l.size();
for (int i = 0; i < n; i++) {
if (!indices_l[i].same_as(indices_r[i])) {
return false;
}
}
return true;
}

bool CheckBroadcastPattern(const Array<PrimExpr>& indices_l, const Array<PrimExpr>& indices_r){
if (indices_l.size() < indices_r.size()) {
return false;
}
int j=0;
for (int i = 0; i < static_cast<int>(indices_r.size()); i++) {
for (; j < static_cast<int>(indices_l.size()) && !indices_l[j].same_as
(indices_r[i]); j++);
if(j==static_cast<int>(indices_l.size())){
return false;
}
}
return true;
}

bool CheckInjectivePattern(const Array<PrimExpr>& indices_l, const Array<PrimExpr>& indices_r){
std::unordered_set<const VarNode*> vars;
for (int i = 0; i < static_cast<int>(indices_l.size()); i++) {
if (const auto* v = indices_l[i].as<VarNode>()) {
vars.insert(v);
} else {
return false;
}
}
for (int i = 0; i < static_cast<int>(indices_r.size()); i++) {
if (tir::UsesVar(indices_r[i],
[&vars](const VarNode* var) { return !vars.count(var); })) {
return false;
}
}
return true;
}

class PatternKindAnalyzer: public StmtExprVisitor {
void VisitStmt_(const BufferStoreNode* op) final {
indices_.push_back(op->indices);
store_indices_ = op->indices;
StmtVisitor::VisitStmt_(op);
}
void VisitExpr_(const BufferLoadNode* op) final {
indices_.push_back(op->indices);
load_indices_.push_back(op->indices);
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const CallNode* op) final {
kind_=relay::kOpaque;
}

void VisitStmt_(const BlockNode* op)final {
if (op->name_hint == "root") {
StmtVisitor::VisitStmt(op->body);
return;
}

relay::OpPatternKind kind = relay::kOpaque;

//test whether is elemwise
indices_.clear();
load_indices_.clear();
store_indices_.clear();
StmtVisitor::VisitStmt(op->body);
bool same_index = true;
for (int i = 1; i < static_cast<int>(indices_.size()); i++) {
if(!CheckSameArray(indices_[0],indices_[i])) {
same_index = false;
break;

relay::OpPatternKind index_pair_pattern = relay::kElemWise;
if (load_indices_.empty()) {
index_pair_pattern = relay::kBroadcast;
} else {
for (int i = 0; i < static_cast<int>(load_indices_.size()); i++) {
if (CheckElemwisePattern(store_indices_, load_indices_[i])) {
index_pair_pattern = std::max(index_pair_pattern, relay::kElemWise);
} else if (CheckBroadcastPattern(store_indices_, load_indices_[i])) {
index_pair_pattern = std::max(index_pair_pattern, relay::kBroadcast);
} else if (CheckInjectivePattern(store_indices_, load_indices_[i])) {
index_pair_pattern = std::max(index_pair_pattern, relay::kInjective);
} else {
index_pair_pattern = relay::kOpaque;
break;
}
}
}
if (same_index) {
kind = relay::kElemWise;
kind_ = static_cast<int>(kind)>static_cast<int>(kind_)?kind:kind_;
if (index_pair_pattern != relay::kOpaque) {
kind_ = std::max(kind_, index_pair_pattern);
return;
}

if (const auto* store = op->body.as<BufferStoreNode>()) {
if (const auto* load = store->value.as<BufferLoadNode>()) {
//test whether is broadcast
int j = 0;
bool all_var_axis = true;
for (int i = 0; i < static_cast<int>(load->indices.size()); i++) {
if (load->indices[i].as<VarNode>()) {
for (; j < static_cast<int>(store->indices.size()) && !store->indices[j].same_as
(load->indices[i]); j++);
} else {
all_var_axis = false;
break;
}
}
if (all_var_axis && j != static_cast<int>(store->indices.size())) {
kind = relay::kBroadcast;
kind_ = static_cast<int>(kind)>static_cast<int>(kind_)?kind:kind_;
return;
}

std::unordered_set<const VarNode*> vars;
for (int i = 0; i < static_cast<int>(store->indices.size()); i++) {
if (const auto* v = store->indices[i].as<VarNode>()) {
vars.insert(v);
}
}
if (vars.size() == store->indices.size()) {
bool use_other_var = false;
for (int i = 0; i < static_cast<int>(load->indices.size()); i++) {
if (tir::UsesVar(load->indices[i],
[&vars](const VarNode* var) { return !vars.count(var); })) {
use_other_var = true;
break;
}
}
if (!use_other_var) {
kind = relay::kInjective;
kind_ = static_cast<int>(kind) > static_cast<int>(kind_) ? kind : kind_;
return;
}
}
}
}
//test whether is reduce
for (IterVar it : op->iter_vars) {
if (it->iter_type == kCommReduce) {
kind = relay::kCommReduce;
kind_ = static_cast<int>(kind)>static_cast<int>(kind_)?kind:kind_;
kind_ = std::max(kind_, relay::kCommReduce);
return;
}
}

kind_ = static_cast<int>(kind)>static_cast<int>(kind_)?kind:kind_;
kind_ = relay::kOpaque;
}

Array<Array<PrimExpr>> indices_;
Array<PrimExpr> store_indices_;
Array<Array<PrimExpr>> load_indices_;
relay::OpPatternKind kind_ =relay::kElemWise;

public:
Expand Down
24 changes: 24 additions & 0 deletions tests/python/relax/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,30 @@ def foo(x: Tensor[(m, n), "float32"], w: Tensor[(n, k), "float32"]) -> Tensor:
new_mod =relax.transform.AnnotateOpKind()(mod)
assert new_mod["injective"].attrs["op_pattern"] == 2

def test_annotate_op_kind_bias_add():
@tvm.script.ir_module
class InputModule:
@T.prim_func
def tir_bias_add(rxplaceholder_2: T.Buffer[(1, 1000), "float32"], rxplaceholder_3: T.Buffer[(1000,), "float32"], T_add_1: T.Buffer[(1, 1000), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "tir_bias_add", "tir.noalias": True, "op_pattern": 8})
# body
# with T.block("root")
for i0, i1 in T.grid(1, 1000):
with T.block("T_add"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder_2[ax0, ax1], rxplaceholder_3[ax1])
T.writes(T_add_1[ax0, ax1])
T_add_1[ax0, ax1] = rxplaceholder_2[ax0, ax1] + rxplaceholder_3[ax1]

@R.function
def foo(x: Tensor[(1, 1000), "float32"], y: Tensor[(1000, ), "float32"]) -> Tensor:
gv0 = R.call_tir(tir_bias_add, (x, y), (1, 1000), dtype="float32")
return gv0

mod = InputModule
new_mod =relax.transform.AnnotateOpKind()(mod)
assert new_mod["tir_bias_add"].attrs["op_pattern"] == 1

def test_layout_rewrite():
@tvm.script.ir_module
Expand Down

0 comments on commit 46685cb

Please sign in to comment.