Skip to content

Latest commit

 

History

History
544 lines (443 loc) · 20.6 KB

0074-tvmscript-unified-printer.md

File metadata and controls

544 lines (443 loc) · 20.6 KB

Summary

This RFC proposes to modularize and infrastructuralize the existing TVMScript printer, to develop unified printing mechanism across TVM stack, where TIR, Relax and any future vendor-specific IR are all treated equally as dialects and could be printed together without potential conflict in engineering.

Motivation

TVMScript, as a roundtrippable python-based text format, is the central piece of TVM performance productivity. As the frontend of TVM, it enables end users to directly construct the TVM IR, either TIR or Relax, in a pragmatic approach. From Relax to MetaSchedule and TIR, TVMScript enables inspectability and reproducibility at any level of compilation and optimization. Furthermore, based on TVMScript, developers are empowered to intercept, manipulate and customize the compiler behavior in a principled way.

While TVMScript is gaining traction and buy-in from the open source community, the TVMScript printer suffers from multiple profound design issues:

  • Not supporting IR fragment printing requires users to jump in-between TVMScript syntax and TIRText syntax
  • The lack of modularity leads to practical inability to scale up to and maintain multiple IRs without engineering conflicts
  • Enhancing co-existence of multi-level IRs often leads to re-engineering of existing features.

Goal. This RFC introduces Tvmscript UNIfied Printer (TUNIP), a systematic redesign to address those engineering, usability and scalability issues above. The goal of this re-design includes:

Goal 1 [Unified Representation]. Become the unified roundtrippable representation of TIR and Relax, allowing systematic mixing of IRs or IR fragments (Relax + TIR) in the same IRModule in the target language (for example, python, C++).

Currently TVMScript priner is designed specifically for TIR, and printing multiple dialects together was not a design goal at that time. Therefore, supporting Relax requires ad-hoc hack around the system (for instance, relax#149 added support of printing T.cast and T.max in an ad-hoc way, without reusing the printing code for TIR). The unified printer in this RFC addresses this issue by having a unified approach for printing IR tree to TVMScript. Engineers will be able to implement a fully-fledged printer for Relax, TIR and any potential IR in the future with minimal effort.

The folder structure that we want to pursue is:

include/tvm/script/printer/
└── ... # Public headers for the core infra
src/script/printer/
├── core # Core infra, which is IR-agnostic
│   ├── ir_docsifier.cc
│   └── ...
├── tir # TIR dialect 
│   ├── expr.cc
│   ├── stmt.cc
│   └── ...
└── relax # Hypothetical Relax dialect (not part of our RFC)
    └── ...

Goal 2 [Third-Party IRs in Multi-Stage Compilation]. Modularize and infrastructuralize the printer to support more future IRs or third-party IRs at any level with maintainability, for example, IRs at lower-level than TIR, or Relax VM executable.

The current TVMScript printer is tightly coupled with TIR by being a subclass of TIR-specific functors (link). This design isn’t scalable when we want to support more IRs. More importantly, it’s impossible for the current approach to support third-party IR bteing registered in a dynamic library.

Goal 3 [Reproducibility and Error Reporting]. Expand reproducibility and flexible rendering of diagnostic messages during any level of IR transformation.

For example, the following snippet runs and produces an error.

import tvm

@T.prim_func
def func_a(A: T.Buffer[(1,), "int32"]):
    A[0] = 0

@T.prim_func
def func_b(A: T.Buffer[(8,), "int32"]):
    A[0] = 0

tvm.ir.assert_structural_equal(func_a, func_b)

The current error message indicates what the difference was, but not where it occurred. This can sometimes be inferred from a stack trace, but becomes increasingly difficult with larger IR graphs.

ValueError: StructuralEqual check failed, caused by lhs:
1
and rhs:
8

TUNIP should enable individual utilities and IR passes to have error messages directing the user to exact locations in the IR representation.

ValueError: StructuralEqual check failed, first delta highlighted below

@T.prim_func
def func_a(A: T.Buffer[(1,), "int32"]) -> None:
                       ^^^^
    A[0] = 0

@T.prim_func
def func_b(A: T.Buffer[(8,), "int32"]) -> None:
                       ^^^^
    A[0] = 0

Guide-level explanation

This section introduces the design philosophy of the printer, and demonstrates the proposed user-facing APIs where users means IR developers.

Two-Stage Translation

Traditionally in TVM stack, printing is a single-stage process. The printer assumes certain syntax of the target language, and therefore, so far there are 3 different printers all for TIR: ReprPrinter, TIRTextPrinter, TVMScriptPrinter.

We extend the idea of the existing Doc class at src/printer/doc.h#L67 to allow better consistency and scalability. An IR, which could be TIR, Relax or any other ones developed by third-party vendors, is first translated to an intermediate Doc node tree, and then the Doc tree is mapped to a target language, for example, Python, C++ IRBuilder API, or Rust.

Stage 1 [TVM IR => Doc]. On the first stage, the printer needs to take care of translating a TVM IR to Doc tree. As an example, tir.For is translated to ForDoc without having to worry about the underlying language. Note that some complicated nodes in TVM IR, for example, PrimFunc, could be translated to multiple IR elements, including FunctionDoc and a few StmtDoc.

During the translation from IR to Doc tree, it is possible that some statement influences the syntax of its children or vice verse, especially for syntactic sugars and declaring undefined variables in IR fragment printing. Therefore, a generic data structure Frame is introduced to allow retrieval and manipulation the relevant context information.

Stage 2. [Doc => target language]. On the second stage, Doc tree is then honestly translated to the target language in text format. For example, when the target language is python, ForDoc is translated to python’s for loop syntax:

for ... in ...:
  ...

When the target language becomes python IRBuilder, ForDoc is translated to:

with T.For(...):
  ...

For generality, the Doc tree is designed to select minimal elements that exist in languages used in developing TVM. A full spec of the Doc could be found in the next section.

Distributed Registration

As a major engineering challenge for TVMScript to scale to multiple IRs, the existing printing logic has to be engineered, maintained and re-engineered in a single file, which has brought significant confusion for developing multi-level IRs for TVM Unity.

Inspired by the pass infrastructure, as well as the ReprPrinter in TVM, we propose to develop the infrastructure to enable distributed registration, and further allows printer for different levels of IR to be registered in separate translation units, and in the meantime keeps the capability to be mixed together at various level, for example, Relax uses TIR expression in its function bodies, and TIR calls back to Relax function.

Diagnostics and Reproducibility

Existing error reporting mechanisms have not taken IR structure and reproducibility into consideration. Usually it reports a single line error message without providing necessary context of how the IR looks like during compilation. For example, when comparing whether two TIRs are structurally equivalent, the system may report:

ValueError: StructuralEqual check failed, caused by lhs:
{slow_memory_3_var: buffer(slow_memory_3_buffer_var, 0x501bf80), fast_memory_2_var: buffer(fast_memory_2_buffer_var, 0x501bd80), placeholder_3: buffer(placeholder_5, 0x50138a0), placeholder_2: buffer(placeholder_4, 0x5012b60), T_subtract: buffer(T_subtract_1, 0x5014390)}
and rhs:
{}

which lacks necessary information for users to understand where the mismatch is.

As a recent effort, structural error reporting in TIR scheduling provides relevant and reproducible context, as demonstrated below:

@tvm.script.ir_module
class Module:
    @tir.prim_func
    def main(a: tir.handle, b: tir.handle) -> None:
        A = tir.match_buffer(a, [128, 128, 128, 128], dtype="float32")
        B = tir.match_buffer(b, [128, 128, 128, 128], dtype="float32")
        # body
        # with tir.block("root")
        for i, j, k, l in tir.grid(128, 128, 128, 8):
            tir.Block#0
            with tir.block("B"):
            ^^^^^^^^^^^^^^^^^^^^
                vi, vj, vk = tir.axis.remap("SSS", [i, j, k])
                vl = tir.axis.spatial(128, l * 16)
                tir.reads([A[vi, vj, vk, vl]])
                tir.writes([B[vi, vj, vk, vl]])
                B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * tir.float32(2)

Error: ...

However, the underlying mechanism supports only S-TIR and error reporting on tir.ForNode and tir.BlockNode, and is less extensible for generic cases.

To generalize this UX across the TVM stack, during the first stage in translation, the following steps is additionally executed:

  • Each Doc node is optionally attached to a node in TVM IR
  • After the 1st stage is finished, collect all IR nodes that gets attached to Doc into a map, whose key is IR node and value is a list of Doc nodes.
  • For each IR node that has diagnostic message, trace back through its parent until it reaches to an IR node in the map collected in previous step. Then it can produce a map from Doc node to diagnostic message.
  • In the 2nd stage, diagnostic message will be printed as doc is being printed into target language

Reference-level explanation

Doc Spec

The design of the Doc is to have a unified representation of TVMScript in different languages. The overall structure is simplied from Python ast, and their meaning is straightforward.

Doc(Optional<ObjectRef> source) # Base class for doc

# Expression
ExprDoc() # Base class for expression
LiteralDoc(Union[IntImm, FloatImm, String, nullptr_t] value) 
IdDoc(String name)
AttrAccessDoc(ExprDoc value, String attr)
IndexDoc(ExprDoc value, Array<Union<ExprDoc, SliceDoc>> indices) 
CallDoc(ExprDoc callee, Array<ExprDoc> args, Array<String> kwargs_keys, Array<ExprDoc> kwargs_values)
OperationDoc(OperationKind kind, Array<ExprDoc> operands)
LambdaDoc(Array<IdDoc> args, ExprDoc body)
TupleDoc(Array<ExprDoc> elements)
ListDoc(Array<ExprDoc> elements)
DictDoc(Array<ExprDoc> keys, Array<ExprDoc> values)

# Statements
StmtDoc(Array<String> comments) # Base class
AssignDoc(ExprDoc lhs, Optional<ExprDoc> rhs, Optional<ExprDoc> annotation)
IfDoc(ExprDoc predicate, Array<StmtDoc> then_branch, Array<StmtDoc> else_branch)
WhileDoc(ExprDoc predicate, Array<StmtDoc> body)
ForDoc(ExprDoc lhs, ExprDoc rhs, Array<StmtDoc> body)
ScopeDoc(Optional<ExprDoc> lhs, ExprDoc rhs, Array<StmtDoc> body)
ExprStmtDoc(ExprDoc expr)

# Special Docs
SliceDoc(Optional<ExprDoc> start, Optional<ExprDoc> stop)
FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators, ExprDoc return_type, Array<StmtDoc> body))
ClassDoc(IdDoc name, Array<ExprDoc> decorators, Array<AssignDoc> aliases, Array<FunctionDoc> functions)

IRDocsifier Spec

IRDocsifier is responsible for transforming IR node tree into Doc tree. Its API looks like

class IRDocsifierNode : public Object {
 public:
  // ir_prefix maintains a map from dispatch_token to ir prefix
  // so that the print function can construct an expression with
  // the current ir prefix, like `T.xxx` in TIR and `R.xxx` in Relax 
  Map<String, String> ir_prefix;
  // TranslationTable maintains a map from IR node to Doc
  // It will be updated when new variable gets into the scope, 
  // like when print PrimFunc or BlockRealize
  // It will be looked up when printing variable nodes like tir::Var and tir::Buffer
  TranslationTable translation_table;
  Array<Frame> frames;
  Array<String> dispatch_tokens;

  /*!
   * \brief Transform the input object into TDoc
   */
  template <class TDoc>
  TDoc AsDoc(const ObjectRef& obj);

  /*!
   * \brief Push a new dispatch token into the stack
   * \details The top dispatch token decides which dispatch table to use
   *          when printing Object. This method returns a RAII guard which
   *          pops the token when going out of the scope.
   */
  WithCtx WithDispatchToken(const String& token);

  /*!
   * \brief Push a new frame the stack
   * \details Frame contains the contextual information that's needed during printing,
   *          for example, variables in the scope. This method returns a RAII guard which
   *          pops the frame and call the cleanup method of frame when going out of the scope.
   */
  WithCtx WithFrame(const Frame& frame);

  /*!
   * \brief Get the top frame with type FrameType
   */
  template <typename FrameType>
  Optional<FrameType> GetFrame() const;
}

To register print function to the IRDocsifier, one should use the TVM_STATIC_IR_FUNCTOR macro and the set_dispatch method of the ObjectFunctor

  • Registration of printing methods for IR nodes
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
    .set_dispatch<PrimType>("tir", [](PrimType ty, IRDocsifier p) -> Doc {
      using runtime::DLDataType2String;
      return TIR(p)->Attr(DLDataType2String(ty->dtype));
    });

// Explanation:
// 1. Here we register the print function of the PrimType node in TIR
// 2. The first arg to the `set_dispatch` function is the dispatch token
//    It's optional and represents the name of IR
// 3. The first argument to the print function is the node to be printed
// 4. The second argument is instance of `IRDocsifier`, which can be used
//    to recursively translate the child nodes.
// 5. The print method returns a subclass of Doc

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch<Range>([](Range e, IRDocsifier p) {
  return SliceDoc(p->AsExprDoc(e->min), p->AsExprDoc(e->min + e->extent));
});

// The first arg to the `set_dispatch` can be omitted, and 
// the print function will be registered the default layer.
// It will be called by default and can be overriden by registering
// another print function under an IR name. 

// This function will be called instead of the previous one, 
// if Printer is printing relax.
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch<Range>("relax", [](Range e, IRDocsifier p) {
  ...
});
  • Dispatch
auto tir_dispatch_ctx = ir_docsifier->WithDispatchToken("tir");
Doc doc = ir_docsifier->AsDoc<Doc>(node);

// Here we setup the ir_docsifier to call print functions under 
// the 'tir' dispatch token, and then call the AsDoc method to 
// translate `node`, as an ObjectRef, into `Doc`, by using the 
// print functions registered in the dispatch table.

template <class TDoc>
TDoc AsDoc(const ObjectRef& obj) const {
  return Downcast<TDoc>(AsDocImpl(obj));
}

Frame Spec

Frame provides the contextual information during printing. Most commonly, frame contains variable defined in the current scope (like tir function, tir block, tir loop). A subclass of Frame can be created to store more specific information. For instance, tir::ForLoopFrame should contain the information about the TIR for loop in order to print iter var remapping when printing BlockRealize.

class FrameNode : public Object {
 public:
  Array<ObjectRef> objs;
  TranslationTableNode* translation_table;

  /*!
   * \brief Set the name of a variable IR node
   */
	virtual IdDoc DefByName(const ObjectRef& obj, const String& name);
  /*!
   * \brief Set the doc of a variable IR node
   * \details This is useful when the variable is implicitly defined in the TVMScript.
   *          For example, when defining a `tir::Buffer buf`, buf->data is also a tir::Var,
   *          which should be printed as `buf.data`, rather than an identifier
   *          in the TVMScript.
   */
  virtual ExprDoc DefByDoc(const ObjectRef& obj, const ExprDoc& doc);
}

Upgrade Plan

IRModule.script() is the current way to print TIR into TVMScript. It calls the script.AsTVMScript function registered at scr/printer/tvmscript_printer.cc. We plan to split the whole upgrading process into 5 steps.

  1. Without breaking change to existing functionality, upstream system components piece by piece with small PRs under a tracking issue. This new system mainly locates in src/script, which does not affect the functionality of the existing TVMScript printer.
  2. Expose the unified printer as a global TVM function script.printer.Script, which is parallel to the existing printer.
  3. Add a boolean flag use_legacy_printer to the Python IRModule.script, which defaults to True. IRModule.script calls script.printer.Print if use_legacy_printer is explicitly turned off.
  4. After stabilizing the new infra, change the default value use_legacy_printer to True.
  5. Finally, deprecate the use_legacy_printer flag and clean up legacy code.

Drawbacks

N/A

Rationale and alternatives

Compared to the existing way of printing TVMScript in single stage, introducing two-stage printing will certainly increase the amount of code that needs to be written. However, we believe two-stage printing is the right choice because it reduces the complexity in the printing logic of each IR dialect by removing unneccessary details about the target language syntax and string operations. Therefore, it's more scalable if we want to support printing multiple kinds of IR (TIR, Relax, and any potential third-party IRs in the future).

For example, printing buffer region (like A[1:10, 2]) in the current printer looks like

Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) {
  Doc doc;
  if (op->region.size() == 0) {
    doc << Print(op->buffer) << "[()]";
  } else {
    doc << Print(op->buffer) << "[";
    for (size_t i = 0; i < op->region.size(); ++i) {
      if (i != 0) doc << ", ";
      const auto& range = op->region[i];
      if (!is_one(range->extent)) {
        doc << Print(range->min) << " : " << Print(ana_.Simplify(range->min + range->extent));
      } else {
        doc << Print(range->min);
      }
    }
    doc << "]";
  }
  return doc;
}

while in the unified printer with two-stage printing

ExprDoc PrintBufferRegion(tir::BufferRegion buffer_region, IRDocsifier p) {
  Array<Doc> indices;

  for (const Range& range : buffer_region->region) {
    if (tir::is_one(range->extent)) {
      indices.push_back(p->AsExprDoc(range->min));
    } else {
      indices.push_back(p->AsExprDoc(range));
    }
  }

  return p->AsExprDoc(buffer_region->buffer)->Index(indices);
}

The latter one is much simpler because it's free from the noisy code on how to print the script in valid index syntax in Python.

Assume the printer needs to support k IRs, and it takes m time to develop the logic around IR semantics and n time to develop the logic around target language syntax. It will take k*(m+n) time if we use single-stage printing and km + n time if we adopt two-stage printing. We believe the cost of extending the Doc class will be paid off as soon as k is larger than one, based on our PoC on using two-stage printing for TIR.

Additionally, with two-stage printing we can change the output language from Python to other languages easily. Although we will still focus on TVMScript in Python in the foreseeable future, having such flexibilty is a nice additional benefit.

Prior art

RFC for TVMScript: https://discuss.tvm.apache.org/t/rfc-hybrid-script-support-for-tir/7516

Unresolved questions

N/A

Future possibilities

With the unified TVMScript printer, we have one of the building blocks towards a more open architecture, where the community can author their own IR and plug into the TVM stack, interacting with other components and layers.

As a mirror of this RFC, we will send out another RFC on the unified TVMScript parser, to support parsing TVMScript into different kinds of IR.