Skip to content

Commit

Permalink
[BYOC] Two helper passes for external codegen using RelayToTIR custom…
Browse files Browse the repository at this point in the history
… pass machinery

(See https://discuss.tvm.apache.org/t/byoc-supporting-cutlass-byoc-with-collage/12796/6 for
context, which in turn is part of Collage (https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md).

For reasons explained in the above thread I'm moving CUTLASS to be IRModule-at-a-time external codegen
using a custom RelayToTIR pass instead of the traditional function-at-a-time external codegen using
a relay.ext.cutlass registered function. This means some of the rewriing done on-the-fly by LowerTEPass now
needs to be done by the custom pass directly. This PR supplies two passes which ease that burden:
 - Before starting the CUTLASS-specific processing, make sure all "Compiler" attributed functions have
   unique global definitions (ie are outlined). Though functions start in this form after BYOC partitioning,
   under Graph and AOT compilation flows those functions are then inlined to pass through the 'codegen' keyhole
   which assumes the whole model is just one self-contained main function. This pass will undo that. (I gave up
   trying to just remove the inlining in the first place.)
 - After the CUTLASS-specific processing the now compiled "Compiler" attributed functions need to marked as
   'extern'. The te_compiler.cc uses the "ExternalSymbol" attribute for that, but since a) the symbol name
   is never needed, on the presense of the attribute is significant downstream and b) "ExternalSymbol" is
   easy to confuse with "global_symbol", I just replaced "ExternalSymbol" with "Extern" with an Integer(1)
   (cf "Primitive").

 The outlining pass is a little more general than necessary because it (will also) be used by Collage to
 rewrite the IRModule into optimally partitioned form while making maximal reuse of partition functions.
 Hence the abstract GlobalSymbolCache.
  • Loading branch information
mbs-octoml committed May 27, 2022
1 parent bc492ac commit 5ece9dd
Show file tree
Hide file tree
Showing 15 changed files with 585 additions and 39 deletions.
3 changes: 2 additions & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,10 @@ class GlobalVarNode : public RelayExprNode {
*/
class GlobalVar : public RelayExpr {
public:
TVM_DLL explicit GlobalVar(String name_hint, Type type = {});
TVM_DLL explicit GlobalVar(String name_hint, Type type = {}, Span span = {});

TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode);
};

// PrimExprs that are useful as runtime containers.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/call.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace relay {
* \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR.
*/
struct CallLoweredAttrs : public tvm::AttrsNode<CallLoweredAttrs> {
/*! \brief The metadata attached to the call node. */
/*! \brief Additional metadata attached to the call node. Should be replaced by explict fields. */
Map<String, ObjectRef> metadata;

TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") {
Expand Down
24 changes: 20 additions & 4 deletions include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,19 +170,34 @@ const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func);
* \brief namespace of the attributes that can be attached to a relay::Function.
*/
namespace attr {
/*! \brief Mark the function as a primitive function. */

/*!
* \brief Mark the function as a primitive function. Should be bound to \p Integer(1).
*
* Type: Integer
*/
constexpr const char* kPrimitive = "Primitive";

/*!
* \brief Mark the function as being 'extern', ie implemented in a runtime::Module. Should be bound
* to \p Integer(1). Typically accompanied by "Primitive".
*
* Type: Integer
*/
constexpr const char* kExtern = "Extern";

/*!
* \brief Indicate the compiler that should be used for building this function.
* \brief Indicate the external codegen 'compiler' that should be used for building this function.
* When this is unset or set to "default", the default compilation pipeline will be used.
*
* Type: String
*/
constexpr const char* kCompiler = "Compiler";

/*! \brief Indicate if the function is a closure. */
constexpr const char* kClosure = "Closure";
/*! \brief Store a Var to parameter/Constant mapping on a Function. */
constexpr const char* kParams = "__params__";
/*! \brief Store the unique external symbol for external compilers. */
constexpr const char* kExternalSymbol = "ExternalSymbol";
/*! \brief Mark if the function should be avoided being optimized. */
constexpr const char* kSkipOptimization = "SkipOptimization";
/*! \brief Treat the function as a composite operator. */
Expand All @@ -193,6 +208,7 @@ constexpr const char* kInline = "Inline";
constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
/*! \brief Mark the function as only composed of reshape operations. */
constexpr const char* kReshapeOnly = "relay.reshape_only";

} // namespace attr

} // namespace relay
Expand Down
66 changes: 48 additions & 18 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,24 +802,6 @@ def Inline():
return _ffi_api.Inline()


def InlineComposites(target):
"""Perform inlining on the given Relay IR module. The functions originate
from the MergeComposite pass based on an input pattern table will fold back
to main. Currently, this is used for the TRT BYOC which expects a single
primitive function to operate on.
Parameters
----------
target: str
The byoc target for which ops need to fold back to primitive function.
Returns
-------
ret: tvm.transform.Pass
The registered pass that performs inlining for a Relay IR module.
"""
return _ffi_api.InlineComposites(target)


def gradient(expr, mod=None, mode="higher_order"):
"""
Transform the input function,
Expand Down Expand Up @@ -1386,3 +1368,51 @@ def SplitArgs(max_function_args):
The registered pass for constant folding.
"""
return _ffi_api.SplitArgs(max_function_args)


def OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter=""):
"""A pass to outline all literal functions in direct call positions which have a "Compiler"
attribute. The functions are bound to unique global vars according to their existing
"global_symbol" attribute. At most one function with the same global symbol is outlined.
If compiler_filter is non-empty only functions with that as their attribute value are
outlined.
This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism
to prepare the IRModule before custom lowering.
Parameters
----------
compiler_filter : String
If non-empty, the 'compiler' attribute to filter on.
Returns
-------
ret : tvm.transform.Pass
The pass.
"""
return _ffi_api.OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter)


def MarkCompilerFunctionsAsExtern(compiler_filter=""):
"""A pass to mark all global functions which have a "Compiler" attribute matching
compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute, and
rewrite all calls to such functions to use the 'call_lowered' calling convention.
If compiler_filter is non-empty only functions with that as their attribute value are
outlined.
This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism to
cleanup the IRModule after custom lowering.
Parameters
----------
compiler_filter : String
If non-empty, the 'compiler' attribute to filter on.
Returns
-------
ret : tvm.transform.Pass
The pass.
"""
return _ffi_api.MarkCompilerFunctionsAsExtern(compiler_filter)
3 changes: 2 additions & 1 deletion src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});

GlobalVar::GlobalVar(String name_hint, Type type) {
GlobalVar::GlobalVar(String name_hint, Type type, Span span) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint);
n->checked_type_ = std::move(type);
n->span = std::move(span);
data_ = std::move(n);
}

Expand Down
4 changes: 1 addition & 3 deletions src/parser/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,6 @@ struct Tokenizer {
int line = this->line;
int column = this->col;

ICHECK_EQ(Peek(), '[');
Next();
std::stringstream type_key;
while (More() && Peek() != ']') {
type_key << Next();
Expand Down Expand Up @@ -498,7 +496,7 @@ struct Tokenizer {
auto token = NewToken(TokenType::kQuestion);
Next();
return token;
} else if (MatchString("meta")) {
} else if (MatchString("meta[")) {
return TokenizeMetaRef();
} else if (next == '#') {
return TokenizeAttr();
Expand Down
6 changes: 3 additions & 3 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class TECompilerImpl : public TECompilerNode {
if (const auto* function_node = kv2.second.as<FunctionNode>()) {
// Abandon the existing function annotations.

// Unfortuantely, Optional<DictAttrs>() is indistinguishable from
// Unfortunately, Optional<DictAttrs>() is indistinguishable from
// NullValue<DictAttrs>(), and DictAttrs() is nullptr, so to erase the attributes, we
// need pass in DictAttrs<Map<String, ObjectRef>()), which is a DictAttrs containing no
// attributes.
Expand All @@ -177,7 +177,7 @@ class TECompilerImpl : public TECompilerNode {
function_node->body, function_node->ret_type, function_node->type_params,
/* erase attributes */ DictAttrs(Map<String, ObjectRef>()));
// Mark function as 'extern' using the "ExternalSymbol" attribute.
function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint);
function = WithAttr(std::move(function), attr::kExtern, Integer(1));
module->Add(kv2.first, function);
}
}
Expand Down Expand Up @@ -689,7 +689,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {

Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override {
if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
function_node->GetAttr<String>(attr::kExternalSymbol)) {
function_node->HasNonzeroAttr(attr::kExtern)) {
// Nothing to lower inside primitive/external functions.
return GetRef<Function>(function_node);
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ void VMCompiler::LowerImpl(IRModule mod) {
for (const auto& pair : context_.module->functions) {
auto gvar = pair.first;
if (auto* n = pair.second.as<FunctionNode>()) {
if (n->GetAttr<String>(attr::kExternalSymbol).defined()) {
if (n->HasNonzeroAttr(attr::kExtern)) {
// Already compiled during lowering.
continue;
}
Expand Down Expand Up @@ -1129,7 +1129,7 @@ size_t VMCompiler::PopulateGlobalMap() {
// Excludes PrimFuncs and externs, which are managed by the primitive_map_.
for (const auto& kv : context_.module->functions) {
if (const auto* function_node = kv.second.as<FunctionNode>()) {
if (!function_node->GetAttr<String>(attr::kExternalSymbol)) {
if (!function_node->HasNonzeroAttr(attr::kExtern)) {
context_.global_map.emplace(kv.first, context_.global_map.size());
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ FuncType FunctionNode::func_type_annotation() const {
const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func) {
if (const auto* function_node = base_func.as<FunctionNode>()) {
if (!function_node->GetAttr<String>(attr::kCompiler).defined() &&
!function_node->GetAttr<String>(attr::kExternalSymbol).defined() &&
!function_node->HasNonzeroAttr(attr::kExtern) &&
!function_node->HasNonzeroAttr(attr::kSkipOptimization)) {
return function_node;
}
Expand Down
1 change: 1 addition & 0 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,7 @@ Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT f
- **out**: `(b, m, n)`.
)code" TVM_ADD_FILELINE)
.set_attrs_type<BatchMatmulAttrs>()
.set_num_inputs(2)
.add_argument("tensor_a", "3D Tensor", "The first input.")
.add_argument("tensor_b", "3D Tensor", "The second input.")
Expand Down
Loading

0 comments on commit 5ece9dd

Please sign in to comment.