Skip to content

Commit

Permalink
[TVMScript][Fix] Print Multi-line String as Metadata (#13965)
Browse files Browse the repository at this point in the history
Multi-line strings might make less sense to be printed out by default,
as they could be LLVM snippets, CUDA source code and anything hard to
comprehend but easy to mess up with the TVMScript itself. Therefore,
this PR is introduced to print them as metadata by default.
  • Loading branch information
junrushao authored Feb 12, 2023
1 parent 43c2810 commit 09f38ac
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 5 deletions.
2 changes: 2 additions & 0 deletions python/tvm/script/parser/core/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None)
The parsed TVMScript program.
"""
if extra_vars is None:
import tvm # pylint: disable=import-outside-toplevel
from tvm.script.parser import ir # pylint: disable=import-outside-toplevel
from tvm.script.parser import tir # pylint: disable=import-outside-toplevel

extra_vars = {
"tvm": tvm,
"I": ir,
"ir": ir,
"T": tir,
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/script/parser/ir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,14 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
node : doc.ClassDef
The doc AST expression node.
"""


@dispatch.register(token="default", type_name="Assign")
def visit_assign(self: Parser, node: doc.Assign) -> None:
if len(node.targets) != 1:
self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.")
lhs = node.targets[0]
rhs = self.eval_expr(node.value)
self.eval_assign(
target=lhs, source=rhs, bind_value=lambda _a, _b, _c, value: value, allow_shadowing=True
)
3 changes: 3 additions & 0 deletions src/script/printer/ir/misc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ namespace printer {

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<String>("", [](String s, ObjectPath p, IRDocsifier d) -> Doc {
if (HasMultipleLines(s)) {
return d->AddMetadata(s);
}
return LiteralDoc::Str(s, p);
});

Expand Down
6 changes: 5 additions & 1 deletion src/script/printer/tir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::StringImm>("", [](tir::StringImm s, ObjectPath p, IRDocsifier d) -> Doc {
return d->AsDoc<ExprDoc>(s->value, p->Attr("value"));
if (HasMultipleLines(s->value)) {
return d->AddMetadata(s);
} else {
return d->AsDoc<ExprDoc>(s->value, p->Attr("value"));
}
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Expand Down
14 changes: 11 additions & 3 deletions src/script/printer/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include <utility>
#include <vector>

#include "../../support/str_escape.h"

namespace tvm {
namespace script {
namespace printer {
Expand Down Expand Up @@ -76,9 +78,10 @@ inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fra
std::ostringstream os;
if (!d->metadata.empty()) {
if (d->cfg->show_meta) {
os << "metadata = tvm.ir.load_json("
<< SaveJSON(Map<String, ObjectRef>(d->metadata.begin(), d->metadata.end())) << ")"
<< "\n";
os << "metadata = tvm.ir.load_json(\""
<< support::StrEscape(
SaveJSON(Map<String, ObjectRef>(d->metadata.begin(), d->metadata.end())))
<< "\")\n";
} else {
f->stmts.push_back(
CommentDoc("Metadata omitted. Use show_meta=True in script() method to show it."));
Expand Down Expand Up @@ -130,6 +133,11 @@ inline Doc HeaderWrapper(const IRDocsifier& d, const Doc& doc) {
return doc;
}

/*! \brief Check if a string has multiple lines. */
inline bool HasMultipleLines(const std::string& str) {
return str.find_first_of('\n') != std::string::npos;
}

inline Optional<String> GetBindingName(const IRDocsifier& d) {
return d->cfg->binding_names.empty() ? Optional<String>(NullOpt) : d->cfg->binding_names.back();
}
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3630,7 +3630,7 @@ def func():

def test_roundtrip(ir_generator):
original = ir_generator()
after_roundtrip = tvm.script.from_source(original.script())
after_roundtrip = tvm.script.from_source(original.script(show_meta=True))
tvm.ir.assert_structural_equal(original, after_roundtrip, True)


Expand Down

0 comments on commit 09f38ac

Please sign in to comment.