Skip to content

Commit

Permalink
[BYOC] RelayToTIR custom codegen passes can still depend on dynamic s…
Browse files Browse the repository at this point in the history
…hape functions

In apache#11474 I got ready to switch CUTLASS from function-at-a-time to IRModule-at-a-time compilation.
However my approach didn't handle dynamic shape functions, so I adjust it here.

The idea is still that such passes will leave behind
calls to 'extern' functions. However, converting those
calls to 'call_lowered' form in
MarkCompilerFunctionsAsExtern is too soon since only
the TECompiler knows how to capture all the attributes
necessary to support dynamic shape functions.

So stop doing that in MarkCompilerFunctionsAsExtern and
instead support this case properly in the TECompiler.

While there try to chip away at the chronic lack of structure in te_compiler.cc. Every little bit helps.

Add a basic unit test.
  • Loading branch information
mbs-octoml committed Jun 9, 2022
1 parent 81b42e6 commit 78d92c8
Show file tree
Hide file tree
Showing 10 changed files with 503 additions and 228 deletions.
8 changes: 3 additions & 5 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1064,9 +1064,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {

mod = transform::ToANormalForm()(mod);

IRModule lowered_mod = tec::LowerTEPass(
mod_name,
[this, workspace_byte_alignment](BaseFunc func) {
IRModule lowered_mod =
tec::LowerTE(mod_name, config_, [this, workspace_byte_alignment](BaseFunc func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand All @@ -1078,8 +1077,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment);
},
config_)(mod);
})(mod);

auto lowered_main = lowered_mod->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
Expand Down
27 changes: 12 additions & 15 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,22 +217,19 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
mod = WithAttr(mod, "main_func_info", func_info);
}

IRModule lowered_mod = tec::LowerTEPass(
mod_name_,
[this](BaseFunc func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
if (func->GetAttr<String>(attr::kCompiler).defined()) {
UpdateConstants(func, &params_);
}
IRModule lowered_mod = tec::LowerTE(mod_name_, config_, [this](BaseFunc func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
if (func->GetAttr<String>(attr::kCompiler).defined()) {
UpdateConstants(func, &params_);
}

// TODO(@areusch, @jroesch): We should refactor this to
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_);
},
config_)(mod);
// TODO(@areusch, @jroesch): We should refactor this to
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_);
})(mod);

Optional<backend::FunctionInfo> main_func_info =
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
Expand Down
3 changes: 1 addition & 2 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -960,8 +960,7 @@ IRModule Prepare(IRModule mod, const CompilationConfig& config) {
// eta expand to support constructors in argument position.
transform::EtaExpand(
/*expand_constructor=*/true, /*expand_global_var=*/false),
transform::InferType(),
tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op */ }, config)});
transform::InferType(), tec::LowerTE(/*module_name=*/"intrp", config)});

transform::PassContext pass_ctx = transform::PassContext::Current();
With<transform::PassContext> ctx(pass_ctx);
Expand Down
329 changes: 221 additions & 108 deletions src/relay/backend/te_compiler.cc

Large diffs are not rendered by default.

32 changes: 9 additions & 23 deletions src/relay/backend/te_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
*/

/*!
* \file relay/backend/tir_compiler.h
* * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns.
* \file relay/backend/te_compiler.h
* \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns.
*
*
* This represents the new design of the Relay compilation flow and will replace the interface
Expand Down Expand Up @@ -173,36 +173,22 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, const Compila
*/
Map<Target, IRModule> GetPerTargetModules(IRModule mod);

/*! \brief Lower an IRModule's primitive functions to TIR.
*
* This is the "back half" of the Relay compiler which lowers "primitive functions"
* to TE expressions, schedules them, and then to TIR.
*
* \param module The IRModule.
* \param memory_plan The memory plan used during lowering
* \param module_name The name of this module
* \param process_fn Callback allowing one-level up code generators to process
* each function that we lower
* \return The lowered module, see above.
*/
IRModule LowerTE(
const IRModule& module, backend::StaticMemoryPlan memory_plan, const String& module_name,
ProcessFn process_fn = [](BaseFunc f) {});
inline void DefaultProcessFn(BaseFunc) {}

/*!
* \brief Pass to lower an IRModule's primitive functions to TIR.
*
* This is the "back half" of the Relay compiler which lowers "primitive functions"
* to TE expressions, schedules them, and then to TIR. It annotates all functions
* with their target.
* to TE expressions, schedules them, and emits PrimFuncs.
*
* \param module_name The name of this module
* \param process_fn Callback allowing one-level up code generators to process
* each function that we lower
* \param module_name The name of this module, used as a prefix for generated globals.
* \param config All available targets.
* \param process_fn Callback allowing one-level up code generators to process
* each function that we lower (default is no-op).
* \returns The pass which lowers primitive functions to TIR
*/
transform::Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig config);
transform::Pass LowerTE(String module_name, CompilationConfig config,
ProcessFn process_fn = DefaultProcessFn);

} // namespace tec
} // namespace relay
Expand Down
24 changes: 10 additions & 14 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1039,13 +1039,11 @@ transform::Sequential VMCompiler::FuseAndLowerOperators(const CompilationConfig&
// Give each "primitive" Function a hash.
pass_seqs.push_back(LabelOps());
// Lower "primitive" Functions to PrimFuncs and rewrite calls.
pass_seqs.push_back(tec::LowerTEPass(/*module_name=*/"vm_mod",
[this](const BaseFunc& func) {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
backend::UpdateConstants(func, &params_);
}
},
config));
pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config, [this](const BaseFunc& func) {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
backend::UpdateConstants(func, &params_);
}
}));
// Since lowered functions are bound in the IRModule, we can now eliminate any unused
// let-bound functions.
pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));
Expand Down Expand Up @@ -1090,13 +1088,11 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
pass_seqs.push_back(transform::LabelOps());

// Lower all functions annotated as "primitive" by FuseOps.
pass_seqs.push_back(tec::LowerTEPass(/*module_name=*/"vm_mod",
[this](const BaseFunc& func) {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
backend::UpdateConstants(func, &params_);
}
},
config_));
pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config_, [this](const BaseFunc& func) {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
backend::UpdateConstants(func, &params_);
}
}));

// Since lowered functions are bound in the IRModule, we can now eliminate any unused
// let-bound functions.
Expand Down
51 changes: 0 additions & 51 deletions src/relay/transforms/compiler_function_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,42 +81,6 @@ class Outliner : public MixedModeMutator {
IRModule mod_;
};

/*!
* \brief Rewrite calls to global "Compiler" functions to use the 'call_lowered' convention.
*/
class CallRewriter : public MixedModeMutator {
public:
CallRewriter(std::string compiler_filter, IRModule mod)
: compiler_filter_(std::move(compiler_filter)), mod_(std::move(mod)) {}

Expr Rewrite_(const CallNode* pre, const Expr& post) final {
Call new_call = Downcast<Call>(post);
if (const auto* global_var_node = new_call->op.as<GlobalVarNode>()) {
if (const auto* function_node =
mod_->Lookup(GetRef<GlobalVar>(global_var_node)).as<FunctionNode>()) {
Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler);
if (opt_compiler.defined() &&
(compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) {
Optional<String> opt_global_symbol =
function_node->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(opt_global_symbol.defined());
GlobalVar global_symbol = mod_->GetGlobalVar(opt_global_symbol.value());
CallLoweredAttrs attrs;
attrs.metadata.Set("relay_attrs", new_call->attrs);
return CallLowered(global_symbol, new_call->args, attrs, new_call->span);
}
}
}
return post;
}

private:
/*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */
std::string compiler_filter_;
/*! \brief Module being rewritten. */
IRModule mod_;
};

} // namespace

GlobalSymbolCache::~GlobalSymbolCache() = default;
Expand Down Expand Up @@ -169,20 +133,6 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) {
IRModule output_mod = mod->ShallowCopy();

// First pass, rewrite the calls.
// We have to do this before marking functions as 'extern' to know which calls to rewrite!
for (const auto& kv : mod->functions) {
if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
Expr new_body =
CallRewriter(compiler_filter, output_mod).VisitExpr(function_node->body);
Function new_function =
WithFields(GetRef<Function>(function_node), /*opt_params=*/{}, new_body);
output_mod->Update(kv.first, new_function);
}
}

// Second pass, mark functions as 'extern'.
for (const auto& kv : mod->functions) {
if (const auto* function_node = kv.second.as<FunctionNode>()) {
Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler);
Expand All @@ -197,7 +147,6 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) {
}
}
}

return output_mod;
};

Expand Down
11 changes: 4 additions & 7 deletions src/relay/transforms/compiler_function_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,8 @@
*
* - \p MarkCompilerFunctionsAsExtern will replace global functions with a matching "Compiler"
* attribute with the same function with just an "Extern" attribute, signalling the function
* has been dealt with. Calls to such functions will be rewritten to use the 'call_lowered'
* calling convention. Can be used after lowering to cleanup the IRModule.
*
* Note that the above behaviour is hard coded within the TECompiler, but is only available to
* external codegen using the Function-at-a-time "relay.ext.toolchain" extension point.
* has been dealt with. However calls to such functions will be left unchanged. Can be used
* after lowering to cleanup the IRModule.
*/

#ifndef TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_
Expand Down Expand Up @@ -118,8 +115,8 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co

/*!
* \brief A pass to mark all global functions which have a "Compiler" attribute matching
* compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute, and
* rewrite all calls to such functions to use the 'call_lowered' calling convention.
* compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute.
* Calls to such functions are not changed.
*
* If \p compiler_filter is non-empty only functions with that as their attribute value are
* outlined.
Expand Down
Loading

0 comments on commit 78d92c8

Please sign in to comment.