From d329729b23de90058528fcb151ef3318720eb0ad Mon Sep 17 00:00:00 2001 From: wrongtest Date: Thu, 21 Dec 2023 11:22:24 +0000 Subject: [PATCH] disable concise scoping when the scope stmt is explicitly annotated --- src/script/printer/tir/stmt.cc | 22 ++++++++++------ .../test_tvmscript_printer_annotation.py | 25 +++++++++++++++++++ 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 01899d9001fe..beba290581d6 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -37,7 +37,13 @@ Doc DoConciseScoping(const Optional& lhs, const ExprDoc& rhs, Arraycfg.defined()) { + if (d->cfg->obj_to_annotate.count(obj)) { + // if the object requires annotation, do not fold this frame + return false; + } + } ICHECK(!d->frames.empty()); if (const auto* f = d->frames.back().as()) { return f->allow_concise_scoping; @@ -69,7 +75,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::LetStmt stmt, ObjectPath p, IRDocsifier d) -> Doc { - bool concise = AllowConciseScoping(d); + bool concise = AllowConciseScoping(d, stmt); // Step 1. Type annotation Optional type_doc = d->AsDoc(stmt->var->type_annotation, // p->Attr("var")->Attr("type_annotation")); @@ -105,7 +111,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( "", [](tir::AssertStmt stmt, ObjectPath p, IRDocsifier d) -> Doc { - bool concise = AllowConciseScoping(d); + bool concise = AllowConciseScoping(d, stmt); ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); ExprDoc msg = d->AsDoc(stmt->message, p->Attr("message")); With f(d, stmt); @@ -129,7 +135,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) namespace { Doc DeclBufferDoc(tir::DeclBuffer stmt, ObjectPath p, IRDocsifier d, BufferVarDefinition var_definitions) { - bool concise = AllowConciseScoping(d); + bool concise = AllowConciseScoping(d, stmt); ExprDoc rhs = BufferDecl(stmt->buffer, "decl_buffer", {}, p->Attr("buffer"), d->frames.back(), d, var_definitions); With f(d, stmt); @@ -203,7 +209,7 @@ bool IsAllocateDeclBufferPattern(const tir::AllocateNode* allocate) { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::Allocate stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc { - bool concise = AllowConciseScoping(d); + bool concise = AllowConciseScoping(d, stmt_p); if (d->cfg->syntax_sugar && IsAllocateDeclBufferPattern(stmt.get())) { return DeclBufferDoc(Downcast(stmt->body), stmt_p->Attr("body"), d, BufferVarDefinition::DataPointer); @@ -261,7 +267,7 @@ ExprDoc PrintNDArray(::tvm::runtime::NDArray arr) { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( "", [](tir::AllocateConst stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc { - bool concise = AllowConciseScoping(d); + bool concise = AllowConciseScoping(d, stmt); String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var); Array args; Array kwargs_keys; @@ -379,7 +385,7 @@ ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const ObjectPath& at TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::BufferRealize stmt, ObjectPath p, IRDocsifier d) -> Doc { - bool concise = AllowConciseScoping(d); + bool concise = AllowConciseScoping(d, stmt); ExprDoc rhs = DocsifyBufferRealize(stmt.get(), NullOpt, p, d); With f(d, stmt); AsDocBody(stmt->body, p->Attr("body"), f->get(), d); @@ -389,7 +395,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::AttrStmt stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc { - bool concise = AllowConciseScoping(d); + bool concise = AllowConciseScoping(d, stmt); Optional lhs = NullOpt; Optional rhs = NullOpt; Optional define_var = NullOpt; diff --git a/tests/python/tvmscript/test_tvmscript_printer_annotation.py b/tests/python/tvmscript/test_tvmscript_printer_annotation.py index 98e6d7c0596c..fb57ae9ce635 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_annotation.py +++ b/tests/python/tvmscript/test_tvmscript_printer_annotation.py @@ -84,3 +84,28 @@ def main(): T.evaluate(6) T.evaluate(7) # annotation 7""" ) + + +def test_disable_concise_scoping_when_scope_annotated(): + @T.prim_func + def _func(): + x = 1 + y = x + 1 + T.evaluate(y - 1) + + result = _func.with_attr("global_symbol", "main").script( + obj_to_annotate={ + _func.body.body: "annotation 1", + } + ) + assert ( + result + == """# from tvm.script import tir as T + +@T.prim_func +def main(): + x: T.int32 = 1 + # annotation 1 + with T.LetStmt(x + 1) as y: + T.evaluate(y - 1)""" + )