Skip to content

Commit

Permalink
[TVMScript] Support show_meta (apache#13934)
Browse files Browse the repository at this point in the history
This PR adds the functionality to roundtrip metadata during printing.
Users may turn on the flag below to allow the printer to dump metadata to screen.

```python
ir_node.show(show_meta=True, ...)
```
  • Loading branch information
junrushao authored Feb 8, 2023
1 parent 36f45bb commit 45a92df
Show file tree
Hide file tree
Showing 16 changed files with 336 additions and 152 deletions.
2 changes: 2 additions & 0 deletions include/tvm/node/repr_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class ReprLegacyPrinter {
TVM_DLL void Print(const ObjectRef& node);
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
/*! \brief Could the LegacyPrinter dispatch the node */
TVM_DLL static bool CanDispatch(const ObjectRef& node);
/*! \brief Return the ostream it maintains */
TVM_DLL std::ostream& Stream() const;
// Allow registration to be printer.
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/node/script_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ namespace tvm {

class PrinterConfigNode : public Object {
public:
/*! \brief A stack that tracks the names of the binding hierarchy */
Array<String> binding_names = {};
/*! \brief Whether or not to show metadata. */
bool show_meta = false;
/*! \brief The prefix of IR nodes */
std::string ir_prefix = "I";
/*! \brief The prefix of TIR nodes */
Expand Down Expand Up @@ -71,6 +75,8 @@ class PrinterConfigNode : public Object {
Map<ObjectRef, String> obj_to_annotate = Map<ObjectRef, String>();

void VisitAttrs(AttrVisitor* v) {
v->Visit("binding_names", &binding_names);
v->Visit("show_meta", &show_meta);
v->Visit("ir_prefix", &ir_prefix);
v->Visit("buffer_dtype", &buffer_dtype);
v->Visit("int_dtype", &int_dtype);
Expand Down
9 changes: 5 additions & 4 deletions include/tvm/script/printer/ir_docsifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ class IRDocsifierNode : public Object {
* when converting IR node object to Doc.
*/
Array<String> dispatch_tokens;
/*! \brief The IRModule to be docsifier is handling */
Optional<IRModule> mod;
/*! \brief Mapping from a var to its info */
std::unordered_map<ObjectRef, VariableInfo, ObjectPtrHash, ObjectPtrEqual> obj2info;
/*! \brief Metadata printing */
std::unordered_map<String, Array<ObjectRef>> metadata;
/*! \brief The variable names used already */
std::unordered_set<String> defined_names;
/*! \brief Common prefixes of variable usages */
Expand All @@ -155,8 +155,8 @@ class IRDocsifierNode : public Object {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("frames", &frames);
v->Visit("dispatch_tokens", &dispatch_tokens);
v->Visit("mod", &mod);
// `obj2info` is not visited
// `metadata` is not visited
// `defined_names` is not visited
// `common_prefix` is not visited
// `ir_usage` is not visited
Expand Down Expand Up @@ -204,7 +204,8 @@ class IRDocsifierNode : public Object {
* \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt.
*/
Optional<ExprDoc> GetVarDoc(const ObjectRef& obj) const;

/*! \brief Add a TVM object to the metadata section*/
ExprDoc AddMetadata(const ObjectRef& obj);
/*!
* \brief Check if a variable exists in the table.
* \param obj The variable object.
Expand Down
113 changes: 93 additions & 20 deletions python/tvm/runtime/script_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""Configuration of TVMScript printer"""
from typing import List, Dict, Optional
from typing import Dict, List, Optional, Sequence

from tvm._ffi import register_object
from tvm._ffi import get_global_func, register_object
from tvm.runtime import Object

from . import _ffi_node_api
Expand All @@ -28,6 +28,8 @@
class PrinterConfig(Object):
"""Configuration of TVMScript printer"""

binding_names: Sequence[str]
show_meta: bool
ir_prefix: str
tir_prefix: str
relax_prefix: str
Expand All @@ -47,6 +49,8 @@ class PrinterConfig(Object):
def __init__(
self,
*,
name: Optional[str] = None,
show_meta: bool = False,
ir_prefix: str = "I",
tir_prefix: str = "T",
relax_prefix: str = "R",
Expand All @@ -65,38 +69,49 @@ def __init__(
) -> None:
if num_context_lines is None:
num_context_lines = -1
cfg = {
"show_meta": show_meta,
"ir_prefix": ir_prefix,
"tir_prefix": tir_prefix,
"relax_prefix": relax_prefix,
"buffer_dtype": buffer_dtype,
"int_dtype": int_dtype,
"float_dtype": float_dtype,
"verbose_expr": verbose_expr,
"indent_spaces": indent_spaces,
"print_line_numbers": print_line_numbers,
"num_context_lines": num_context_lines,
"syntax_sugar": syntax_sugar,
"path_to_underline": path_to_underline,
"path_to_annotate": path_to_annotate,
"obj_to_underline": obj_to_underline,
"obj_to_annotate": obj_to_annotate,
}

if name is not None:
cfg["name"] = name
self.__init_handle_by_constructor__(
_ffi_node_api.PrinterConfig, # type: ignore # pylint: disable=no-member
{
"ir_prefix": ir_prefix,
"tir_prefix": tir_prefix,
"relax_prefix": relax_prefix,
"buffer_dtype": buffer_dtype,
"int_dtype": int_dtype,
"float_dtype": float_dtype,
"verbose_expr": verbose_expr,
"indent_spaces": indent_spaces,
"print_line_numbers": print_line_numbers,
"num_context_lines": num_context_lines,
"syntax_sugar": syntax_sugar,
"path_to_underline": path_to_underline,
"path_to_annotate": path_to_annotate,
"obj_to_underline": obj_to_underline,
"obj_to_annotate": obj_to_annotate,
},
_ffi_node_api.PrinterConfig, cfg # type: ignore # pylint: disable=no-member
)


def _script(obj: Object, config: PrinterConfig) -> str:
return _ffi_node_api.TVMScriptPrinterScript(obj, config) # type: ignore # pylint: disable=no-member


def _relax_script(obj: Object, config: PrinterConfig) -> str:
func = get_global_func("script.printer.ReprPrintRelax")
return func(obj, config)


class Scriptable:
"""A base class that enables the script() and show() method."""

def script(
self,
*,
name: Optional[str] = None,
show_meta: bool = False,
ir_prefix: str = "I",
tir_prefix: str = "T",
relax_prefix: str = "R",
Expand All @@ -117,6 +132,10 @@ def script(
Parameters
----------
name : Optional[str] = None
The name of the object
show_meta : bool = False
Whether to print the meta data of the object
ir_prefix : str = "I"
The prefix of AST nodes from tvm.ir
tir_prefix : str = "T"
Expand Down Expand Up @@ -156,6 +175,52 @@ def script(
return _script(
self,
PrinterConfig(
name=name,
show_meta=show_meta,
ir_prefix=ir_prefix,
tir_prefix=tir_prefix,
relax_prefix=relax_prefix,
buffer_dtype=buffer_dtype,
int_dtype=int_dtype,
float_dtype=float_dtype,
verbose_expr=verbose_expr,
indent_spaces=indent_spaces,
print_line_numbers=print_line_numbers,
num_context_lines=num_context_lines,
syntax_sugar=syntax_sugar,
path_to_underline=path_to_underline,
path_to_annotate=path_to_annotate,
obj_to_underline=obj_to_underline,
obj_to_annotate=obj_to_annotate,
),
)

def _relax_script(
self,
*,
name: Optional[str] = None,
show_meta: bool = False,
ir_prefix: str = "I",
tir_prefix: str = "T",
relax_prefix: str = "R",
buffer_dtype: str = "float32",
int_dtype: str = "int32",
float_dtype: str = "void",
verbose_expr: bool = False,
indent_spaces: int = 4,
print_line_numbers: bool = False,
num_context_lines: int = -1,
syntax_sugar: bool = True,
path_to_underline: Optional[List[ObjectPath]] = None,
path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
obj_to_underline: Optional[List[Object]] = None,
obj_to_annotate: Optional[Dict[Object, str]] = None,
) -> str:
return _relax_script(
self,
PrinterConfig(
name=name,
show_meta=show_meta,
ir_prefix=ir_prefix,
tir_prefix=tir_prefix,
relax_prefix=relax_prefix,
Expand All @@ -179,6 +244,8 @@ def show(
style: Optional[str] = None,
black_format: bool = True,
*,
name: Optional[str] = None,
show_meta: bool = False,
ir_prefix: str = "I",
tir_prefix: str = "T",
relax_prefix: str = "R",
Expand All @@ -204,6 +271,10 @@ def show(
`tvm.script.highlight.cprint` for more details.
black_format: bool
If true (default), use the formatter Black to format the TVMScript
name : Optional[str] = None
The name of the object
show_meta : bool = False
Whether to print the meta data of the object
ir_prefix : str = "I"
The prefix of AST nodes from tvm.ir
tir_prefix : str = "T"
Expand Down Expand Up @@ -241,6 +312,8 @@ def show(

cprint(
self.script(
name=name,
show_meta=show_meta,
ir_prefix=ir_prefix,
tir_prefix=tir_prefix,
relax_prefix=relax_prefix,
Expand Down
12 changes: 11 additions & 1 deletion python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

# pylint: disable=unused-import
from tvm.target.codegen import llvm_lookup_intrinsic_id
from tvm.tir import Buffer, BufferRegion, PrimExpr
from tvm.tir import Buffer, BufferRegion, IndexMap, PrimExpr
from tvm.tir import op as _tir_op
from tvm.tir import type_annotation

Expand Down Expand Up @@ -1522,6 +1522,15 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer:
return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity)


def index_map(
mapping: Callable,
*,
inverse_index_map: Optional[Callable] = None,
) -> IndexMap:
"""Create a TIR Index mapping"""
return IndexMap.from_func(mapping, inverse_index_map=inverse_index_map)


def target(target_config: Union[Dict, str]) -> Target:
"""
Create a target
Expand Down Expand Up @@ -1824,6 +1833,7 @@ def wrapped(*args, **kwargs):
"max",
"iter_var",
"comm_reducer",
"index_map",
"target",
"buffer_var",
"abs",
Expand Down
12 changes: 11 additions & 1 deletion src/node/repr_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,20 @@ void ReprLegacyPrinter::Print(const ObjectRef& node) {
} else if (f.can_dispatch(node)) {
f(node, this);
} else {
stream << node; // Use ReprPrinter
try {
stream << node; // Use ReprPrinter
} catch (const tvm::Error& e) {
LOG(WARNING) << "ReprPrinter fails";
stream << node->GetTypeKey() << '(' << node.get() << ')';
}
}
}

bool ReprLegacyPrinter::CanDispatch(const ObjectRef& node) {
static const FType& f = vtable();
return !node.defined() || f.can_dispatch(node);
}

void ReprLegacyPrinter::PrintIndent() {
for (int i = 0; i < indent; ++i) {
stream << ' ';
Expand Down
6 changes: 6 additions & 0 deletions src/node/script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional<Print

PrinterConfig::PrinterConfig(Map<String, ObjectRef> config_dict) {
runtime::ObjectPtr<PrinterConfigNode> n = make_object<PrinterConfigNode>();
if (auto v = config_dict.Get("name")) {
n->binding_names.push_back(Downcast<String>(v));
}
if (auto v = config_dict.Get("show_meta")) {
n->show_meta = Downcast<IntImm>(v)->value;
}
if (auto v = config_dict.Get("ir_prefix")) {
n->ir_prefix = Downcast<String>(v);
}
Expand Down
8 changes: 8 additions & 0 deletions src/relay/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,14 @@ TVM_REGISTER_GLOBAL("relay.ir.FuncWithAttr")
return NullOpt;
});

TVM_REGISTER_GLOBAL("relay.ir.FuncWithoutAttr")
.set_body_typed([](BaseFunc func, String key) -> Optional<Function> {
if (func->IsInstance<relay::FunctionNode>()) {
return WithoutAttr(Downcast<relay::Function>(std::move(func)), key);
}
return NullOpt;
});

TVM_REGISTER_NODE_TYPE(FunctionNode);

TVM_REGISTER_GLOBAL("relay.ir.Function")
Expand Down
30 changes: 17 additions & 13 deletions src/script/printer/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,24 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
}
return lhs_name < rhs_name;
});
ICHECK(!d->mod.defined());
d->mod = mod;
{
With<IRFrame> f(d);
(*f)->AddDispatchToken(d, "ir");
for (const auto& kv : functions) {
GlobalVar gv = kv.first;
BaseFunc func = kv.second;
(*f)->stmts.push_back(d->AsDoc<FunctionDoc>(func, p->Attr("functions")->MapValue(gv)));
With<IRFrame> f(d);
(*f)->AddDispatchToken(d, "ir");
for (const auto& kv : functions) {
GlobalVar gv = kv.first;
BaseFunc func = kv.second;
d->cfg->binding_names.push_back(gv->name_hint);
Doc doc = d->AsDoc(func, p->Attr("functions")->MapValue(gv));
d->cfg->binding_names.pop_back();
if (const auto* stmt_block = doc.as<StmtBlockDocNode>()) {
(*f)->stmts.push_back(stmt_block->stmts.back());
} else if (const auto* stmt = doc.as<StmtDocNode>()) {
(*f)->stmts.push_back(GetRef<StmtDoc>(stmt));
} else {
(*f)->stmts.push_back(Downcast<FunctionDoc>(doc));
}
return ClassDoc(IdDoc("Module"), {IR(d, "ir_module")}, (*f)->stmts);
}
return HeaderWrapper(d, ClassDoc(IdDoc(GetBindingName(d).value_or("Module")),
{IR(d, "ir_module")}, (*f)->stmts));
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Expand Down Expand Up @@ -119,9 +125,7 @@ std::string ReprPrintIRModule(const ObjectRef& mod, const PrinterConfig& cfg) {
return s.value();
}
}
IRDocsifier d(cfg);
Doc doc = HeaderWrapper(d, d->AsDoc(mod, ObjectPath::Root()));
return DocToPythonScript(doc, cfg);
return ReprPrintIR(mod, cfg);
}

TVM_SCRIPT_REPR(TypeVarNode, ReprPrintIR);
Expand Down
Loading

0 comments on commit 45a92df

Please sign in to comment.