Skip to content

Commit

Permalink
[TVMScript] IR Fragment Printing (#13742)
Browse files Browse the repository at this point in the history
This PR introduces support for TIR fragment printing.
Fragment printing makes it possible to print TIR fragments in the text
format consistency with TVMScript PrimFunc/IRModule printing.

This PR still preserves the legacy ReprPrinter format by introducing an
API `LegacyTIRPrint` for TIR PrimExpr. This method is used in
AutoScheduler and TIR CSE for full backward compatibility.
  • Loading branch information
junrushao authored Jan 14, 2023
1 parent 60c723e commit c452e69
Show file tree
Hide file tree
Showing 38 changed files with 1,425 additions and 1,105 deletions.
2 changes: 1 addition & 1 deletion include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ inline std::string DLDataType2String(DLDataType t) {
inline DLDataType String2DLDataType(std::string s) {
DLDataType t;
// handle void type
if (s.length() == 0) {
if (s.length() == 0 || s == "void") {
t = DataType::Void();
return t;
}
Expand Down
12 changes: 2 additions & 10 deletions include/tvm/script/printer/ir_docsifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,6 @@ class IRDocsifierNode : public Object {
/*! \brief The name of the variable */
Optional<String> name;
};
/*!
* \brief This map connects IR dispatch token to the name of identifier.
*/
Map<String, String> ir_prefix;
/*!
* \brief The stack of frames.
* \sa FrameNode
Expand All @@ -152,7 +148,6 @@ class IRDocsifierNode : public Object {
std::unordered_map<const Object*, std::vector<const Object*>> common_prefix;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("ir_prefix", &ir_prefix);
v->Visit("frames", &frames);
v->Visit("dispatch_tokens", &dispatch_tokens);
v->Visit("mod", &mod);
Expand Down Expand Up @@ -236,11 +231,8 @@ class IRDocsifierNode : public Object {
class IRDocsifier : public ObjectRef {
public:
using FType = IRDocsifierFunctor<printer::Doc, ObjectPath, IRDocsifier>;
/*!
* \brief Create a IRDocsifier.
* \param ir_prefix The ir_prefix to use for this IRDocsifier.
*/
explicit IRDocsifier(Map<String, String> ir_prefix);
/*! \brief Create a IRDocsifier. */
IRDocsifier();
/*! \brief The registration table for IRDocsifier. */
TVM_DLL static FType& vtable();

Expand Down
17 changes: 11 additions & 6 deletions include/tvm/script/printer/printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <tvm/node/node.h>
#include <tvm/script/printer/ir_docsifier.h>

#include <string>
#include <unordered_map>
#include <vector>

Expand All @@ -31,6 +32,8 @@ namespace printer {

/*! \brief Default values in the TVMScript printer */
struct Default {
/*! \brief The prefix of IR nodes */
std::unordered_map<std::string, std::string> ir_prefix = {{"ir", "I"}, {"tir", "T"}};
/*! \brief Default data type of TIR buffer */
DataType buffer_dtype = DataType::Float(32);
/*! \brief Default data type of integer literals */
Expand All @@ -41,28 +44,30 @@ struct Default {
* T.float32/T.float64 wrapper.
*/
DataType float_dtype = DataType::Void();
/*! \brief Whether or not to verbose print expressions. */
bool verbose_expr = false;
/*! \brief Returns a singleton of the configuration */
static Default* Instance();
static std::string& Prefix(const std::string& ir) { return Instance()->ir_prefix.at(ir); }
static DataType& BufferDType() { return Instance()->buffer_dtype; }
static DataType& IntDType() { return Instance()->int_dtype; }
static DataType& FloatDType() { return Instance()->float_dtype; }
static bool& VerboseExpr() { return Instance()->verbose_expr; }
};

/*!
* \brief The entry method for TVMScript printing
* \param obj The object to be printed
* \param ir_prefix The prefix of IR nodes
* \param indent_spaces Number of spaces used for indentation
* \param print_line_numbers Whether to print line numbers
* \param num_context_lines Number of context lines to print around the underlined text
* \param path_to_underline Object path to be underlined
* \return The TVMScript text format
*/
String Script(ObjectRef obj, //
Map<String, String> ir_prefix = {{"ir", "I"}, {"tir", "T"}}, //
int indent_spaces = 4, //
bool print_line_numbers = false, //
int num_context_lines = -1, //
String Script(ObjectRef obj, //
int indent_spaces = 4, //
bool print_line_numbers = false, //
int num_context_lines = -1, //
Optional<ObjectPath> path_to_underline = NullOpt);

/*!
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,9 @@ class Any : public PrimExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(AnyNode);
};

/*! \brief Legacy ReprPrint format for TIR */
std::string LegacyTIRPrint(const ObjectRef& obj);

/*
* \brief Template function to convert Map to unordered_map
* Sometimes useful for API gluing when internal uses unordered_map
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,17 +1210,17 @@ def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr,
)


def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None:
def prefetch(buffer: Buffer, bounds: List[Range]) -> None:
"""The prefetch hint for a buffer.
Parameters
----------
buffer : Buffer
The buffer to be prefetched.
indices : List[PrimExpr]
The indices of the buffer to extract.
bounds : List[Range]
The range to be prefetched.
"""
return _ffi_api.Prefetch(buffer, indices) # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Prefetch(buffer, bounds) # type: ignore[attr-defined] # pylint: disable=no-member


def evaluate(value: PrimExpr) -> None:
Expand Down
1 change: 1 addition & 0 deletions python/tvm/script/printer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@
This package provides a set of APIs to print supported TVM IR into TVMScript
in a roundtrippable way.
"""
from . import default
from .printer import script
83 changes: 83 additions & 0 deletions python/tvm/script/printer/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The printer configuration"""
from typing_extensions import Literal

from . import _ffi_api


def ir_prefix( # pylint: disable=invalid-name
ir: Literal["ir", "tir"],
prefix: str,
) -> None:
"""Set the prefix for the IR. If not set, the prefix for "tvm.ir" is "I", and for "tir" is "T.
Parameters
----------
ir : str
The IR type, either "ir" or "tir".
prefix : str
The prefix to use.
"""
_ffi_api.DefaultIRPrefix(ir, prefix) # type: ignore # pylint: disable=no-member


def buffer_dtype(dtype: str) -> None:
"""Set the default dtype for buffer. If not set, it is "float32".
Parameters
----------
dtype : str
The default dtype for buffer.
"""
_ffi_api.DefaultBufferDtype(dtype) # type: ignore # pylint: disable=no-member


def int_dtype(dtype: str) -> None:
"""Set the default dtype for integers. If not set, it is "int32".
Parameters
----------
dtype : str
The default dtype for buffer.
"""
_ffi_api.DefaultBufferDtype(dtype) # type: ignore # pylint: disable=no-member


def float_dtype(dtype: str) -> None:
"""Set the default dtype for buffer. If not set, there is no default,
which means every floating point numbers will be wrapped with its precise dtype.
Parameters
----------
dtype : str
The default dtype for buffer.
"""
_ffi_api.DefaultFloatDtype(dtype) # type: ignore # pylint: disable=no-member


def verbose_expr(verbose: bool) -> None:
"""Whether or not to verbose print expressions. If not, the definition of every variable in an
expression will be printed as separate statements. Otherwise, the result will be a one-liner.
Parameters
----------
dtype : str
The default dtype for buffer.
"""
_ffi_api.VerboseExpr(verbose) # type: ignore # pylint: disable=no-member
14 changes: 2 additions & 12 deletions python/tvm/script/printer/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""The printer interface"""

from typing import Mapping, Optional
from typing import Optional

from tvm.runtime.object_path import ObjectPath

Expand All @@ -25,7 +24,6 @@

def script(
obj,
ir_prefix: Optional[Mapping[str, str]] = None,
indent_space: int = 4,
print_line_number: bool = False,
num_context_lines: int = -1,
Expand All @@ -37,9 +35,6 @@ def script(
----------
obj : object
An TVM object representing TVM IR
ir_prefix : Optional[Mapping[str, str]]
A mapping from IR type to the prefix of the script.
Default to {"ir": "I", "tir": T}
indent_space : int = 4
The number of spaces to indent
print_line_number : bool = False
Expand All @@ -54,11 +49,6 @@ def script(
script : str
The TVMScript text format
"""
if ir_prefix is None:
ir_prefix = {
"ir": "I",
"tir": "T",
}
return _ffi_api.Script( # type: ignore # pylint: disable=no-member
obj, ir_prefix, indent_space, print_line_number, num_context_lines, path_to_underline
obj, indent_space, print_line_number, num_context_lines, path_to_underline
)
27 changes: 15 additions & 12 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1270,29 +1270,32 @@ String ComputeDAG::PrintDAG(bool simple_mode) const {
if (pop->body.size() > 1) {
ss << ".v" << k;
}
if (auto preduce = pop->body[k].as<ReduceNode>()) {
ICHECK_LT(k, preduce->combiner->result.size());
PrimExpr combiner = preduce->combiner->result[k];
if (auto p_reduce = pop->body[k].as<ReduceNode>()) {
ICHECK_LT(k, p_reduce->combiner->result.size());
PrimExpr combiner = p_reduce->combiner->result[k];
if (combiner->IsInstance<AddNode>()) {
ss << " += " << preduce->source[0] << "\n";
ss << " += " << LegacyTIRPrint(p_reduce->source[0]) << "\n";
} else if (combiner->IsInstance<MaxNode>()) {
ss << " max= " << preduce->source[0] << "\n";
ss << " max= " << LegacyTIRPrint(p_reduce->source[0]) << "\n";
} else if (combiner->IsInstance<MinNode>()) {
ss << " min= " << preduce->source[0] << "\n";
ss << " min= " << LegacyTIRPrint(p_reduce->source[0]) << "\n";
} else if (combiner->IsInstance<SelectNode>()) {
const auto& select = combiner.as<SelectNode>();
ss << " select(" << select->condition << ", " << select->true_value << ", "
<< select->false_value << ")= " << '(' << preduce->source[0] << ','
<< preduce->source[1] << ")\n";
ss << " select(" << LegacyTIRPrint(select->condition) //
<< ", " << LegacyTIRPrint(select->true_value) //
<< ", " << LegacyTIRPrint(select->false_value) //
<< ")= (" << LegacyTIRPrint(p_reduce->source[0]) //
<< ',' << LegacyTIRPrint(p_reduce->source[1]) //
<< ")\n";
} else {
ss << "reduce" << combiner << "\n";
ss << "reduce" << LegacyTIRPrint(combiner) << "\n";
}
} else {
auto call = pop->body[k].as<CallNode>();
if (simple_mode && call) {
ss << " = " << call->op << "\n";
ss << " = " << LegacyTIRPrint(call->op) << "\n";
} else {
ss << " = " << pop->body[k] << "\n";
ss << " = " << LegacyTIRPrint(pop->body[k]) << "\n";
}
}
}
Expand Down
35 changes: 0 additions & 35 deletions src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,6 @@ TVM_REGISTER_GLOBAL("ir.IntImm").set_body_typed([](DataType dtype, int64_t value

TVM_REGISTER_NODE_TYPE(IntImmNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntImmNode*>(node.get());
if (op->dtype == DataType::Int(32)) {
p->stream << op->value;
} else {
p->stream << "(" << op->dtype << ")" << op->value;
}
});

FloatImm::FloatImm(DataType dtype, double value, Span span) {
ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar.";

Expand Down Expand Up @@ -149,25 +139,6 @@ TVM_REGISTER_GLOBAL("ir.FloatImm").set_body_typed([](DataType dtype, double valu

TVM_REGISTER_NODE_TYPE(FloatImmNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FloatImmNode*>(node.get());
auto& stream = p->stream;
switch (op->dtype.bits()) {
case 64:
stream << op->value;
break;
case 32:
stream << op->value << 'f';
break;
case 16:
stream << op->value << 'h';
break;
default:
LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits();
}
});

Range::Range(PrimExpr begin, PrimExpr end, Span span)
: Range(make_object<RangeNode>(begin, tir::is_zero(begin) ? end : (end - begin), span)) {}

Expand All @@ -183,12 +154,6 @@ TVM_REGISTER_GLOBAL("ir.Range").set_body([](TVMArgs args, TVMRetValue* ret) {

TVM_REGISTER_NODE_TYPE(RangeNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get());
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});

GlobalVar::GlobalVar(String name_hint, Type type, Span span) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint);
Expand Down
Loading

0 comments on commit c452e69

Please sign in to comment.