Skip to content

Commit

Permalink
[AOT] Refactor AOTExecutorCodegenModule using new AOT passes
Browse files Browse the repository at this point in the history
This commit refactors the AOTExecutorCodegenModule to make use
of the newly introduced AOT passes AOTLowerMain,
CreateFunctionMetadata and CreateExecutor metadata. Some
modifications are additionally made to the 'Codegen' interface
to make important code generation options explicit.
  • Loading branch information
mbaret committed Oct 24, 2022
1 parent 02431ec commit fd4aeb3
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 1,396 deletions.
209 changes: 209 additions & 0 deletions src/relay/backend/aot/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/backend/aot_executor_codegen.cc
* \brief AOT executor codegen
*/

#include <tvm/ir/module.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/attrs/call.h>
#include <tvm/relay/executor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/runtime.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/usmp/utils.h>

#include <algorithm>
#include <list>
#include <string>
#include <vector>

#include "../te_compiler.h"
#include "../utils.h"
#include "./aot_lower_main.h"
#include "./create_executor_metadata.h"
#include "./create_function_metadata.h"

namespace tvm {
namespace relay {
namespace backend {
namespace aot {

std::unordered_map<std::string, tvm::runtime::NDArray> CreateParamMap(
const IRModule& mod, const std::unordered_map<std::string, runtime::NDArray>& external_params) {
auto params = std::unordered_map<std::string, tvm::runtime::NDArray>();
// Collect any constants extracted by external codegen.
Map<String, runtime::NDArray> const_name_to_constant =
mod->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToConstant).value_or({});
for (const auto& kv : const_name_to_constant) {
params[kv.first] = kv.second;
}

// Collect any constants extracted during lowering.
for (const auto& kv : external_params) {
params[kv.first] = kv.second;
}

return params;
}

LoweredOutput Codegen(IRModule mod, String mod_name, CompilationConfig config, Executor executor,
CallType call_type) {
Integer workspace_byte_alignment =
executor->GetAttr<Integer>("workspace-byte-alignment").value_or(1);
Integer constant_byte_alignment =
executor->GetAttr<Integer>("constant-byte-alignment").value_or(1);
// Required Relay passes prior to AOT codegen (should be refactored out of executors)
mod = transform::ToANormalForm()(mod);
mod = transform::InferType()(mod);
mod = transform::AnnotateUsedMemory()(mod); // TODO(mbaret) Move into Ethos-U hook
std::unordered_map<std::string, runtime::NDArray> external_params;
mod = tec::LowerTE(mod_name, config, [&external_params](BaseFunc func) {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
UpdateConstants(func, &external_params);
}
})(mod);

// Lower the main Relay function to a TIR PrimFunc
// After this point the entire module is composed of PrimFuncs
mod = AOTLowerMain(mod_name, config, call_type)(mod);

mod = tir::transform::ConvertForLoopsToSerial()(mod); // TODO(mbaret) Make this optional
transform::PassContext pass_ctx = transform::PassContext::Current();
bool enable_usmp = pass_ctx->GetConfig<Bool>(kUSMPEnableOption, Bool(false)).value();
if (enable_usmp) {
mod = tir::transform::UnifiedStaticMemoryPlanner()(mod);
} else {
tir::PrimFunc tir_main_func =
Downcast<tir::PrimFunc>(mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
IRModule main_func_mod;
main_func_mod->Update(mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main),
tir_main_func);
main_func_mod = tir::transform::StorageRewrite()(main_func_mod);
mod->Update(mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main),
main_func_mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
}
mod = tir::transform::LegalizePackedCalls()(mod);

// Collect the various functions, params and metadata into a LoweredOutput
LoweredOutput ret;
ret.params = CreateParamMap(mod, external_params);
ret.external_mods =
mod->GetAttr<Array<tvm::runtime::Module>>(tvm::attr::kExternalMods).value_or({});
ret.function_metadata =
std::move(CreateFunctionMetadata(mod, workspace_byte_alignment, constant_byte_alignment));
ret.lowered_funcs = tec::GetPerTargetModules(mod);
ret.metadata = CreateExecutorMetadata(mod, mod_name, executor, workspace_byte_alignment,
constant_byte_alignment);
return LoweredOutput(std::move(ret));
}

class AOTExecutorCodegenModule : public runtime::ModuleNode {
public:
AOTExecutorCodegenModule() {}
virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
// Do nothing
});
} else if (name == "codegen") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
IRModule mod = args[0];
Function func = args[1];
String mod_name = args[2];
CompilationConfig config = args[3];
Executor executor = args[4];
Integer call_type = args[5];
this->output_ =
Codegen(mod, mod_name, config, executor, static_cast<CallType>(call_type->value));
});
} else if (name == "list_params_name") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = list_params_name(); });
} else if (name == "get_param_by_name") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
String key = args[0];
*rv = get_param_by_name(key);
});
} else if (name == "get_irmodule") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_irmodule(); });
} else if (name == "get_external_modules") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_external_modules(); });
} else if (name == "get_function_metadata") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->output_.function_metadata;
});
} else if (name == "get_devices") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = Array<String>(); });
} else if (name == "get_executor_codegen_metadata") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = output_.metadata; });
} else {
return PackedFunc([](TVMArgs args, TVMRetValue* rv) {});
}
}

const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; }

private:
Array<runtime::String> list_params_name() {
Array<runtime::String> ret;
for (const auto& kv : this->output_.params) {
ret.push_back(kv.first);
}
return ret;
}

runtime::NDArray get_param_by_name(String key) {
auto it = this->output_.params.find(key);
CHECK(it != this->output_.params.end()) << "no such parameter " << key;
return (*it).second;
}

Array<tvm::runtime::Module> get_external_modules() { return output_.external_mods; }

Map<Target, IRModule> get_irmodule() { return this->output_.lowered_funcs; }

LoweredOutput output_;
};

runtime::Module CreateAOTExecutorCodegenMod() {
auto ptr = make_object<AOTExecutorCodegenModule>();
return runtime::Module(ptr);
}

TVM_REGISTER_GLOBAL("relay.build_module._AOTExecutorCodegen")
.set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateAOTExecutorCodegenMod(); });

} // namespace aot
} // namespace backend
} // namespace relay
} // namespace tvm
Loading

0 comments on commit fd4aeb3

Please sign in to comment.