Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript][Printer] Remove relax prefix for now #14140

Merged
merged 1 commit into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions include/tvm/node/script_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ class PrinterConfigNode : public Object {
std::string ir_prefix = "I";
/*! \brief The prefix of TIR nodes */
std::string tir_prefix = "T";
/*! \brief The prefix of Relax nodes */
std::string relax_prefix = "R";
/*! \brief Default data type of TIR buffer */
DataType buffer_dtype = DataType::Float(32);
/*! \brief Default data type of integer literals */
Expand Down
57 changes: 2 additions & 55 deletions python/tvm/runtime/script_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class PrinterConfig(Object):
show_meta: bool
ir_prefix: str
tir_prefix: str
relax_prefix: str
buffer_dtype: str
int_dtype: str
float_dtype: str
Expand All @@ -53,7 +52,6 @@ def __init__(
show_meta: bool = False,
ir_prefix: str = "I",
tir_prefix: str = "T",
relax_prefix: str = "R",
buffer_dtype: str = "float32",
int_dtype: str = "int32",
float_dtype: str = "void",
Expand All @@ -73,7 +71,6 @@ def __init__(
"show_meta": show_meta,
"ir_prefix": ir_prefix,
"tir_prefix": tir_prefix,
"relax_prefix": relax_prefix,
"buffer_dtype": buffer_dtype,
"int_dtype": int_dtype,
"float_dtype": float_dtype,
Expand Down Expand Up @@ -114,7 +111,6 @@ def script(
show_meta: bool = False,
ir_prefix: str = "I",
tir_prefix: str = "T",
relax_prefix: str = "R",
buffer_dtype: str = "float32",
int_dtype: str = "int32",
float_dtype: str = "void",
Expand All @@ -140,8 +136,7 @@ def script(
The prefix of AST nodes from tvm.ir
tir_prefix : str = "T"
The prefix of AST nodes from tvm.tir
relax_prefix : str = "R"
The prefix of AST nodes from tvm.relax
buffer_dtype : str = "float32"
The default data type of buffer
int_dtype : str = "int32"
Expand Down Expand Up @@ -179,51 +174,6 @@ def script(
show_meta=show_meta,
ir_prefix=ir_prefix,
tir_prefix=tir_prefix,
relax_prefix=relax_prefix,
buffer_dtype=buffer_dtype,
int_dtype=int_dtype,
float_dtype=float_dtype,
verbose_expr=verbose_expr,
indent_spaces=indent_spaces,
print_line_numbers=print_line_numbers,
num_context_lines=num_context_lines,
syntax_sugar=syntax_sugar,
path_to_underline=path_to_underline,
path_to_annotate=path_to_annotate,
obj_to_underline=obj_to_underline,
obj_to_annotate=obj_to_annotate,
),
)

def _relax_script(
self,
*,
name: Optional[str] = None,
show_meta: bool = False,
ir_prefix: str = "I",
tir_prefix: str = "T",
relax_prefix: str = "R",
buffer_dtype: str = "float32",
int_dtype: str = "int32",
float_dtype: str = "void",
verbose_expr: bool = False,
indent_spaces: int = 4,
print_line_numbers: bool = False,
num_context_lines: int = -1,
syntax_sugar: bool = True,
path_to_underline: Optional[List[ObjectPath]] = None,
path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
obj_to_underline: Optional[List[Object]] = None,
obj_to_annotate: Optional[Dict[Object, str]] = None,
) -> str:
return _relax_script(
self,
PrinterConfig(
name=name,
show_meta=show_meta,
ir_prefix=ir_prefix,
tir_prefix=tir_prefix,
relax_prefix=relax_prefix,
buffer_dtype=buffer_dtype,
int_dtype=int_dtype,
float_dtype=float_dtype,
Expand All @@ -248,7 +198,6 @@ def show(
show_meta: bool = False,
ir_prefix: str = "I",
tir_prefix: str = "T",
relax_prefix: str = "R",
buffer_dtype: str = "float32",
int_dtype: str = "int32",
float_dtype: str = "void",
Expand Down Expand Up @@ -279,8 +228,7 @@ def show(
The prefix of AST nodes from tvm.ir
tir_prefix : str = "T"
The prefix of AST nodes from tvm.tir
relax_prefix : str = "R"
The prefix of AST nodes from tvm.relax
buffer_dtype : str = "float32"
The default data type of buffer
int_dtype : str = "int32"
Expand Down Expand Up @@ -316,7 +264,6 @@ def show(
show_meta=show_meta,
ir_prefix=ir_prefix,
tir_prefix=tir_prefix,
relax_prefix=relax_prefix,
buffer_dtype=buffer_dtype,
int_dtype=int_dtype,
float_dtype=float_dtype,
Expand Down
4 changes: 1 addition & 3 deletions src/node/script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ PrinterConfig::PrinterConfig(Map<String, ObjectRef> config_dict) {
if (auto v = config_dict.Get("tir_prefix")) {
n->tir_prefix = Downcast<String>(v);
}
if (auto v = config_dict.Get("relax_prefix")) {
n->relax_prefix = Downcast<String>(v);
}

if (auto v = config_dict.Get("buffer_dtype")) {
n->buffer_dtype = DataType(runtime::String2DLDataType(Downcast<String>(v)));
}
Expand Down
10 changes: 1 addition & 9 deletions src/script/printer/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,6 @@ inline ExprDoc TIR(const IRDocsifier& d, const String& attr) {
return IdDoc(d->cfg->tir_prefix)->Attr(attr);
}

/*! \brief Creates the TIR common prefix, which is by default `T` */
inline ExprDoc Relax(const IRDocsifier& d, const String& attr) {
d->ir_usage.insert("relax");
return IdDoc(d->cfg->relax_prefix)->Attr(attr);
}

inline std::string DType2Str(const runtime::DataType& dtype) {
return dtype.is_void() ? "void" : runtime::DLDataType2String(dtype);
}
Expand All @@ -123,9 +117,7 @@ inline Doc HeaderWrapper(const IRDocsifier& d, const Doc& doc) {
if (d->ir_usage.count("tir")) {
stmts.push_back(CommentDoc("from tvm.script import tir as " + d->cfg->tir_prefix));
}
if (d->ir_usage.count("relax")) {
stmts.push_back(CommentDoc("from tvm.script import relax as " + d->cfg->relax_prefix));
}

stmts.push_back(CommentDoc(""));
stmts.push_back(Downcast<StmtDoc>(doc));
return StmtBlockDoc(stmts);
Expand Down