Skip to content

Commit

Permalink
[TVMScript] IR Fragment Printing
Browse files Browse the repository at this point in the history
This PR introduces support for TIR fragment printing
  • Loading branch information
junrushao committed Jan 10, 2023
1 parent d00168f commit bcf2252
Show file tree
Hide file tree
Showing 23 changed files with 865 additions and 886 deletions.
2 changes: 1 addition & 1 deletion include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ inline std::string DLDataType2String(DLDataType t) {
inline DLDataType String2DLDataType(std::string s) {
DLDataType t;
// handle void type
if (s.length() == 0) {
if (s.length() == 0 || s == "void") {
t = DataType::Void();
return t;
}
Expand Down
12 changes: 2 additions & 10 deletions include/tvm/script/printer/ir_docsifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,6 @@ class IRDocsifierNode : public Object {
/*! \brief The name of the variable */
Optional<String> name;
};
/*!
* \brief This map connects IR dispatch token to the name of identifier.
*/
Map<String, String> ir_prefix;
/*!
* \brief The stack of frames.
* \sa FrameNode
Expand All @@ -152,7 +148,6 @@ class IRDocsifierNode : public Object {
std::unordered_map<const Object*, std::vector<const Object*>> common_prefix;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("ir_prefix", &ir_prefix);
v->Visit("frames", &frames);
v->Visit("dispatch_tokens", &dispatch_tokens);
v->Visit("mod", &mod);
Expand Down Expand Up @@ -236,11 +231,8 @@ class IRDocsifierNode : public Object {
class IRDocsifier : public ObjectRef {
public:
using FType = IRDocsifierFunctor<printer::Doc, ObjectPath, IRDocsifier>;
/*!
* \brief Create a IRDocsifier.
* \param ir_prefix The ir_prefix to use for this IRDocsifier.
*/
explicit IRDocsifier(Map<String, String> ir_prefix);
/*! \brief Create a IRDocsifier. */
IRDocsifier();
/*! \brief The registration table for IRDocsifier. */
TVM_DLL static FType& vtable();

Expand Down
13 changes: 7 additions & 6 deletions include/tvm/script/printer/printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ namespace printer {

/*! \brief Default values in the TVMScript printer */
struct Default {
/*! \brief The prefix of IR nodes */
std::unordered_map<std::string, std::string> ir_prefix = {{"ir", "I"}, {"tir", "T"}};
/*! \brief Default data type of TIR buffer */
DataType buffer_dtype = DataType::Float(32);
/*! \brief Default data type of integer literals */
Expand All @@ -43,6 +45,7 @@ struct Default {
DataType float_dtype = DataType::Void();
/*! \brief Returns a singleton of the configuration */
static Default* Instance();
static std::string& Prefix(const std::string& ir) { return Instance()->ir_prefix.at(ir); }
static DataType& BufferDType() { return Instance()->buffer_dtype; }
static DataType& IntDType() { return Instance()->int_dtype; }
static DataType& FloatDType() { return Instance()->float_dtype; }
Expand All @@ -51,18 +54,16 @@ struct Default {
/*!
* \brief The entry method for TVMScript printing
* \param obj The object to be printed
* \param ir_prefix The prefix of IR nodes
* \param indent_spaces Number of spaces used for indentation
* \param print_line_numbers Whether to print line numbers
* \param num_context_lines Number of context lines to print around the underlined text
* \param path_to_underline Object path to be underlined
* \return The TVMScript text format
*/
String Script(ObjectRef obj, //
Map<String, String> ir_prefix = {{"ir", "I"}, {"tir", "T"}}, //
int indent_spaces = 4, //
bool print_line_numbers = false, //
int num_context_lines = -1, //
String Script(ObjectRef obj, //
int indent_spaces = 4, //
bool print_line_numbers = false, //
int num_context_lines = -1, //
Optional<ObjectPath> path_to_underline = NullOpt);

/*!
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,17 +1210,17 @@ def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr,
)


def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None:
def prefetch(buffer: Buffer, bounds: List[Range]) -> None:
"""The prefetch hint for a buffer.
Parameters
----------
buffer : Buffer
The buffer to be prefetched.
indices : List[PrimExpr]
The indices of the buffer to extract.
bounds : List[Range]
The range to be prefetched.
"""
return _ffi_api.Prefetch(buffer, indices) # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Prefetch(buffer, bounds) # type: ignore[attr-defined] # pylint: disable=no-member


def evaluate(value: PrimExpr) -> None:
Expand Down
14 changes: 2 additions & 12 deletions python/tvm/script/printer/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""The printer interface"""

from typing import Mapping, Optional
from typing import Optional

from tvm.runtime.object_path import ObjectPath

Expand All @@ -25,7 +24,6 @@

def script(
obj,
ir_prefix: Optional[Mapping[str, str]] = None,
indent_space: int = 4,
print_line_number: bool = False,
num_context_lines: int = -1,
Expand All @@ -37,9 +35,6 @@ def script(
----------
obj : object
An TVM object representing TVM IR
ir_prefix : Optional[Mapping[str, str]]
A mapping from IR type to the prefix of the script.
Default to {"ir": "I", "tir": T}
indent_space : int = 4
The number of spaces to indent
print_line_number : bool = False
Expand All @@ -54,11 +49,6 @@ def script(
script : str
The TVMScript text format
"""
if ir_prefix is None:
ir_prefix = {
"ir": "I",
"tir": "T",
}
return _ffi_api.Script( # type: ignore # pylint: disable=no-member
obj, ir_prefix, indent_space, print_line_number, num_context_lines, path_to_underline
obj, indent_space, print_line_number, num_context_lines, path_to_underline
)
35 changes: 0 additions & 35 deletions src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,6 @@ TVM_REGISTER_GLOBAL("ir.IntImm").set_body_typed([](DataType dtype, int64_t value

TVM_REGISTER_NODE_TYPE(IntImmNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntImmNode*>(node.get());
if (op->dtype == DataType::Int(32)) {
p->stream << op->value;
} else {
p->stream << "(" << op->dtype << ")" << op->value;
}
});

FloatImm::FloatImm(DataType dtype, double value, Span span) {
ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar.";

Expand Down Expand Up @@ -149,25 +139,6 @@ TVM_REGISTER_GLOBAL("ir.FloatImm").set_body_typed([](DataType dtype, double valu

TVM_REGISTER_NODE_TYPE(FloatImmNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FloatImmNode*>(node.get());
auto& stream = p->stream;
switch (op->dtype.bits()) {
case 64:
stream << op->value;
break;
case 32:
stream << op->value << 'f';
break;
case 16:
stream << op->value << 'h';
break;
default:
LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits();
}
});

Range::Range(PrimExpr begin, PrimExpr end, Span span)
: Range(make_object<RangeNode>(begin, tir::is_zero(begin) ? end : (end - begin), span)) {}

Expand All @@ -183,12 +154,6 @@ TVM_REGISTER_GLOBAL("ir.Range").set_body([](TVMArgs args, TVMRetValue* ret) {

TVM_REGISTER_NODE_TYPE(RangeNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get());
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});

GlobalVar::GlobalVar(String name_hint, Type type, Span span) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint);
Expand Down
22 changes: 0 additions & 22 deletions src/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@ TVM_REGISTER_GLOBAL("ir.PrimType").set_body_typed([](runtime::DataType dtype) {
return PrimType(dtype);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PrimTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PrimTypeNode*>(ref.get());
p->stream << node->dtype;
});

PointerType::PointerType(Type element_type, String storage_scope) {
ObjectPtr<PointerTypeNode> n = make_object<PointerTypeNode>();
n->element_type = std::move(element_type);
Expand All @@ -57,16 +51,6 @@ TVM_REGISTER_GLOBAL("ir.PointerType")
return PointerType(element_type, storage_scope);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PointerTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PointerTypeNode*>(ref.get());
if (!node->storage_scope.empty()) {
p->stream << node->storage_scope << " ";
}
p->Print(node->element_type);
p->stream << '*';
});

TypeVar::TypeVar(String name, TypeKind kind, Span span) {
ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
n->name_hint = std::move(name);
Expand Down Expand Up @@ -148,12 +132,6 @@ TVM_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array<Type> fields) {
return TupleType(fields);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TupleTypeNode*>(ref.get());
p->stream << "TupleTypeNode(" << node->fields << ")";
});

IncompleteType::IncompleteType(TypeKind kind, Span span) {
auto n = make_object<IncompleteTypeNode>();
n->kind = std::move(kind);
Expand Down
11 changes: 10 additions & 1 deletion src/script/printer/doc_printer/python_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,16 @@ void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) {
void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) {
MaybePrintCommentWithNewLine(doc);
output_ << "for ";
PrintDoc(doc->lhs);
if (const auto* tuple = doc->lhs.as<TupleDocNode>()) {
if (tuple->elements.size() == 1) {
PrintDoc(tuple->elements[0]);
output_ << ",";
} else {
PrintJoinedDocs(tuple->elements, ", ");
}
} else {
PrintDoc(doc->lhs);
}
output_ << " in ";
PrintDoc(doc->rhs);
output_ << ":";
Expand Down
11 changes: 4 additions & 7 deletions src/script/printer/ir_docsifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace printer {

String GenerateUniqueName(std::string name_hint, std::unordered_set<String>* defined_names) {
for (char& c : name_hint) {
if (c != 'c' && !std::isalnum(c)) {
if (c != '_' && !std::isalnum(c)) {
c = '_';
}
}
Expand All @@ -39,19 +39,17 @@ String GenerateUniqueName(std::string name_hint, std::unordered_set<String>* def
}

IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const String& name_hint) {
ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj;
String name = GenerateUniqueName(name_hint, &this->defined_names);
DocCreator doc_factory = [name]() { return IdDoc(name); };
auto result = obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}});
ICHECK(result.second) << "Duplicated object: " << obj;
obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}});
IdDoc def_doc(name);
frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });
return def_doc;
}

void IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, DocCreator doc_factory) {
ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj;
ICHECK(!doc_factory()->IsInstance<IdDocNode>())
<< "IRDocsifierNode::Define cannot be used for variable that's mapped to IdDoc.";
obj2info.insert({obj, VariableInfo{std::move(doc_factory), NullOpt}});
frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });
}
Expand Down Expand Up @@ -146,9 +144,8 @@ void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root,
this->common_prefix = std::move(visitor.common_prefix);
}

IRDocsifier::IRDocsifier(Map<String, String> ir_prefix) {
IRDocsifier::IRDocsifier() {
auto n = make_object<IRDocsifierNode>();
n->ir_prefix = std::move(ir_prefix);
n->dispatch_tokens.push_back("");
data_ = std::move(n);
}
Expand Down
9 changes: 3 additions & 6 deletions src/script/printer/printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,10 @@ namespace tvm {
namespace script {
namespace printer {

String Script(ObjectRef obj, Map<String, String> ir_prefix, int indent_spaces,
bool print_line_numbers, int num_context_lines,
String Script(ObjectRef obj, int indent_spaces, bool print_line_numbers, int num_context_lines,
Optional<ObjectPath> path_to_underline) {
IRDocsifier d(ir_prefix);
Doc doc = d->AsDoc(obj, ObjectPath::Root());
return DocToPythonScript(doc, indent_spaces, print_line_numbers, num_context_lines,
path_to_underline);
return DocToPythonScript(IRDocsifier()->AsDoc(obj, ObjectPath::Root()), indent_spaces,
print_line_numbers, num_context_lines, path_to_underline);
}

Default* Default::Instance() {
Expand Down
Loading

0 comments on commit bcf2252

Please sign in to comment.