Skip to content

Commit

Permalink
[TIR] Preserve annotations after lower opaque block (apache#12572)
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif authored Aug 31, 2022
1 parent f7cc992 commit f114d55
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 14 deletions.
60 changes: 46 additions & 14 deletions src/tir/transforms/lower_opaque_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ class OpaqueBlockLower : public StmtExprMutator {
}
body = Allocate(buffer->data, buffer->dtype, new_shape, const_true(), std::move(body));
}
// Step 4. Handle annotations, block annotations are not preserved by default.
std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
HandleAnnotations(new_block->annotations, &pragma_attrs, /*is_block=*/true);
for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) {
body = AttrStmt(Integer(0), it->first, it->second, std::move(body));
}
return body;
}

Expand All @@ -72,7 +78,11 @@ class OpaqueBlockLower : public StmtExprMutator {
}
// Step 2. Visit recursively
Stmt body = this->VisitStmt(op->body);
// Step 3. Create new For loop accordingly
// Step 3. Handle annotations
std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
Map<String, ObjectRef> new_annotations =
HandleAnnotations(op->annotations, &pragma_attrs, /*is_block=*/false);
// Step 4. Create new For loop accordingly
if (op->kind == ForKind::kThreadBinding) {
// Case 1. Thread binding
ICHECK(op->thread_binding.defined());
Expand All @@ -83,20 +93,12 @@ class OpaqueBlockLower : public StmtExprMutator {
return body;
} else {
// Case 3. An ordinary loop
body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body));
}
// Step 4. Handle annotations
std::set<std::string> ordered_ann_keys;
for (const auto& annotation : op->annotations) {
ordered_ann_keys.insert(annotation.first);
body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body),
NullOpt, new_annotations);
}
for (auto it = ordered_ann_keys.rbegin(); it != ordered_ann_keys.rend(); ++it) {
const std::string& ann_key = *it;
const ObjectRef& ann_value = op->annotations.at(ann_key);
if (attr::IsPragmaKey(ann_key)) {
body =
AttrStmt(op->loop_var, ann_key, ConvertAttrValue(ann_key, ann_value), std::move(body));
}
// Step 5. Insert nested attrs
for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) {
body = AttrStmt(op->loop_var, it->first, it->second, std::move(body));
}
return body;
}
Expand Down Expand Up @@ -146,8 +148,38 @@ class OpaqueBlockLower : public StmtExprMutator {
}
}

/*!
* \brief Helper to handle annotation dict.
* (1) if the attr key is prefixed by `pragma_`, move to ordered kv list. They
* are lowered to `AttrStmt` by legacy TE schedule convention.
* (2) the non-pragma loop annotations are preserved
* (3) the non-pragma block annotations are dropped
* \return New annotation dict with preserved keys. Also update pragma attr pairs ordered by key.
*/
Map<String, ObjectRef> HandleAnnotations(
const Map<String, ObjectRef>& annotations,
std::vector<std::pair<std::string, PrimExpr>>* pragma_attrs, bool is_block) {
Map<String, ObjectRef> preserved_annotations;
pragma_attrs->clear();
for (const auto& kv : annotations) {
const String& key = kv.first;
if (attr::IsPragmaKey(key)) {
pragma_attrs->emplace_back(key, ConvertAttrValue(key, kv.second));
} else if (!is_block) {
// the loop annotation is preserved
preserved_annotations.Set(key, kv.second);
}
}
std::sort(pragma_attrs->begin(), pragma_attrs->end(),
[](const auto& p1, const auto& p2) { return p1.first < p2.first; });
return preserved_annotations;
}

/*! \brief Record the loop_var and loop start value of unit loops, whose extent is one. */
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> unit_loop_vars_;

/*! \brief Attr keys to preserve into loop annotations. */
std::unordered_set<std::string> preserved_annotations_;
};

PrimFunc LowerOpaqueBlock(PrimFunc f) {
Expand Down
37 changes: 37 additions & 0 deletions tests/python/unittest/test_tir_transform_lower_opaque_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,43 @@ def test_annotated_loops():
tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 0.0))


def test_annotated_block():
@T.prim_func
def annotated_block() -> None:
with T.block():
T.block_attr({"pragma_1": "str_value", "pragma_2": 1, "pragma_3": 0.0})
T.evaluate(0)

mod = tvm.IRModule.from_expr(annotated_block)
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
attr1 = mod["main"].body
attr2 = attr1.body
attr3 = attr2.body
assert attr1.attr_key == "pragma_1" and attr1.value == "str_value"
assert attr2.attr_key == "pragma_2"
tvm.ir.assert_structural_equal(attr2.value, tvm.tir.IntImm("int32", 1))
assert attr3.attr_key == "pragma_3"
tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 0.0))


def test_preserved_annotations():
@T.prim_func
def before(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]):
for i in T.serial(8, annotations={"k_0": 1, "k_1": [2, 3], "k_2": 3.14}):
with T.block("block"):
T.block_attr({"k_3": "oops"})
B[i] = A[i] + 1.0

@T.prim_func
def after(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]):
for i in T.serial(8, annotations={"k_0": 1, "k_1": [2, 3], "k_2": 3.14}):
B[i] = A[i] + 1.0

mod = tvm.IRModule.from_expr(before)
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
tvm.ir.assert_structural_equal(mod["main"], after)


def test_boolean_handling():
_check(boolean_handling_before, boolean_handling_after)

Expand Down

0 comments on commit f114d55

Please sign in to comment.