From fd4aeb38506390f4574b5bc0ffc2966c21eb5a90 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Mon, 22 Aug 2022 18:46:40 +0100 Subject: [PATCH] [AOT] Refactor AOTExecutorCodegenModule using new AOT passes This commit refactors the AOTExecutorCodegenModule to make use of the newly introduced AOT passes AOTLowerMain, CreateFunctionMetadata and CreateExecutor metadata. Some modifications are additionally made to the 'Codegen' interface to make important code generation options explicit. --- src/relay/backend/aot/aot_executor_codegen.cc | 209 +++ src/relay/backend/aot_executor_codegen.cc | 1392 ----------------- src/relay/backend/build_module.cc | 43 +- tests/python/relay/aot/test_crt_aot.py | 2 +- 4 files changed, 250 insertions(+), 1396 deletions(-) create mode 100644 src/relay/backend/aot/aot_executor_codegen.cc delete mode 100644 src/relay/backend/aot_executor_codegen.cc diff --git a/src/relay/backend/aot/aot_executor_codegen.cc b/src/relay/backend/aot/aot_executor_codegen.cc new file mode 100644 index 0000000000000..826fd6fe82bd9 --- /dev/null +++ b/src/relay/backend/aot/aot_executor_codegen.cc @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/aot_executor_codegen.cc + * \brief AOT executor codegen + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../te_compiler.h" +#include "../utils.h" +#include "./aot_lower_main.h" +#include "./create_executor_metadata.h" +#include "./create_function_metadata.h" + +namespace tvm { +namespace relay { +namespace backend { +namespace aot { + +std::unordered_map CreateParamMap( + const IRModule& mod, const std::unordered_map& external_params) { + auto params = std::unordered_map(); + // Collect any constants extracted by external codegen. + Map const_name_to_constant = + mod->GetAttr>(tvm::attr::kConstNameToConstant).value_or({}); + for (const auto& kv : const_name_to_constant) { + params[kv.first] = kv.second; + } + + // Collect any constants extracted during lowering. + for (const auto& kv : external_params) { + params[kv.first] = kv.second; + } + + return params; +} + +LoweredOutput Codegen(IRModule mod, String mod_name, CompilationConfig config, Executor executor, + CallType call_type) { + Integer workspace_byte_alignment = + executor->GetAttr("workspace-byte-alignment").value_or(1); + Integer constant_byte_alignment = + executor->GetAttr("constant-byte-alignment").value_or(1); + // Required Relay passes prior to AOT codegen (should be refactored out of executors) + mod = transform::ToANormalForm()(mod); + mod = transform::InferType()(mod); + mod = transform::AnnotateUsedMemory()(mod); // TODO(mbaret) Move into Ethos-U hook + std::unordered_map external_params; + mod = tec::LowerTE(mod_name, config, [&external_params](BaseFunc func) { + if (func->GetAttr(attr::kCompiler).defined()) { + UpdateConstants(func, &external_params); + } + })(mod); + + // Lower the main Relay function to a TIR PrimFunc + // After this point the entire module is composed of PrimFuncs + mod = AOTLowerMain(mod_name, config, call_type)(mod); + + mod = tir::transform::ConvertForLoopsToSerial()(mod); // TODO(mbaret) Make this optional + transform::PassContext pass_ctx = transform::PassContext::Current(); + bool enable_usmp = pass_ctx->GetConfig(kUSMPEnableOption, Bool(false)).value(); + if (enable_usmp) { + mod = tir::transform::UnifiedStaticMemoryPlanner()(mod); + } else { + tir::PrimFunc tir_main_func = + Downcast(mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); + IRModule main_func_mod; + main_func_mod->Update(mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main), + tir_main_func); + main_func_mod = tir::transform::StorageRewrite()(main_func_mod); + mod->Update(mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main), + main_func_mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); + } + mod = tir::transform::LegalizePackedCalls()(mod); + + // Collect the various functions, params and metadata into a LoweredOutput + LoweredOutput ret; + ret.params = CreateParamMap(mod, external_params); + ret.external_mods = + mod->GetAttr>(tvm::attr::kExternalMods).value_or({}); + ret.function_metadata = + std::move(CreateFunctionMetadata(mod, workspace_byte_alignment, constant_byte_alignment)); + ret.lowered_funcs = tec::GetPerTargetModules(mod); + ret.metadata = CreateExecutorMetadata(mod, mod_name, executor, workspace_byte_alignment, + constant_byte_alignment); + return LoweredOutput(std::move(ret)); +} + +class AOTExecutorCodegenModule : public runtime::ModuleNode { + public: + AOTExecutorCodegenModule() {} + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "init") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + // Do nothing + }); + } else if (name == "codegen") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + IRModule mod = args[0]; + Function func = args[1]; + String mod_name = args[2]; + CompilationConfig config = args[3]; + Executor executor = args[4]; + Integer call_type = args[5]; + this->output_ = + Codegen(mod, mod_name, config, executor, static_cast(call_type->value)); + }); + } else if (name == "list_params_name") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = list_params_name(); }); + } else if (name == "get_param_by_name") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + String key = args[0]; + *rv = get_param_by_name(key); + }); + } else if (name == "get_irmodule") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_irmodule(); }); + } else if (name == "get_external_modules") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_external_modules(); }); + } else if (name == "get_function_metadata") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->output_.function_metadata; + }); + } else if (name == "get_devices") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = Array(); }); + } else if (name == "get_executor_codegen_metadata") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = output_.metadata; }); + } else { + return PackedFunc([](TVMArgs args, TVMRetValue* rv) {}); + } + } + + const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; } + + private: + Array list_params_name() { + Array ret; + for (const auto& kv : this->output_.params) { + ret.push_back(kv.first); + } + return ret; + } + + runtime::NDArray get_param_by_name(String key) { + auto it = this->output_.params.find(key); + CHECK(it != this->output_.params.end()) << "no such parameter " << key; + return (*it).second; + } + + Array get_external_modules() { return output_.external_mods; } + + Map get_irmodule() { return this->output_.lowered_funcs; } + + LoweredOutput output_; +}; + +runtime::Module CreateAOTExecutorCodegenMod() { + auto ptr = make_object(); + return runtime::Module(ptr); +} + +TVM_REGISTER_GLOBAL("relay.build_module._AOTExecutorCodegen") + .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateAOTExecutorCodegenMod(); }); + +} // namespace aot +} // namespace backend +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc deleted file mode 100644 index 786b3f81a5ae8..0000000000000 --- a/src/relay/backend/aot_executor_codegen.cc +++ /dev/null @@ -1,1392 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/relay/backend/aot_executor_codegen.cc - * \brief AOT executor codegen - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "../../target/source/codegen_source_base.h" -#include "../op/annotation/annotation.h" -#include "../op/call/call.h" -#include "../op/memory/device_copy.h" -#include "../transforms/device_aware_visitors.h" -#include "./name_transforms.h" -#include "./te_compiler.h" -#include "./utils.h" - -namespace tvm { -namespace relay { -namespace backend { - -using StorageMap = - std::unordered_map; - -/** - * This is an on demand allocator for AOT. A new temporary - * (storage allocator identifier) is allocated for each operation. - */ -class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { - public: - AOTOnDemandAllocator() : transform::DeviceAwareExprVisitor(Optional()) {} - - // run the visitor on a global function. - void Run(const Function& func) { VisitExpr(func); } - - std::vector GetReturnIds() const { return return_ids_; } - std::vector GetReturnTtypes() const { return return_ttypes_; } - - StorageMap GetStorageMap() const { return storage_device_map_; } - - using ExprVisitor::VisitExpr_; - - void VisitExpr_(const ConstantNode* op) final { - CreateStorage(op); - AssignReturnSid(GetRef(op)); - } - - void DeviceAwareVisitExpr_(const CallNode* call_node) final { - // AOTOnDemandAllocator is run both before and after lowering, so we need to handle the case - // where the op of the call is a generic function - - Expr func; - Array args; - - CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); - if (call_lowered_props.lowered_func.defined()) { - func = call_lowered_props.lowered_func; - args = call_lowered_props.arguments; - } else { // Relay functions that have not been lowered and lowered extern functions - func = call_node->op; - args = call_node->args; - if (call_node->op.as()) { // Lowered extern function - ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes."; - } else { // Relay function which has not been lowered yet - ICHECK(call_node->op.as()) - << "Expected the call to be to a lowered primfunc, a lowered extern function or a " - "unlowered Relay function."; - } - } - VisitExpr(func); - CreateStorage(call_node); - for (const Expr& arg : args) { - VisitExpr(arg); - } - AssignReturnSid(GetRef(call_node)); - } - - void VisitExpr_(const VarNode* op) final { AssignReturnSid(GetRef(op)); } - - void DeviceAwareVisitExpr_(const FunctionNode* func_node) final { - if (function_nesting() > 1) { - // do not recurse into sub functions. - return; - } - if (func_node->HasNonzeroAttr(attr::kPrimitive)) { - // No storage needed for primitive functions. - return; - } - for (const auto& param : func_node->params) { - CreateStorage(param.get()); - } - VisitExpr(func_node->body); - } - - void VisitExpr_(const GlobalVarNode* op) final { - // Do nothing. - } - - void VisitExpr_(const OpNode* op) final { - // Do nothing. - } - - void VisitExpr_(const TupleNode* op) final { - std::vector storage_ids; - std::vector virtual_devices; - std::vector storage_sizes_in_bytes; - Expr expr = GetRef(op); - for (Expr field : op->fields) { - auto sid = GetStorage(field); - storage_ids.insert(storage_ids.end(), sid->storage_ids.begin(), sid->storage_ids.end()); - virtual_devices.insert(virtual_devices.end(), sid->virtual_devices.begin(), - sid->virtual_devices.end()); - storage_sizes_in_bytes.insert(storage_sizes_in_bytes.end(), - sid->storage_sizes_in_bytes.begin(), - sid->storage_sizes_in_bytes.end()); - } - storage_device_map_[expr] = StorageInfo(storage_ids, virtual_devices, storage_sizes_in_bytes); - AssignReturnSid(expr); - } - - void VisitExpr_(const TupleGetItemNode* op) final { - Expr expr = GetRef(op); - auto sids = GetStorage(op->tuple); - ICHECK_LT(static_cast(op->index), sids->storage_ids.size()); - storage_device_map_[expr] = - StorageInfo({sids->storage_ids[op->index]}, {sids->virtual_devices[op->index]}, - {sids->storage_sizes_in_bytes[op->index]}); - AssignReturnSid(expr); - } - - void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; } - - void PreVisitLetBinding_(const Var& var, const Expr& value) final { - VisitExpr(value); - StorageInfo si = GetStorage(value); - storage_device_map_[var] = si; - } - - private: - void AssignReturnSid(Expr e) { - if (storage_device_map_.find(e) != storage_device_map_.end()) { - StorageInfo& sinfo = storage_device_map_[e]; - return_ids_.clear(); - for (auto sid : sinfo->storage_ids) { - return_ids_.push_back(sid); - } - return_ttypes_.clear(); - return_ttypes_ = FlattenTupleType(e->checked_type()); - } - } - /*! - * \brief ceil(size/word_size) to get number of words. - * \param size The original size. - * \param word_size The element size. - */ - static size_t DivRoundUp(size_t size, size_t word_size) { - return (size + word_size - 1) / word_size; - } - /*! - * \brief Get the memory requirement. - * \param prototype The prototype token. - * \return The required memory size. - * - * TODO(mbs): Cf CalculateRelayExprSizeBytes in utils.cc, GetMemorySize is graph_plan_memory.cc - */ - size_t GetMemorySizeBytes(const TensorType& ttype) { - size_t size = 1; - for (IndexExpr dim : ttype->shape) { - const int64_t* pval = tir::as_const_int(dim); - ICHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << ttype->shape; - ICHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval; - size *= static_cast(pval[0]); - } - size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8); - return size; - } - /*! - * \brief Get the necessary storage for the expression. - * \param expr The expression. - * \return The corresponding token. - */ - StorageInfo GetStorage(const Expr& expr) { - // See through "on_device" calls. - Expr true_expr = IgnoreOnDevice(expr); - VisitExpr(true_expr); - auto it = storage_device_map_.find(true_expr); - ICHECK(it != storage_device_map_.end()) << "Could not find " << true_expr->GetTypeKey() << " " - << PrettyPrint(true_expr) << " in storage device map"; - return it->second; - } - - /*! - * \brief Create storage for the expression. - */ - void CreateStorage(const ExprNode* op) { - Expr expr = GetRef(op); - return CreateStorage(expr, GetVirtualDevice(expr)); - } - - /*! - * \brief Create storage to hold the result of evaluating \p expr in \p virtual_device. - */ - void CreateStorage(const Expr& expr, const VirtualDevice& virtual_device) { - ICHECK(!virtual_device->IsFullyUnconstrained()) - << "invalid virtual device for expr:" << std::endl - << PrettyPrint(expr); - std::vector storage_ids; - std::vector virtual_devices; - std::vector storage_sizes_in_bytes; - for (const auto& ttype : FlattenTupleType(expr->checked_type())) { - storage_ids.push_back(next_available_sid_++); - virtual_devices.push_back(virtual_device); - storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype)); - } - storage_device_map_[expr] = StorageInfo(std::move(storage_ids), std::move(virtual_devices), - std::move(storage_sizes_in_bytes)); - } - - /*! \brief mapping of expression -> storageInfo */ - StorageMap storage_device_map_; - /*! \brief current id of the temporary allocated */ - int next_available_sid_{0}; - /*! \brief the set of intermediate tensors that are return variables */ - std::vector return_ids_; - /*! \brief the data types of the return values */ - std::vector return_ttypes_; -}; - -/*! \brief Code generator for AOT executor */ -class AOTExecutorCodegen : public MixedModeVisitor { - protected: - /*! \brief Describes the type of kernel call emitted. */ - enum CallType { - /*! - * \brief Emit PackedFunc calls bound just-in-time using TVMBackend* functions. - * - * When this type is selected, assumes all operators must be called via TVMFuncCall. Given the - * implementation of TVMFuncCall in the C++ runtime, this in practice implies that those - * functions are of type TVMBackendPackedCFunc. - * - * The following code is emitted at call sites to call a function named `func`: - * void* func_ptr = TVMBackendGetFuncFromEnv("func"); - * TVMFuncCall(func_ptr, values, tcodes, num_args, ret_values, ret_tcodes) - * - * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args` - * by LowerTVMBuiltin TIR transform. - * - * If `resource_handle` is passed to `func`, it is determined by TVMFuncCall (often, - * `resource_handle` is registered with the C++ runtime to provide a `this` equivalent when - * `func` is implemented in C). - * - * Compatible with both C++ and C runtimes, implemented with the C runtime only. - */ - kPacked, // Emit tir.call_packed and wrap all arguments in DLTensor. - - /*! - * \brief Directly call a TVMBackendPackedCFunc named according to the tir::Call. - * - * When this type is selected, assumes all operators are implemented in functions of type - * `TVMBackendPackedCFunc` and should be called directly. That is, presumes at the time of - * downstream compilation that there is a symbol named after the 0th arg to tir::Call of - * type `TVMBackendPackedCFunc`. This situation should occur when target_host == target. - * - * The following code is emitted at call sites to call a function named `func`: - * func(values, tcodes, num_args, ret_values, ret_tcodes, resource_handle) - * - * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args` - * by LowerTVMBuiltin TIR transform. - * - * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is - * always the device context parameter when not null. At present, the implementation does not - * support forwarding device context parameters to CPacked. - * - * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented - * in the same scenarios. - */ - kCPacked, // Emit tir.call_cpacked and wrap all arguments in DLTensor. - - /*! \brief Directly call a function accepting the `data` arrays as args. - * - * When this type is selected, assumes all operaotrs are implemented in C functions whose - * arguments are 1-to-1 with those in the tir::Call. DLTensor arguments are encoded as just the - * `data` parameters (i.e. no DLTensor object is passed along). - * - * The following code is emitted at call sites to a function named `func`: - * func(void* arg0, void* arg1, ..., void* argN) // no resource_handle - * -or- - * func(void* arg0, void* arg1, ..., void* argN, void* resource_handle) // with resource_handle - * - * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is - * always the device context parameter when not null. - * - * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented - * with the C runtime only. - */ - kUnpacked, // Emit tir.call_extern passing only the `data` part of DLTensors. - }; - - /*! - * \brief Return a vector of variables that represents the sids for the given Relay Expr - */ - std::vector PackSid(Expr expr) { - std::vector buffer_vars; - - ICHECK(storage_device_map_.find(expr) != storage_device_map_.end()) - << "Storage map did not contain constant expr " << PrettyPrint(expr); - StorageInfo& sinfo = storage_device_map_[expr]; - - // Note that an expression can have multiple sids associated with it - // e.g., returning multiple values from a function - for (auto sid : sinfo->storage_ids) { - // Determine if an sid is an output buffer - auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid); - if (output_iter != return_sid_.end()) { - int output_index = std::distance(return_sid_.begin(), output_iter); - buffer_vars.push_back(GetBufferVarForIO(input_vars_.size() + output_index)); - continue; - } - - auto sid_value = sids_table_[sid]; - buffer_vars.push_back(sid_value); - } - return buffer_vars; - } - - /*! - * brief Given an expression return the variable(s) associated with that expression - */ - std::vector FindExpr(Expr arg) { - auto input_iter = std::find(input_vars_.begin(), input_vars_.end(), arg); - if (input_iter != input_vars_.end()) { - // Input variable - int main_index = std::distance(input_vars_.begin(), input_iter); - return {GetBufferVarForIO(main_index)}; - } else { - // Storage identifier (i.e., intermediate memory) - return PackSid(arg); - } - } - - /*! - * \brief Reverse lookup the device name in devices_ map. - * \param device_context Value in devices_ to find. - * \return Key matching device_context in devices_. - */ - std::string FindDeviceName(tir::Var device_context) { - for (std::pair kv : devices_) { - if (kv.second->name_hint == device_context->name_hint) { - return kv.first; - } - } - ICHECK(false) << "Did not find a device name associated with " << device_context; - return ""; - } - - void PushArgs(const Expr& expr, const std::vector& sids, Array* args) { - const TupleNode* t = expr.as(); - if (t != nullptr) { - CHECK_EQ(sids.size(), t->fields.size()) << "Relay tuple does not map 1:1 into TIR; AOT can't " - "handle this type of Relay Expr in a CallNode."; - } - - args->insert(args->end(), sids.begin(), sids.end()); - } - - /* - * Wraps a call_extern with a tvm_check_return annotation if required otherwise - * returns the passed Call - */ - tir::Call AddCheckReturn(tir::Call existing_call) { - Array args = {tir::make_const(DataType::Int(32, 1), 0, Span()), - tir::make_const(DataType::Int(32, 1), -1, Span()), existing_call}; - return tir::Call(DataType::Int(32), tir::builtin::tvm_check_return(), args); - } - - /*! - * brief Create a function call - * \param call_lowered_props The lowered function and the arguments to call it with - * \param result_expr The call we got func and args from (so as to recover the storage - * ids to hold the result). - */ - void CreateFuncCall(CallLoweredProps call_lowered_props, const Expr& result_expr) { - std::string func_name = call_lowered_props.lowered_func->name_hint; - tvm::Array args{tvm::tir::StringImm(func_name)}; - std::vector create_func_call_stmts; - - // Pack the inputs - for (const Expr& arg : call_lowered_props.arguments) { - if (params_by_expr_.find(arg) != params_by_expr_.end()) { - auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), - {tir::StringImm(params_by_expr_[arg])}); - // NOTE: this cast looks like a no-op, but is required for compilation downstream. - // Because DataType::Handle has default bits=64, but CodeGenC does not observe this field, - // adding this cast forces the codegen to insert the cast. In this case, a cast is required - // because param_handle is actually code-generated as `const void*`, and the `const` piece - // needs to be removed. - args.push_back(tvm::tir::Cast(DataType::Handle(32, 1), param_handle)); - } else { - auto sids = FindExpr(arg); - PushArgs(arg, sids, &args); - } - } - - // Pack the return(s) value. A call node can produce multiple outputs - auto result_expr_sid = PackSid(result_expr); - PushArgs(result_expr, result_expr_sid, &args); - - GlobalVar global_var = call_lowered_props.lowered_func; - bool has_c_device_api_context = device_contexts_.count(global_var) != 0; - tir::Var device_context; - tir::Stmt func_call; - - switch (call_type_) { - case CallType::kUnpacked: { - // call_extern calling convention with optional context - if (has_c_device_api_context) { - device_context = device_contexts_.Get(global_var).value(); - args.push_back(device_context); - } - func_call = tir::Evaluate(AddCheckReturn( - tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args))); - break; - } - case CallType::kCPacked: { - if (has_c_device_api_context) { - device_context = device_contexts_.Get(global_var).value(); - args.push_back(device_context); - } else { - // NOTE: LowerTVMBuiltin expects some device_context placeholder. - args.push_back(tir::make_zero(DataType::Handle())); - } - func_call = tir::Evaluate( - tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_cpacked(), args)); - create_func_call_stmts.push_back(func_call); - break; - } - case CallType::kPacked: { - // call_packed does not accept a device context. - CHECK(!has_c_device_api_context) << "CallType::kPacked does not accept a device context"; - func_call = tir::Evaluate(AddCheckReturn( - tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args))); - create_func_call_stmts.push_back(func_call); - break; - } - default: - ICHECK(false) << "Unknown CallType: " << call_type_; - } - - ICHECK(func_call.defined()) << "Must define func_call"; - - if (has_c_device_api_context) { - func_call = tir::SeqStmt(Array({ - GenerateDeviceHook(device_context, "Open"), - func_call, - GenerateDeviceHook(device_context, "Close"), - })); - } - - tir::Stmt body = tir::SeqStmt({func_call}); - stmts_.push_back(body); - } - - /*! - * \brief Copy a variable to the output. This function is mainly used in edge cases - * when we want to return an input or a parameter. - * TODO(giuseros): we should try to avoid unnecessary copy to the output, e.g., in a - * copy-on-write fashion. - */ - void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) { - // Define intermediate DLTensor to load/store the data - tir::Buffer tmp_read = - tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_read"); - tir::Buffer tmp_write = - tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_write"); - te::Var loop_idx("i", DataType::Int(32)); - auto retval_i = tir::BufferLoad(tmp_read, {loop_idx}); - // Copy the variable from the input to the output - tir::Stmt copy = tir::For( - loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial, - tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx})); - stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy)); - } - - /* - * \brief Collects device context variables for passing to operators - */ - void CollectDeviceVariables(const Map& device_contexts) { - Map target_contexts; - TargetKindAttrMap target_attr_map = tvm::TargetKind::GetAttrMap("use_device_api"); - - for (const auto& it : device_contexts) { - const GlobalVar& global_var = it.first; - const std::string device_context_name = it.second; - - Optional target_kind = tvm::TargetKind::Get(device_context_name); - if (!target_kind || !target_attr_map.count(target_kind.value())) { - return; - } - if (target_attr_map[target_kind.value()]) { - std::string context_name = tvm::runtime::SanitizeName(device_context_name); - tir::Var device_context_var("device_context_" + context_name, DataType::Handle()); - - auto pair = target_contexts.find(target_kind.value()); - if (pair != target_contexts.end()) { - device_context_var = (*pair).second; - } else { - main_signature_.push_back(device_context_var); - devices_.Set(context_name, device_context_var); - target_contexts.Set(target_kind.value(), device_context_var); - } - - device_contexts_.Set(global_var, device_context_var); - } - } - } - - /** - * \brief Generates a call to a given hook for all Devices found for C Device API - * \param Name of hook to generate statements for - * \return Statement with function calls for each device - */ - tir::Stmt GenerateAllDeviceHook(const String& hook) { - std::vector device_hooks; - for (const auto& it : devices_) { - const String& device_name = it.first; - const tir::Var& context = it.second; - Array sections = {"Device", device_name, hook}; - String device_hook_name = ToCFunctionStyle(PrefixName(sections)); - - tir::Evaluate device_hook( - AddCheckReturn(tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), - {tvm::tir::StringImm(device_hook_name), context}))); - device_hooks.push_back(device_hook); - } - return tir::SeqStmt(device_hooks); - } - - /** - * \brief Generates a call to a given hook for a single Device function - * \param Var Device context to call hook on - * \param Name of hook to generate statements for - * \return Statement with function call to Device API - */ - tir::Stmt GenerateDeviceHook(const tir::Var& context, const String& hook) { - const auto& it = std::find_if(std::begin(devices_), std::end(devices_), [&](const auto& it) { - return it.second->name_hint == context->name_hint; - }); - const String& device_name = (*it).first; - Array sections = {"Device", device_name, hook}; - String device_hook = ToCFunctionStyle(PrefixName(sections)); - - return tir::Evaluate( - AddCheckReturn(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), - {tvm::tir::StringImm(device_hook), context}))); - } - - /*! - * Utility function to string together different arguments - */ - template - std::string MakeString(Args const&... args) { - std::ostringstream ss; - using List = int[]; - (void)List{0, ((void)(ss << args), 0)...}; - - return ss.str(); - } - - void VisitExpr_(const CallNode* call_node) override { - OnDeviceProps on_device_props = GetOnDeviceProps(call_node); - if (on_device_props.body.defined()) { - VisitExpr(on_device_props.body); - return; - } - - DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node); - CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); - - if (device_copy_props.body.defined()) { - // TODO(mbs): device_copy cleaunp - // Suspect treating as no-op is better since already built into the StorageInfo? - LOG(FATAL) << "The AOT executor does not currently support device_copy"; - return; - } - - // At this point we should only see calls of the form call_lowered(@callee, (args...)), - // where @callee can be a PrimFunc we've compiled or an external function supplied via - // some other mechanism. - ICHECK(call_lowered_props.lowered_func.defined()) - << "AOT does not support calling Relay functions. Attempting to call:" << std::endl - << PrettyPrint(GetRef(call_node)); - for (const auto& arg : call_lowered_props.arguments) { - // Evaluate the args - VisitExpr(arg); - } - CreateFuncCall(call_lowered_props, GetRef(call_node)); - } - - void VisitExpr_(const VarNode* op) override { - Expr expr = GetRef(op); - StorageInfo& sinfo = storage_device_map_[expr]; - - // Let bound vars refer to a value, so these should not be considered "output" vars. - if (let_bound_vars_.find(GetRef(op)) != let_bound_vars_.end()) { - return; - } - - // If the Var node is an output node we need to copy the content of the variable to the output - // It's safe to check the SID here because Var StorageToken are never reallocated - auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]); - if (output_iter != return_sid_.end()) { - int output_index = std::distance(return_sid_.begin(), output_iter); - if (params_by_expr_.find(expr) != params_by_expr_.end()) { - auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), - {tir::StringImm(params_by_expr_[expr])}); - CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), param_handle, - /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); - } else { - auto var_expr = FindExpr(expr); - CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), var_expr[0], - /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); - } - } - } - - void VisitExpr_(const ConstantNode* op) override { - Expr expr = GetRef(op); - ICHECK(storage_device_map_.find(expr) != storage_device_map_.end()) - << "Storage map did not contain constant expr " << PrettyPrint(expr); - StorageInfo& sinfo = storage_device_map_[expr]; - std::stringstream ss; - ss << "constant_" << constant_map_.size(); - - tir::Var constant(ss.str(), PointerType(PrimType(DataType(op->data->dtype)))); - constant_map_[constant] = op; - auto sid = sinfo->storage_ids[0]; - sids_table_[sid] = constant; - - // If the Constant node is an output node we need to copy the content of the parameter to the - // output. A node can only produce a single output - auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid); - if (output_iter != return_sid_.end()) { - int output_index = std::distance(return_sid_.begin(), output_iter); - auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), - {tir::StringImm(ss.str())}); - CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), constant, - /* pack_input */ false, sinfo->storage_sizes_in_bytes[0]); - } - } - - void VisitExpr_(const TupleNode* op) override { - for (auto field : op->fields) { - VisitExpr(field); - } - } - - void VisitExpr_(const LetNode* op) override { - auto pre_visit = [this](const LetNode* op) { - let_bound_vars_.insert(op->var); - this->VisitExpr(op->value); - }; - auto post_visit = [this](const LetNode* op) { - this->VisitExpr(op->body); - this->visit_counter_[op] += 1; - }; - ExpandANormalForm(op, pre_visit, post_visit); - } - - void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); } - void VisitExpr_(const OpNode* op) override { - if (GetRef(op) != CallLoweredOp() && GetRef(op) != OnDeviceOp()) { - LOG(FATAL) << "All OpNodes except for call_lowered should have been expanded"; - } - } - void VisitExpr_(const IfNode* op) override { - LOG(FATAL) << "All GlobalVarNodes should be removed before AOT executor's Codegen is called"; - } - void VisitExpr_(const FunctionNode* op) override { - ICHECK(op->GetAttr(attr::kCompiler).defined()) - << "FunctionNode only supported by custom codegen"; - } - void VisitExpr_(const RefCreateNode* op) override { - LOG(FATAL) << "AOT executor does not support references (found RefCreateNode)"; - } - void VisitExpr_(const RefReadNode* op) override { - LOG(FATAL) << "AOT executor does not support references (found RefReadNode)"; - } - void VisitExpr_(const RefWriteNode* op) override { - LOG(FATAL) << "AOT executor does not support references (found RefWriteNode)"; - } - void VisitExpr_(const ConstructorNode* op) override { - LOG(FATAL) << "AOT executor does not support ADTs (found ConstructorNode)"; - } - void VisitExpr_(const MatchNode* op) override { - LOG(FATAL) << "AOT executor does not support matching (found MatchNode)"; - } - - // Create the main PrimFunc to execute the graph. Please note that - // the packed function calls don't pack their arguments. The AOT - // runner function needs to be legalized by the LegalizePackedCalls pass. - tir::PrimFunc CreateMainFunc(String mod_name, unsigned int relay_params) { - tir::Stmt body = tir::SeqStmt(stmts_); - // Allocate the sids - std::unordered_map allocated; - - for (auto kv : storage_device_map_) { - // Only allocate sids that are needed - const bool is_input = - (std::find(input_vars_.begin(), input_vars_.end(), kv.first) != input_vars_.end()); - const bool is_param = (params_by_expr_.find(kv.first) != params_by_expr_.end()); - if (is_input || is_param) { - continue; - } - - for (unsigned int i = 0; i < kv.second->storage_ids.size(); i++) { - int size = kv.second->storage_sizes_in_bytes[i]; - int sid = kv.second->storage_ids[i]; - - if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) { - continue; - } - - // Make sure it hasn't already been allocated, this can happen - // with let-bound var/value pairs. - if (allocated.find(sid) != allocated.end()) { - continue; - } - - allocated[sid] = constant_map_.count(sids_table_[sid]); - - // TODO(giuseros): we should allocate this once outside the PrimFunc - // so we don't pay the price of allocation for every inference - if (!allocated[sid]) { - PointerType ptype = Downcast(sids_table_[sid]->type_annotation); - DataType element_type = Downcast(ptype->element_type)->dtype; - body = tir::Allocate(sids_table_[sid], element_type, {size}, tir::const_true(), body); - } - allocated[sid] = true; - } - } - - for (auto kv : constant_map_) { - auto buffer_var = kv.first; - auto dtype = DataType(kv.second->data->dtype); - - int ndim = kv.second->data->ndim; - Array extents; - - for (int i = 0; i < ndim; i++) { - int shape = kv.second->data->shape[i]; - extents.push_back(tir::make_const(DataType::Int(32), shape, Span())); - } - body = tir::AllocateConst(buffer_var, dtype, extents, kv.second->data, body); - } - - // Define the PrimFunc attributes - Map dict_attrs; - String run_func_name = runtime::get_name_mangled(mod_name, runtime::symbol::tvm_module_main); - dict_attrs.Set("global_symbol", run_func_name); - dict_attrs.Set("runner_function", Bool(true)); - dict_attrs.Set(tvm::attr::kTarget, config_->host_target); - - tir::Stmt device_activations = GenerateAllDeviceHook("Activate"); - tir::Stmt device_deactivations = GenerateAllDeviceHook("Deactivate"); - tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations}); - - // Make the PrimFunc - return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, {}, - DictAttrs(dict_attrs)); - } - - /*! - * \brief Access IO vars using the buffer vars and - * not the actual var. - */ - tir::Var GetBufferVarForIO(int index) { return main_buffer_map_[main_signature_[index]]->data; } - - /*! - * \brief Create tir::Var for input/output while updating the buffer_maps. - * - * \param expr The expression to evaluate. - * \param original_name The name of the tir::Var. - * \param use_unique_name Whether to generate a new unique name where a name conflicts. - */ - void CreateIOVar(const Expr& expr, const std::string& original_name, - bool use_unique_name = true) { - CreateIOVar(expr->checked_type(), original_name, use_unique_name); - } - - /*! - * \brief Create tir::Var for input/output while updating the buffer_maps. - * - * \param expr The expression to evaluate. - * \param original_name The name of the tir::Var. - * \param use_unique_name Whether to generate a new unique name where a name conflicts. - */ - void CreateIOVar(const Type& type, const std::string& original_name, - bool use_unique_name = true) { - if (type->IsInstance()) { - TupleType tuple_type = Downcast(type); - for (unsigned i = 0; i < tuple_type->fields.size(); i++) { - CreateIOVar(tuple_type->fields[i], original_name); - } - } else { - std::string name = original_name; - if (use_unique_name) { - name = GetUniqueIOVarName(original_name); - } - tir::Var var = tir::Var(name, DataType::Handle()); - main_signature_.push_back(var); - auto tensor_type = type.as(); - ICHECK(tensor_type) << "Expected TensorType node but was " << type->GetTypeKey(); - DataType elem_type = tensor_type->dtype; - tir::Var buffer_var = - tir::Var(name + "_buffer_var", PointerType(PrimType(elem_type), "global")); - tir::Buffer buffer = tir::Buffer(buffer_var, elem_type, tensor_type->shape, {}, 0, - name + "_buffer", 16, 1, tir::BufferType::kDefault); - main_buffer_map_.Set(var, buffer); - io_tensor_types_.Set(var, Downcast(type)); - } - } - - /*! - * \brief Create a unique name for I/O Var - */ - std::string GetUniqueIOVarName(std::string name) { - if (io_var_names_.find(name) == io_var_names_.end()) { - io_var_names_[name] = 1; - return name; - } else { - io_var_names_[name] = io_var_names_[name] + 1; - return name + std::to_string(io_var_names_[name]); - } - } - - /*! - * \brief Calculate workspace sizes for PrimFuncs in the IRModule - */ - Map CalculateWorkspaceSizes( - const IRModule& lowered_mod, const Map& function_metadata) { - Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(lowered_mod); - Map updated_function_metadata; - for (const auto& kv : lowered_mod->functions) { - GlobalVar global_var = kv.first; - BaseFunc base_func = kv.second; - if (base_func->IsInstance()) { - tir::PrimFunc pfunc = Downcast(base_func); - Target tgt = pfunc->GetAttr(tvm::attr::kTarget).value(); - const auto& ws = CalculateWorkspaceBytes(pfunc, workspace_byte_alignment); - if (function_metadata.count(global_var->name_hint)) { - updated_function_metadata.Set(global_var->name_hint, - function_metadata[global_var->name_hint]); - updated_function_metadata[global_var->name_hint]->workspace_sizes.Set(tgt, ws); - } else { - FunctionInfo finfo{{{tgt, ws}}, {}, {}, {{tgt, pfunc}}, {}}; - updated_function_metadata.Set(global_var->name_hint, finfo); - } - } - } - return updated_function_metadata; - } - - /*! - * \brief Run USMP to plan memory for lowered IRModule. - */ - IRModule PlanMemoryWithUSMP(const IRModule& mod) { - VLOG(1) << "Planning memory with USMP for module:" << std::endl << PrettyPrint(mod); - Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(mod); - IRModule lowered_mod = mod->ShallowCopy(); - lowered_mod = tir::transform::UnifiedStaticMemoryPlanner()(lowered_mod); - function_metadata_ = CalculateWorkspaceSizes(lowered_mod, function_metadata_); - Optional> allocated_pool_infos = - lowered_mod->GetAttr>(tvm::attr::kPoolArgs); - backend::FunctionInfo main_func_info = - lowered_mod->GetAttr("main_func_info").value(); - main_func_info->workspace_sizes.clear(); - if (allocated_pool_infos) { - for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) { - for (const auto& tgt : allocated_pool_info->pool_info->targets) { - VLOG(1) << "USMP requires target " << tgt->ToDebugString() << " to have pool size " - << allocated_pool_info->allocated_size->value; - size_t size = allocated_pool_info->allocated_size->value; - if (allocated_pool_info->pool_info->IsInstance()) { - size += main_func_info->constant_sizes.count(tgt) - ? main_func_info->constant_sizes[tgt]->value - : 0; - main_func_info->constant_sizes.Set(tgt, size); - } else if (allocated_pool_info->pool_info->IsInstance()) { - size += main_func_info->workspace_sizes.count(tgt) - ? main_func_info->workspace_sizes[tgt]->value - : 0; - main_func_info->workspace_sizes.Set(tgt, size); - } else { - LOG(FATAL) << "Unknown pool type: " << allocated_pool_info->pool_info->GetTypeKey(); - } - } - } - } - function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info); - return lowered_mod; - } - - /*! - * \brief Run StorageRewrite to plan memory for lowered IRModule. - */ - IRModule PlanMemoryWithStorageRewrite(const IRModule& mod) { - Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(mod); - IRModule lowered_mod = mod->ShallowCopy(); - function_metadata_ = CalculateWorkspaceSizes(lowered_mod, function_metadata_); - // Running StorageRewrite just on the main function - tir::PrimFunc tir_main_func = - Downcast(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); - IRModule main_func_mod; - main_func_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main), - tir_main_func); - main_func_mod = tir::transform::StorageRewrite()(main_func_mod); - lowered_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main), - main_func_mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); - tir_main_func = - Downcast(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); - // Use the PrimFunc to calculate the workspace required to service the allocates - Integer main_workspace_size_bytes = - CalculateWorkspaceBytes(tir_main_func, workspace_byte_alignment); - backend::FunctionInfo main_func_info = - lowered_mod->GetAttr("main_func_info").value(); - main_func_info->workspace_sizes.Set(config_->host_target, main_workspace_size_bytes); - function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info); - return lowered_mod; - } - - /*! - * \brief Gets module workspace alignment from supplied executor or defaults to 16 - */ - Integer GetModuleWorkspaceByteAlignment(const IRModule& mod) { - Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); - return executor_config->GetAttr("workspace-byte-alignment").value_or(16); - } - - /*! - * \brief Gets module constant alignment from supplied executor or defaults to 16 - */ - Integer GetModuleConstantByteAlignment(const IRModule& mod) { - Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); - return executor_config->GetAttr("constant-byte-alignment").value_or(16); - } - - protected: - /*! \brief mod */ - runtime::Module* mod_; - /*! \brief list of input expressions (i.e., variable passed by the user) */ - std::vector input_vars_; - /*! \brief map of device contexts variables */ - Map devices_; - /*! \brief map of GlobalVars to C Device API contexts */ - Map device_contexts_; - /*! \brief input and output variables belonging to the main function signature */ - Array main_signature_; - /*! \brief input and output variables belonging to the main function signature */ - Map main_buffer_map_; - /*! \brief maps input and output variables to TensorType which describe them */ - Map io_tensor_types_; - /*! \brief All available targets. */ - CompilationConfig config_; - /*! - * \brief The type of kernel call to be emitted. - * See CallType for more documentation. - */ - CallType call_type_; - - /*! - * \brief parameters (i.e. ConstantNodes found in the graph). - * These are take as inputs to the GraphRuntime. - * Maps param name to a pair of storage_id and NDArray. At runtime, the storage_id can be - * used to lookup the parameter. - */ - std::unordered_map params_; - /*! \brief mapping between expression and parameters */ - Map params_by_expr_; - /*! \brief mapping between parameter names ("p0", "p1", etc..) and storage identifiers*/ - std::unordered_map param_storage_ids_; - std::unordered_map - constant_map_; - - /*! \brief plan memory of device result */ - StorageMap storage_device_map_; - /*! \brief mapping sid -> tir::Var */ - std::unordered_map sids_table_; - /*! \brief lowered funcs */ - Map function_metadata_; - /*! \brief the set of statements that make the program */ - std::vector stmts_; - /*! \brief the list of return sids (note that the function might return more then one output */ - std::vector return_sid_; - /*! \brief This is per IO var name counter to aid the generating unique names */ - std::unordered_map io_var_names_; - /*! \brief A set of variables that are let bound. */ - std::unordered_set let_bound_vars_; - - public: - AOTExecutorCodegen(runtime::Module* mod, const Array& targets) - : mod_(mod), config_(transform::PassContext::Current(), targets) {} - - LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) { - VLOG_CONTEXT << "AOT"; - - Runtime runtime_config = mod->GetAttr(tvm::attr::kRuntime).value(); - Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(mod); - - Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); - std::string interface_api = - executor_config->GetAttr("interface-api").value_or("packed"); - bool unpacked_api = executor_config->GetAttr("unpacked-api").value_or(Bool(false)); - - // Validate choice of unpacked_api and use_call_cpacked_ - if (runtime_config->name == kTvmRuntimeCrt) { - if (unpacked_api == true) { - call_type_ = CallType::kUnpacked; - } else if (unpacked_api == false && interface_api == "packed") { - call_type_ = CallType::kCPacked; - } else { - CHECK(interface_api == "packed" || unpacked_api == true) - << "Either need interface_api == \"packed\" (got: " << interface_api - << ") or unpacked-api == true (got: " << unpacked_api << ") when targeting c runtime"; - ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api - << ", unpacked-api=" << unpacked_api; - } - } else if (runtime_config->name == kTvmRuntimeCpp) { - if (unpacked_api == false && interface_api == "packed") { - call_type_ = CallType::kCPacked; - } else { - CHECK(static_cast(unpacked_api) == false && interface_api == "packed") - << "Need unpacked-api == false (got: " << unpacked_api - << ") and interface-api == \"packed\" (got: " << interface_api - << ") when targeting c++ runtime"; - ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api - << ", unpacked-api=" << unpacked_api; - } - } else { - ICHECK(false) << "runtime_config (" << runtime_config->name - << ") is not one of the expected values"; - } - - mod = transform::ToANormalForm()(mod); - mod = transform::InferType()(mod); - mod = transform::AnnotateUsedMemory()(mod); - - IRModule lowered_mod = - tec::LowerTE(mod_name, config_, [this, workspace_byte_alignment](BaseFunc func) { - // We need to maintain the constant map for external - // functions so we pass this processing function which - // allows us to process each function as we lower it. - if (func->GetAttr(attr::kCompiler).defined()) { - UpdateConstants(func, ¶ms_); - } - - // TODO(@areusch, @jroesch): We should refactor this to - // execute as a further pass, instead writing data to the - // lowering process directly. - tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment); - })(mod); - - transform::PassContext pass_ctx = transform::PassContext::Current(); - bool enable_remove_reshapes = - pass_ctx->GetConfig("relay.remove_standalone_reshapes.enable", Bool(true)).value(); - if (enable_remove_reshapes) { - lowered_mod = transform::RemoveStandaloneReshapes()(lowered_mod); - } - auto lowered_main = lowered_mod->Lookup("main"); - auto lowered_main_func = GetRef(lowered_main.as()); - - // Post-lowering storage map for writing main func - AOTOnDemandAllocator final_aot_allocator; - final_aot_allocator.Run(lowered_main_func); - storage_device_map_ = final_aot_allocator.GetStorageMap(); - - // TODO(@electriclilies, @jroesch, @Mousius): remove UpdateMainWorkspaceSize - StaticMemoryPlan memory_plan(storage_device_map_); - backend::FunctionInfo func_info = - tec::UpdateMainWorkspaceSize(lowered_mod, config_, memory_plan->expr_to_storage_info); - lowered_mod = WithAttr(lowered_mod, "main_func_info", func_info); - - for (auto input : lowered_main_func->params) { - input_vars_.push_back(input); - std::string input_name = SanitizeName(input->name_hint()); - // We dont want the compiler changing input names in the - // event of a sanitization collision. Therefore, enforcing - // the var created to use the input_name strictly. - CreateIOVar(input, input_name, /*use_unique_name = */ false); - } - - // Define the storage allocator ids - for (auto kv : storage_device_map_) { - for (auto sid : kv.second->storage_ids) { - // The buffer_var is created with storage_scope to be global.workspace to be serviced by - // TVMBackendAllocWorkspace(TVMBAW) calls, explicitly. The reasoning being the executor - // allocates should be serviced by TVMBAWs as the data could be accessed by many devices and - // should not be lowered to the stack. For more details please refer to the discussion here: - // https://github.com/apache/tvm/issues/9022 - te::Var buffer_var(MakeString("sid_", sid), - PointerType(PrimType(DataType::Int(8)), "global.workspace")); - sids_table_[sid] = buffer_var; - } - } - - // Retrieve the return sids - return_sid_ = final_aot_allocator.GetReturnIds(); - // Insert outputs to main func signature - // If output tensor names were provided use them - if (auto opt = func->GetAttr>("output_tensor_names")) { - Array output_tensor_names = opt.value(); - Expr output_expr = lowered_main_func->body; - if (output_expr->checked_type()->IsInstance()) { - TupleType output_tuple_type = Downcast(output_expr->checked_type()); - for (unsigned i = 0; i < output_tuple_type->fields.size(); i++) { - // AoT Executor Codegen does not create these names, - // thus should be used as they are provided. - CreateIOVar(output_tuple_type->fields[i], output_tensor_names[i], - /*use_unique_name = */ false); - } - } else { - // AoT Executor Codegen does not create these names, - // thus should be used as they are provided. - CreateIOVar(lowered_main_func->body, output_tensor_names[0], /*use_unique_name = */ false); - } - } else { - // If output tensor names are not provided we will generate output(x) - // where x is a counter to create unique names. - CreateIOVar(lowered_main_func->body, "output"); - } - - CollectDeviceVariables(lowered_mod->GetAttr>("device_contexts").value()); - VisitExpr(lowered_main_func->body); - - // Create the runner function. Please note that the function is not legal yet - // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need - // to run the LegalizePackedCalls pass. - LoweredOutput ret; - - // Collect any constants extracted by external codegen. - ret.params = std::unordered_map(); - Map const_name_to_constant = - lowered_mod->GetAttr>(tvm::attr::kConstNameToConstant) - .value_or({}); - for (const auto& kv : const_name_to_constant) { - ICHECK(ret.params.emplace(kv.first, kv.second).second); - } - - // Collect any constants extracted during lowering. - for (const auto& kv : params_) { - ICHECK(ret.params.emplace(kv.first, kv.second).second); - } - - // AoT Executor codegen works completely on TIR beyond this point, hence removing relay main - // function and replacing it with its TIR version. We should try to make this a Pass. - lowered_mod->Remove(lowered_mod->GetGlobalVar("main")); - auto tir_main_func = CreateMainFunc(mod_name, lowered_main_func->params.size()); - // Extract additional information around main TIR PrimFunc arguments - Array devices = ListDevices(); - const auto main_func_params_end_iterator = - tir_main_func->params.begin() + tir_main_func->params.size(); - const auto outputs_begin_iterator = - main_func_params_end_iterator - return_sid_.size() - devices.size(); - Array inputs = Array(tir_main_func->params.begin(), outputs_begin_iterator); - Array input_tensor_types; - for (auto i : inputs) { - input_tensor_types.push_back(io_tensor_types_[i]); - } - Array outputs = - Array(outputs_begin_iterator, main_func_params_end_iterator - devices.size()); - - lowered_mod->Update(GlobalVar(::tvm::runtime::symbol::tvm_module_main), tir_main_func); - // Parallel for loops are not supported in AoT codegen. - lowered_mod = tir::transform::ConvertForLoopsToSerial()(lowered_mod); - - bool enable_usmp = pass_ctx->GetConfig(kUSMPEnableOption, Bool(false)).value(); - if (enable_usmp) { - lowered_mod = PlanMemoryWithUSMP(lowered_mod); - } else { - lowered_mod = PlanMemoryWithStorageRewrite(lowered_mod); - } - ret.function_metadata = std::move(function_metadata_); - - // Legalize AOT if needed. This means that all the packed calls - // need to be wrapped in TVMValues (unless unpacked_api is set) - if (call_type_ == CallType::kCPacked || call_type_ == CallType::kPacked) { - auto pack_calls = tir::transform::LegalizePackedCalls(); - lowered_mod = pack_calls(lowered_mod); - } - - // Collect any runtime modules generated by external codegen. - ret.external_mods = - lowered_mod->GetAttr>(tvm::attr::kExternalMods).value_or({}); - - // This is the point where we separate the functions in the module by target - VLOG(1) << "lowered module:" << std::endl << PrettyPrint(lowered_mod); - ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod); - VLOG(1) << "per-target modules:"; - for (const auto& kv : ret.lowered_funcs) { - VLOG(1) << "target:" << std::endl - << kv.first->ToDebugString() << std::endl - << "maps to:" << std::endl - << PrettyPrint(kv.second); - } - - // Extract USMP metadata to pass onto metadata sources - Map pool_var_info; - std::vector pool_vars; - tir_main_func = - Downcast(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); - Optional> allocated_pool_infos = - tir_main_func->GetAttr>(tvm::attr::kPoolArgs); - if (allocated_pool_infos) { - for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) { - int pool_var_index = allocated_pool_info->pool_var_idx.value()->value; - pool_vars.push_back(tir_main_func->params[pool_var_index]); - pool_var_info.Set(tir_main_func->params[pool_var_index], allocated_pool_info); - } - } - Map io_pool_allocations = - lowered_mod - ->GetAttr>(tvm::attr::kIOTensorPoolAllocations) - .value_or({}); - - std::vector output_var_names; - if (auto opt = func->GetAttr>("output_tensor_names")) { - Array output_tensor_names = opt.value(); - for (size_t i = 0; i < output_tensor_names.size(); ++i) { - output_var_names.push_back(output_tensor_names[i]); - } - } - - // If output names have not been specified then generate default output names - if (output_var_names.size() == 0) { - if (return_sid_.size() == 1) { - output_var_names.push_back(String("output")); - } else { - for (size_t i = 0; i < return_sid_.size(); ++i) { - output_var_names.push_back(String("output" + std::to_string(i))); - } - } - } - - Array output_tensor_types{final_aot_allocator.GetReturnTtypes()}; - - ret.metadata = ExecutorCodegenMetadata( - inputs, input_tensor_types, output_var_names, output_tensor_types, pool_vars, devices, - runtime::kTvmExecutorAot, mod_name, interface_api, unpacked_api, - GetModuleWorkspaceByteAlignment(mod), GetModuleConstantByteAlignment(mod), pool_var_info, - io_pool_allocations); - return ret; - } - - /*! - * \brief Get list of devices found - * \return List of devices - */ - Array ListDevices() { - std::vector device_names(devices_.size()); - std::transform(devices_.begin(), devices_.end(), device_names.begin(), - [](const auto& it) -> String { return it.first; }); - return device_names; - } -}; // namespace backend - -class AOTExecutorCodegenModule : public runtime::ModuleNode { - public: - AOTExecutorCodegenModule() {} - virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { - if (name == "init") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: " - << "runtime::Module mod and Array targets"; - void* mod = args[0]; - Array targets = args[1]; - init(mod, targets); - }); - } else if (name == "codegen") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - IRModule mod = args[0]; - Function func = args[1]; - String mod_name = args[2]; - this->output_ = this->codegen_->Codegen(mod, func, mod_name); - }); - } else if (name == "list_params_name") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = list_params_name(); }); - } else if (name == "get_param_by_name") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - String key = args[0]; - *rv = get_param_by_name(key); - }); - } else if (name == "get_irmodule") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_irmodule(); }); - } else if (name == "get_external_modules") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_external_modules(); }); - } else if (name == "get_function_metadata") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->output_.function_metadata; - }); - } else if (name == "get_devices") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->codegen_->ListDevices(); - }); - } else if (name == "get_executor_codegen_metadata") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = output_.metadata; }); - } else { - return PackedFunc([](TVMArgs args, TVMRetValue* rv) {}); - } - } - - const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; } - - private: - void init(void* mod, const Array& targets) { - codegen_ = - std::make_shared(reinterpret_cast(mod), targets); - } - - Array list_params_name() { - Array ret; - for (const auto& kv : this->output_.params) { - ret.push_back(kv.first); - } - return ret; - } - - runtime::NDArray get_param_by_name(String key) { - auto it = this->output_.params.find(key); - CHECK(it != this->output_.params.end()) << "no such parameter " << key; - return (*it).second; - } - - Array get_external_modules() { return output_.external_mods; } - - Map get_irmodule() { return this->output_.lowered_funcs; } - - std::shared_ptr codegen_; - LoweredOutput output_; -}; - -runtime::Module CreateAOTExecutorCodegenMod() { - auto ptr = make_object(); - return runtime::Module(ptr); -} - -TVM_REGISTER_GLOBAL("relay.build_module._AOTExecutorCodegen") - .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateAOTExecutorCodegenMod(); }); - -} // namespace backend -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index bca524794a200..3f0872de479b4 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -65,8 +65,10 @@ struct ExecutorCodegen { CallFunc("init", m, raw_targets); } - void Codegen(IRModule mod, const Function& func, String mod_name) { - CallFunc("codegen", mod, func, mod_name); + void Codegen(IRModule mod, const Function& func, String mod_name, CompilationConfig config, + Executor executor, CallType call_type) { + CallFunc("codegen", mod, func, mod_name, config, executor, + Integer(static_cast(call_type))); } virtual void UpdateOutput(BuildOutput* ret) = 0; @@ -302,11 +304,45 @@ class RelayBuildModule : public runtime::ModuleNode { workspace_memory_pools_ = workspace_memory_pools; constant_memory_pools_ = constant_memory_pools; config_ = CompilationConfig(PassContext::Current(), raw_targets); + SetCallType(executor, runtime); VLOG(1) << "Using compilation config:" << std::endl << config_; BuildRelay(std::move(mod), mod_name); } protected: + void SetCallType(const Executor& executor, const Runtime& runtime) { + std::string interface_api = executor->GetAttr("interface-api").value_or("packed"); + bool unpacked_api = executor->GetAttr("unpacked-api").value_or(Bool(false)); + + // Validate choice of unpacked_api and use_call_cpacked_ + if (runtime->name == kTvmRuntimeCrt) { + if (unpacked_api == true) { + call_type_ = CallType::kUnpacked; + } else if (unpacked_api == false && interface_api == "packed") { + call_type_ = CallType::kCPacked; + } else { + CHECK(interface_api == "packed" || unpacked_api == true) + << "Either need interface_api == \"packed\" (got: " << interface_api + << ") or unpacked-api == true (got: " << unpacked_api << ") when targeting c runtime"; + ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api + << ", unpacked-api=" << unpacked_api; + } + } else if (runtime->name == kTvmRuntimeCpp) { + if (unpacked_api == false && interface_api == "packed") { + call_type_ = CallType::kCPacked; + } else { + CHECK(static_cast(unpacked_api) == false && interface_api == "packed") + << "Need unpacked-api == false (got: " << unpacked_api + << ") and interface-api == \"packed\" (got: " << interface_api + << ") when targeting c++ runtime"; + ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api + << ", unpacked-api=" << unpacked_api; + } + } else { + ICHECK(false) << "runtime (" << runtime->name << ") is not one of the expected values"; + } + } + /*! * \brief Optimize a Relay IRModule. * @@ -428,7 +464,7 @@ class RelayBuildModule : public runtime::ModuleNode { // Generate code for the updated function. executor_codegen_ = MakeExecutorCodegen(executor_->name); executor_codegen_->Init(nullptr, config_->primitive_targets); - executor_codegen_->Codegen(func_module, func, mod_name); + executor_codegen_->Codegen(func_module, func, mod_name, config_, executor_, call_type_); executor_codegen_->UpdateOutput(&ret_); ret_.params = executor_codegen_->GetParams(); @@ -484,6 +520,7 @@ class RelayBuildModule : public runtime::ModuleNode { Executor executor_; /*! \brief Runtime to codegen for */ Runtime runtime_; + CallType call_type_; /*! \brief Workspace memory pools to codegen for */ WorkspaceMemoryPools workspace_memory_pools_; /*! \brief Constant memory pools to codegen for */ diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index c3426f147e0d1..9d255b39c2b70 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -984,7 +984,7 @@ def test_aot_codegen_checks_returns(): # Check operator call is wrapped properly assert ( - str(main_func.body[1]) + str(main_func.body) == "tir.tvm_check_return(0, -1, tir.call_extern(" + '"tvmgen_default_fused_add",' + " x_buffer_var, y_buffer_var, output_buffer_var))\n"