Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][TVMScript] Optionally hide StructInfo that can be inferred #16356

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions include/tvm/node/script_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,40 @@ class PrinterConfigNode : public Object {
bool syntax_sugar = true;
/*! \brief Whether variable names should include the object's address */
bool show_object_address = false;

/*! \brief In Relax, whether to show all StructInfo annotations
*
* If true (default), all variable bindings will be annotated with
* the struct info of the variable being bound.
*
* If false, the annotations will only be shown when they are
* required for correct parsing of the Relax function. For example,
* function parameters must always have struct info annotations, but
* the struct info for expressions within a function body may be inferred from their
* arguments, and are therefore
*
* Example:
*
* # func.show(show_all_struct_info=True)
* @R.function
* def func(
* A: R.Tensor((10, 20), dtype="float32"),
* B: R.Tensor((10,20), dtype="float32"),
* ) -> R.Tensor((10, 20), dtype="float32"):
* C: R.Tensor((10,20), dtype="float32") = R.add(A, B2)
* return C
*
* # func.show(show_all_struct_info=False)
* @R.function
* def func(
* A: R.Tensor((10, 20), dtype="float32"),
* B: R.Tensor((10,20), dtype="float32"),
* ) -> R.Tensor((10, 20), dtype="float32"):
* C = R.add(A, B2)
* return C
*/
bool show_all_struct_info = true;

/* \brief Object path to be underlined */
Array<ObjectPath> path_to_underline = Array<ObjectPath>();
/*! \brief Object path to be annotated. */
Expand All @@ -97,6 +131,7 @@ class PrinterConfigNode : public Object {
v->Visit("num_context_lines", &num_context_lines);
v->Visit("syntax_sugar", &syntax_sugar);
v->Visit("show_object_address", &show_object_address);
v->Visit("show_all_struct_info", &show_all_struct_info);
v->Visit("path_to_underline", &path_to_underline);
v->Visit("path_to_annotate", &path_to_annotate);
v->Visit("obj_to_underline", &obj_to_underline);
Expand Down
24 changes: 20 additions & 4 deletions python/tvm/runtime/script_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class PrinterConfig(Object):
num_context_lines: int
syntax_sugar: bool
show_object_address: bool
show_all_struct_info: bool
path_to_underline: Optional[List[ObjectPath]]
path_to_annotate: Optional[Dict[ObjectPath, str]]
obj_to_underline: Optional[List[Object]]
Expand All @@ -67,6 +68,7 @@ def __init__(
num_context_lines: Optional[int] = None,
syntax_sugar: bool = True,
show_object_address: bool = False,
show_all_struct_info: bool = True,
path_to_underline: Optional[List[ObjectPath]] = None,
path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
obj_to_underline: Optional[List[Object]] = None,
Expand All @@ -89,6 +91,7 @@ def __init__(
"num_context_lines": num_context_lines,
"syntax_sugar": syntax_sugar,
"show_object_address": show_object_address,
"show_all_struct_info": show_all_struct_info,
"path_to_underline": path_to_underline,
"path_to_annotate": path_to_annotate,
"obj_to_underline": obj_to_underline,
Expand Down Expand Up @@ -132,6 +135,7 @@ def script(
num_context_lines: int = -1,
syntax_sugar: bool = True,
show_object_address: bool = False,
show_all_struct_info: bool = True,
path_to_underline: Optional[List[ObjectPath]] = None,
path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
obj_to_underline: Optional[List[Object]] = None,
Expand Down Expand Up @@ -169,9 +173,13 @@ def script(
num_context_lines : int = -1
The number of lines of context to print before and after the line to underline.
syntax_sugar: bool = True
Whether to output with syntax sugar, set false for complete printing.
Whether to output with syntax sugar, set false for complete printing.
show_object_address: bool = False
Whether to include the object's address as part of the TVMScript name
Whether to include the object's address as part of the TVMScript name
show_all_struct_info: bool = True
If True (default), annotate all variable bindings with the struct
info of that variable. If False, only add annotations where
required for unambiguous round-trip of Relax -> TVMScript -> Relax.
path_to_underline : Optional[List[ObjectPath]] = None
Object path to be underlined
path_to_annotate : Optional[Dict[ObjectPath, str]] = None
Expand All @@ -185,6 +193,7 @@ def script(
-------
script : str
The TVM Script of the given TVM IR

"""
return _script(
self,
Expand All @@ -204,6 +213,7 @@ def script(
num_context_lines=num_context_lines,
syntax_sugar=syntax_sugar,
show_object_address=show_object_address,
show_all_struct_info=show_all_struct_info,
path_to_underline=path_to_underline,
path_to_annotate=path_to_annotate,
obj_to_underline=obj_to_underline,
Expand Down Expand Up @@ -279,6 +289,7 @@ def show(
num_context_lines: int = -1,
syntax_sugar: bool = True,
show_object_address: bool = False,
show_all_struct_info: bool = True,
path_to_underline: Optional[List[ObjectPath]] = None,
path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
obj_to_underline: Optional[List[Object]] = None,
Expand Down Expand Up @@ -339,9 +350,13 @@ def show(
num_context_lines : int = -1
The number of lines of context to print before and after the line to underline.
syntax_sugar: bool = True
Whether to output with syntax sugar, set false for complete printing.
Whether to output with syntax sugar, set false for complete printing.
show_object_address: bool = False
Whether to include the object's address as part of the TVMScript name
Whether to include the object's address as part of the TVMScript name
show_all_struct_info: bool = True
If True (default), annotate all variable bindings with the struct
info of that variable. If False, only add annotations where
required for unambiguous round-trip of Relax -> TVMScript -> Relax.
path_to_underline : Optional[List[ObjectPath]] = None
Object path to be underlined
path_to_annotate : Optional[Dict[ObjectPath, str]] = None
Expand Down Expand Up @@ -377,6 +392,7 @@ def show(
num_context_lines=num_context_lines,
syntax_sugar=syntax_sugar,
show_object_address=show_object_address,
show_all_struct_info=show_all_struct_info,
path_to_underline=path_to_underline,
path_to_annotate=path_to_annotate,
obj_to_underline=obj_to_underline,
Expand Down
3 changes: 3 additions & 0 deletions src/node/script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ PrinterConfig::PrinterConfig(Map<String, ObjectRef> config_dict) {
if (auto v = config_dict.Get("show_object_address")) {
n->show_object_address = Downcast<IntImm>(v)->value;
}
if (auto v = config_dict.Get("show_all_struct_info")) {
n->show_all_struct_info = Downcast<IntImm>(v)->value;
}

// Checking prefixes if they are valid Python identifiers.
CHECK(IsIdentifier(n->ir_prefix)) << "Invalid `ir_prefix`: " << n->ir_prefix;
Expand Down
5 changes: 4 additions & 1 deletion src/script/printer/relax/binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
"", [](relax::MatchCast n, ObjectPath n_p, IRDocsifier d) -> Doc {
using relax::StructInfo;
using relax::MatchStructInfo;
Optional<ExprDoc> ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value);
Optional<ExprDoc> ann = NullOpt;
if (d->cfg->show_all_struct_info) {
ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value);
}
ExprDoc rhs = Relax(d, "match_cast")
->Call({d->AsDoc<ExprDoc>(n->value, n_p->Attr("value")),
d->AsDoc<ExprDoc>(n->struct_info, n_p->Attr("struct_info_"))});
Expand Down
39 changes: 39 additions & 0 deletions src/script/printer/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#ifndef TVM_SCRIPT_PRINTER_RELAX_UTILS_H_
#define TVM_SCRIPT_PRINTER_RELAX_UTILS_H_

#include <tvm/relax/analysis.h>
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/type.h>
#include <tvm/relax/utils.h>
Expand Down Expand Up @@ -82,10 +84,47 @@ inline Optional<ExprDoc> StructInfoAsAnn(const relax::Var& v, const ObjectPath&
if (!v->struct_info_.defined()) {
return NullOpt;
}
bool attempt_to_hide_struct_info = !d->cfg->show_all_struct_info;

if (const auto* call = rhs.as<relax::CallNode>()) {
static const Op& call_tir_op = Op::Get("relax.call_tir");
static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
if (call->op.same_as(call_tir_op) || call->op.same_as(call_dps_packed_op)) {
attempt_to_hide_struct_info = true;
}
}
if (attempt_to_hide_struct_info) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I wonder if this is more complicated than what we would really want. Perhaps a syntactic rule for when to write StructInfo might make more sense? E.g., have it for op calls but not tuple indices, etc.? I'm not sure what situations are sufficiently non-obvious to require an annotation if a user wants to hide them. It's not wrong, but it seems a little unusual for the TVMScript printer to be running type inference itself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My worry with having separate rules for it is that those rules could diverge from the actual struct inference. Even when the syntactic rules allow the struct info to be hidden, it would still need to compare against the actual struct info to see if the annotations provide different information. That comparison would require StructInfo to compare against, and so this function would effectively be another implementation of InferStructInfo, but one which could get out of sync with the canonical version.

By implementing it in terms of the struct inference, the same inference rules are used for both parsing and printing. Just as parsing can fill in missing information by calling the type inference, the printing can check if information can be omitted by calling the type inference.

Copy link
Contributor

@slyubomirsky slyubomirsky Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a little like overkill, but we are running the normalizer between every pass anyway so I guess it's not so bad.

e: And you're right, it's futureproof too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, probably overkill, but I tend to be more okay with overkill for non-default behavior. Good point on the comparison with normalizing on each pass.

Optional<relax::StructInfo> inferred_sinfo = NullOpt;
if (auto opt = rhs.as<relax::Call>()) {
auto call = opt.value();
if (auto opt = call->op.as<Op>()) {
auto op = opt.value();

static auto op_map_infer_struct_info =
Op::GetAttrMap<relax::FInferStructInfo>("FInferStructInfo");

auto temp_builder = relax::BlockBuilder::Create(NullOpt);
inferred_sinfo = op_map_infer_struct_info[op](call, temp_builder);
} else if (auto opt = call->op.as<relax::FuncStructInfo>()) {
auto temp_builder = relax::BlockBuilder::Create(NullOpt);
inferred_sinfo =
DeriveCallRetStructInfo(opt.value(), call, temp_builder, temp_builder->GetAnalyzer());
}

} else if (const auto* tuple = rhs.as<relax::TupleNode>()) {
inferred_sinfo = relax::TupleStructInfo(tuple->fields.Map(relax::GetStructInfo));

} else if (const auto* get_item = rhs.as<relax::TupleGetItemNode>()) {
if (auto ptr = get_item->tuple->struct_info_.as<relax::TupleStructInfoNode>();
ptr && get_item->index < static_cast<int>(ptr->fields.size())) {
inferred_sinfo = ptr->fields[get_item->index];
}

} else if (const auto* trivial_binding = rhs.as<relax::VarNode>()) {
inferred_sinfo = trivial_binding->struct_info_.as<relax::StructInfo>();
}

if (inferred_sinfo && StructuralEqual()(inferred_sinfo, v->struct_info_)) {
return NullOpt;
}
}
Expand Down
49 changes: 49 additions & 0 deletions tests/python/relax/test_tvmscript_printer_relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,5 +829,54 @@ def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype
)


def test_hide_inferable_struct_info():
"""Redundant type annotations can be omitted

When `show_all_struct_info=False`, TVMScript type annotations that
provide redundant struct info can be omitted.
"""

@R.function
def func(A: R.Tensor([10, 20], "float32"), B: R.Tensor(ndim=2, dtype="float32")):
# R.match_cast has the struct info as an argument, so it can
# be omitted from the variable annotation.
B2 = R.match_cast(B, R.Tensor([10, 20], "float32"))

# Call nodes may have inferable shapes from their arguments.
C = R.add(A, B2)

# Trivial bindings can be inferred to have the same struct
# info as the RHS.
D = C

# Here, the struct info cannot be omitted. `R.add(D,B)` has
# struct info `R.Tensor(ndim=2)`, but the variable has a shape
# `R.Tensor([10,20])`. This is compatible, so it is not an
# error to have this annotation, but it is not inferrable from
# the RHS. Therefore, it must still be printed.
E: R.Tensor([10, 20], "float32") = R.add(D, B)

# The return type can be inferred from function body, but is
# still always printed in the TVMScript. When parsing an
# IRModule with functions calling each other, the return type
# of each callee must be available for use in the caller's
# shape inference.
return E

_assert_print(
func.script(show_all_struct_info=False),
"""
# from tvm.script import relax as R

@R.function
def func(A: R.Tensor((10, 20), dtype="float32"), B: R.Tensor(dtype="float32", ndim=2)) -> R.Tensor((10, 20), dtype="float32"):
B2 = R.match_cast(B, R.Tensor((10, 20), dtype="float32"))
C = R.add(A, B2)
D = C
E: R.Tensor((10, 20), dtype="float32") = R.add(D, B)
return E""",
)


if __name__ == "__main__":
tvm.testing.main()
42 changes: 41 additions & 1 deletion tests/python/tvmscript/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tvm
import tvm.testing
from tvm import tir
from tvm.script import tir as T, ir as I
from tvm.script import tir as T, ir as I, relax as R

import numpy as np

Expand Down Expand Up @@ -3996,6 +3996,24 @@ def func():
yield make_ir_generator(op, arg)


def relax_extern_func():
@R.function
def func(A: R.Tensor([10, 20], "float32")):
func = R.ExternFunc("dummy_func")

B: R.Tensor([10, 20], "float32") = R.call_dps_packed(
func, [A], out_sinfo=R.Tensor([10, 20], "float32")
)

C: R.Tensor(ndim=2, dtype="float32") = R.call_dps_packed(
func, [B], out_sinfo=R.Tensor([10, 20], "float32")
)

return C

return func


ir_generator = tvm.testing.parameter(
launch_env_thread,
opt_gemm_normalize,
Expand Down Expand Up @@ -4081,13 +4099,35 @@ def func():
*op_of_literal(),
)

relax_ir_generator = tvm.testing.parameter(
relax_extern_func,
)

show_all_relax_struct_info = tvm.testing.parameter(
by_dict={
"show_all_struct_info": True,
"hide_inferable_struct_info": False,
}
)


def test_roundtrip(ir_generator):
original = ir_generator()
after_roundtrip = tvm.script.from_source(original.script(show_meta=True))
tvm.ir.assert_structural_equal(original, after_roundtrip, True)


def test_relax_roundtrip(relax_ir_generator, show_all_relax_struct_info):
original = relax_ir_generator()
after_roundtrip = tvm.script.from_source(
original.script(
show_meta=True,
show_all_struct_info=show_all_relax_struct_info,
)
)
tvm.ir.assert_structural_equal(original, after_roundtrip, True)


def test_return_none_no_trailing_type():
func = return_none()
script = func.script()
Expand Down
Loading