Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Disable concise scoping when the scope stmt is explicitly annotated #16271

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions src/script/printer/tir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ Doc DoConciseScoping(const Optional<ExprDoc>& lhs, const ExprDoc& rhs, Array<Stm
}
}

bool AllowConciseScoping(const IRDocsifier& d) {
bool AllowConciseScoping(const IRDocsifier& d, const ObjectRef& obj) {
if (d->cfg.defined()) {
if (d->cfg->obj_to_annotate.count(obj)) {
// if the object requires annotation, do not fold this frame
return false;
}
}
ICHECK(!d->frames.empty());
if (const auto* f = d->frames.back().as<TIRFrameNode>()) {
return f->allow_concise_scoping;
Expand Down Expand Up @@ -69,7 +75,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::LetStmt>("", [](tir::LetStmt stmt, ObjectPath p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
bool concise = AllowConciseScoping(d, stmt);
// Step 1. Type annotation
Optional<ExprDoc> type_doc = d->AsDoc<ExprDoc>(stmt->var->type_annotation, //
p->Attr("var")->Attr("type_annotation"));
Expand Down Expand Up @@ -105,7 +111,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::AssertStmt>(
"", [](tir::AssertStmt stmt, ObjectPath p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
bool concise = AllowConciseScoping(d, stmt);
ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition"));
ExprDoc msg = d->AsDoc<ExprDoc>(stmt->message, p->Attr("message"));
With<TIRFrame> f(d, stmt);
Expand All @@ -129,7 +135,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
namespace {
Doc DeclBufferDoc(tir::DeclBuffer stmt, ObjectPath p, IRDocsifier d,
BufferVarDefinition var_definitions) {
bool concise = AllowConciseScoping(d);
bool concise = AllowConciseScoping(d, stmt);
ExprDoc rhs = BufferDecl(stmt->buffer, "decl_buffer", {}, p->Attr("buffer"), d->frames.back(), d,
var_definitions);
With<TIRFrame> f(d, stmt);
Expand Down Expand Up @@ -203,7 +209,7 @@ bool IsAllocateDeclBufferPattern(const tir::AllocateNode* allocate) {
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Allocate>( //
"", [](tir::Allocate stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
bool concise = AllowConciseScoping(d, stmt_p);
if (d->cfg->syntax_sugar && IsAllocateDeclBufferPattern(stmt.get())) {
return DeclBufferDoc(Downcast<tir::DeclBuffer>(stmt->body), stmt_p->Attr("body"), d,
BufferVarDefinition::DataPointer);
Expand Down Expand Up @@ -261,7 +267,7 @@ ExprDoc PrintNDArray(::tvm::runtime::NDArray arr) {
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::AllocateConst>(
"", [](tir::AllocateConst stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
bool concise = AllowConciseScoping(d, stmt);
String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var);
Array<ExprDoc> args;
Array<String> kwargs_keys;
Expand Down Expand Up @@ -379,7 +385,7 @@ ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const ObjectPath& at
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::BufferRealize>( //
"", [](tir::BufferRealize stmt, ObjectPath p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
bool concise = AllowConciseScoping(d, stmt);
ExprDoc rhs = DocsifyBufferRealize(stmt.get(), NullOpt, p, d);
With<TIRFrame> f(d, stmt);
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
Expand All @@ -389,7 +395,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::AttrStmt>( //
"", [](tir::AttrStmt stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
bool concise = AllowConciseScoping(d, stmt);
Optional<ExprDoc> lhs = NullOpt;
Optional<ExprDoc> rhs = NullOpt;
Optional<tir::Var> define_var = NullOpt;
Expand Down
25 changes: 25 additions & 0 deletions tests/python/tvmscript/test_tvmscript_printer_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,28 @@ def main():
T.evaluate(6)
T.evaluate(7) # annotation 7"""
)


def test_disable_concise_scoping_when_scope_annotated():
@T.prim_func
def _func():
x = 1
y = x + 1
T.evaluate(y - 1)

result = _func.with_attr("global_symbol", "main").script(
obj_to_annotate={
_func.body.body: "annotation 1",
}
)
assert (
result
== """# from tvm.script import tir as T

@T.prim_func
def main():
x: T.int32 = 1
# annotation 1
with T.LetStmt(x + 1) as y:
T.evaluate(y - 1)"""
)
Loading