diff --git a/src/relay/backend/aot/aot_executor_codegen.cc b/src/relay/backend/aot/aot_executor_codegen.cc new file mode 100644 index 000000000000..9fd6a0b4c645 --- /dev/null +++ b/src/relay/backend/aot/aot_executor_codegen.cc @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/aot_executor_codegen.cc + * \brief AOT executor codegen + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../te_compiler.h" +#include "../utils.h" +#include "./aot_lower_main.h" +#include "./create_executor_metadata.h" +#include "./create_function_metadata.h" + +namespace tvm { +namespace relay { +namespace backend { +namespace aot { + +std::unordered_map CreateParamMap( + const IRModule& mod, const std::unordered_map& external_params) { + auto params = std::unordered_map(); + // Collect any constants extracted by external codegen. + Map const_name_to_constant = + mod->GetAttr>(tvm::attr::kConstNameToConstant).value_or({}); + for (const auto& kv : const_name_to_constant) { + params[kv.first] = kv.second; + } + + // Collect any constants extracted during lowering. + for (const auto& kv : external_params) { + params[kv.first] = kv.second; + } + + return params; +} + +LoweredOutput Codegen(IRModule mod, String mod_name, CompilationConfig config, Executor executor, + CallType call_type) { + Integer workspace_byte_alignment = + executor->GetAttr("workspace-byte-alignment").value_or(16); + Integer constant_byte_alignment = + executor->GetAttr("constant-byte-alignment").value_or(16); + // Required Relay passes prior to AOT codegen (should be refactored out of executors) + mod = transform::ToANormalForm()(mod); + mod = transform::InferType()(mod); + mod = transform::AnnotateUsedMemory()(mod); // TODO(mbaret) Move into Ethos-U hook + std::unordered_map external_params; + mod = tec::LowerTE(mod_name, config, [&external_params](BaseFunc func) { + if (func->GetAttr(attr::kCompiler).defined()) { + UpdateConstants(func, &external_params); + } + })(mod); + + transform::PassContext pass_ctx = transform::PassContext::Current(); + bool enable_remove_reshapes = + pass_ctx->GetConfig("relay.remove_standalone_reshapes.enable", Bool(true)).value(); + if (enable_remove_reshapes) { + mod = transform::RemoveStandaloneReshapes()(mod); + } + + // Lower the main Relay function to a TIR PrimFunc + // After this point the entire module is composed of PrimFuncs + mod = AOTLowerMain(mod_name, config, call_type)(mod); + + mod = tir::transform::ConvertForLoopsToSerial()(mod); // TODO(mbaret) Make this optional + bool enable_usmp = pass_ctx->GetConfig(kUSMPEnableOption, Bool(false)).value(); + if (enable_usmp) { + mod = tir::transform::UnifiedStaticMemoryPlanner()(mod); + } else { + tir::PrimFunc tir_main_func = + Downcast(mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); + IRModule main_func_mod; + main_func_mod->Update(mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main), + tir_main_func); + main_func_mod = tir::transform::StorageRewrite()(main_func_mod); + mod->Update(mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main), + main_func_mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); + } + mod = tir::transform::LegalizePackedCalls()(mod); + + // Collect the various functions, params and metadata into a LoweredOutput + LoweredOutput ret; + ret.params = CreateParamMap(mod, external_params); + ret.external_mods = + mod->GetAttr>(tvm::attr::kExternalMods).value_or({}); + ret.function_metadata = + std::move(CreateFunctionMetadata(mod, workspace_byte_alignment, constant_byte_alignment)); + ret.lowered_funcs = tec::GetPerTargetModules(mod); + ret.metadata = CreateExecutorMetadata(mod, mod_name, executor, workspace_byte_alignment, + constant_byte_alignment); + return LoweredOutput(std::move(ret)); +} + +class AOTExecutorCodegenModule : public runtime::ModuleNode { + public: + AOTExecutorCodegenModule() {} + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "init") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + // Do nothing + }); + } else if (name == "codegen") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + IRModule mod = args[0]; + Function func = args[1]; + String mod_name = args[2]; + CompilationConfig config = args[3]; + Executor executor = args[4]; + Integer call_type = args[5]; + this->output_ = + Codegen(mod, mod_name, config, executor, static_cast(call_type->value)); + }); + } else if (name == "list_params_name") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = list_params_name(); }); + } else if (name == "get_param_by_name") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + String key = args[0]; + *rv = get_param_by_name(key); + }); + } else if (name == "get_irmodule") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_irmodule(); }); + } else if (name == "get_external_modules") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_external_modules(); }); + } else if (name == "get_function_metadata") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->output_.function_metadata; + }); + } else if (name == "get_devices") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = output_.metadata->devices; }); + } else if (name == "get_executor_codegen_metadata") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = output_.metadata; }); + } else { + return PackedFunc([](TVMArgs args, TVMRetValue* rv) {}); + } + } + + const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; } + + private: + Array list_params_name() { + Array ret; + for (const auto& kv : this->output_.params) { + ret.push_back(kv.first); + } + return ret; + } + + runtime::NDArray get_param_by_name(String key) { + auto it = this->output_.params.find(key); + CHECK(it != this->output_.params.end()) << "no such parameter " << key; + return (*it).second; + } + + Array get_external_modules() { return output_.external_mods; } + + Map get_irmodule() { return this->output_.lowered_funcs; } + + LoweredOutput output_; +}; + +runtime::Module CreateAOTExecutorCodegenMod() { + auto ptr = make_object(); + return runtime::Module(ptr); +} + +TVM_REGISTER_GLOBAL("relay.build_module._AOTExecutorCodegen") + .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateAOTExecutorCodegenMod(); }); + +} // namespace aot +} // namespace backend +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/aot/create_function_metadata.cc b/src/relay/backend/aot/create_function_metadata.cc index 54fd270c1b25..2ef5e495abca 100644 --- a/src/relay/backend/aot/create_function_metadata.cc +++ b/src/relay/backend/aot/create_function_metadata.cc @@ -62,8 +62,7 @@ Map CalculateFunctionInfos(const IRModule& mod, auto params = pfunc->params; int64_t total_io_bytes = 0; for (const auto& param : params) { - // Inputs/outputs will be handles, workspaces are pointers - if (param->dtype.is_handle()) { + if (pfunc->buffer_map.find(param) != pfunc->buffer_map.end()) { auto buffer = pfunc->buffer_map[param]; total_io_bytes += GetMemorySizeBytes(buffer->shape, buffer->dtype); } diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc deleted file mode 100644 index 786b3f81a5ae..000000000000 --- a/src/relay/backend/aot_executor_codegen.cc +++ /dev/null @@ -1,1392 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/relay/backend/aot_executor_codegen.cc - * \brief AOT executor codegen - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "../../target/source/codegen_source_base.h" -#include "../op/annotation/annotation.h" -#include "../op/call/call.h" -#include "../op/memory/device_copy.h" -#include "../transforms/device_aware_visitors.h" -#include "./name_transforms.h" -#include "./te_compiler.h" -#include "./utils.h" - -namespace tvm { -namespace relay { -namespace backend { - -using StorageMap = - std::unordered_map; - -/** - * This is an on demand allocator for AOT. A new temporary - * (storage allocator identifier) is allocated for each operation. - */ -class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { - public: - AOTOnDemandAllocator() : transform::DeviceAwareExprVisitor(Optional()) {} - - // run the visitor on a global function. - void Run(const Function& func) { VisitExpr(func); } - - std::vector GetReturnIds() const { return return_ids_; } - std::vector GetReturnTtypes() const { return return_ttypes_; } - - StorageMap GetStorageMap() const { return storage_device_map_; } - - using ExprVisitor::VisitExpr_; - - void VisitExpr_(const ConstantNode* op) final { - CreateStorage(op); - AssignReturnSid(GetRef(op)); - } - - void DeviceAwareVisitExpr_(const CallNode* call_node) final { - // AOTOnDemandAllocator is run both before and after lowering, so we need to handle the case - // where the op of the call is a generic function - - Expr func; - Array args; - - CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); - if (call_lowered_props.lowered_func.defined()) { - func = call_lowered_props.lowered_func; - args = call_lowered_props.arguments; - } else { // Relay functions that have not been lowered and lowered extern functions - func = call_node->op; - args = call_node->args; - if (call_node->op.as()) { // Lowered extern function - ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes."; - } else { // Relay function which has not been lowered yet - ICHECK(call_node->op.as()) - << "Expected the call to be to a lowered primfunc, a lowered extern function or a " - "unlowered Relay function."; - } - } - VisitExpr(func); - CreateStorage(call_node); - for (const Expr& arg : args) { - VisitExpr(arg); - } - AssignReturnSid(GetRef(call_node)); - } - - void VisitExpr_(const VarNode* op) final { AssignReturnSid(GetRef(op)); } - - void DeviceAwareVisitExpr_(const FunctionNode* func_node) final { - if (function_nesting() > 1) { - // do not recurse into sub functions. - return; - } - if (func_node->HasNonzeroAttr(attr::kPrimitive)) { - // No storage needed for primitive functions. - return; - } - for (const auto& param : func_node->params) { - CreateStorage(param.get()); - } - VisitExpr(func_node->body); - } - - void VisitExpr_(const GlobalVarNode* op) final { - // Do nothing. - } - - void VisitExpr_(const OpNode* op) final { - // Do nothing. - } - - void VisitExpr_(const TupleNode* op) final { - std::vector storage_ids; - std::vector virtual_devices; - std::vector storage_sizes_in_bytes; - Expr expr = GetRef(op); - for (Expr field : op->fields) { - auto sid = GetStorage(field); - storage_ids.insert(storage_ids.end(), sid->storage_ids.begin(), sid->storage_ids.end()); - virtual_devices.insert(virtual_devices.end(), sid->virtual_devices.begin(), - sid->virtual_devices.end()); - storage_sizes_in_bytes.insert(storage_sizes_in_bytes.end(), - sid->storage_sizes_in_bytes.begin(), - sid->storage_sizes_in_bytes.end()); - } - storage_device_map_[expr] = StorageInfo(storage_ids, virtual_devices, storage_sizes_in_bytes); - AssignReturnSid(expr); - } - - void VisitExpr_(const TupleGetItemNode* op) final { - Expr expr = GetRef(op); - auto sids = GetStorage(op->tuple); - ICHECK_LT(static_cast(op->index), sids->storage_ids.size()); - storage_device_map_[expr] = - StorageInfo({sids->storage_ids[op->index]}, {sids->virtual_devices[op->index]}, - {sids->storage_sizes_in_bytes[op->index]}); - AssignReturnSid(expr); - } - - void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; } - - void PreVisitLetBinding_(const Var& var, const Expr& value) final { - VisitExpr(value); - StorageInfo si = GetStorage(value); - storage_device_map_[var] = si; - } - - private: - void AssignReturnSid(Expr e) { - if (storage_device_map_.find(e) != storage_device_map_.end()) { - StorageInfo& sinfo = storage_device_map_[e]; - return_ids_.clear(); - for (auto sid : sinfo->storage_ids) { - return_ids_.push_back(sid); - } - return_ttypes_.clear(); - return_ttypes_ = FlattenTupleType(e->checked_type()); - } - } - /*! - * \brief ceil(size/word_size) to get number of words. - * \param size The original size. - * \param word_size The element size. - */ - static size_t DivRoundUp(size_t size, size_t word_size) { - return (size + word_size - 1) / word_size; - } - /*! - * \brief Get the memory requirement. - * \param prototype The prototype token. - * \return The required memory size. - * - * TODO(mbs): Cf CalculateRelayExprSizeBytes in utils.cc, GetMemorySize is graph_plan_memory.cc - */ - size_t GetMemorySizeBytes(const TensorType& ttype) { - size_t size = 1; - for (IndexExpr dim : ttype->shape) { - const int64_t* pval = tir::as_const_int(dim); - ICHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << ttype->shape; - ICHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval; - size *= static_cast(pval[0]); - } - size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8); - return size; - } - /*! - * \brief Get the necessary storage for the expression. - * \param expr The expression. - * \return The corresponding token. - */ - StorageInfo GetStorage(const Expr& expr) { - // See through "on_device" calls. - Expr true_expr = IgnoreOnDevice(expr); - VisitExpr(true_expr); - auto it = storage_device_map_.find(true_expr); - ICHECK(it != storage_device_map_.end()) << "Could not find " << true_expr->GetTypeKey() << " " - << PrettyPrint(true_expr) << " in storage device map"; - return it->second; - } - - /*! - * \brief Create storage for the expression. - */ - void CreateStorage(const ExprNode* op) { - Expr expr = GetRef(op); - return CreateStorage(expr, GetVirtualDevice(expr)); - } - - /*! - * \brief Create storage to hold the result of evaluating \p expr in \p virtual_device. - */ - void CreateStorage(const Expr& expr, const VirtualDevice& virtual_device) { - ICHECK(!virtual_device->IsFullyUnconstrained()) - << "invalid virtual device for expr:" << std::endl - << PrettyPrint(expr); - std::vector storage_ids; - std::vector virtual_devices; - std::vector storage_sizes_in_bytes; - for (const auto& ttype : FlattenTupleType(expr->checked_type())) { - storage_ids.push_back(next_available_sid_++); - virtual_devices.push_back(virtual_device); - storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype)); - } - storage_device_map_[expr] = StorageInfo(std::move(storage_ids), std::move(virtual_devices), - std::move(storage_sizes_in_bytes)); - } - - /*! \brief mapping of expression -> storageInfo */ - StorageMap storage_device_map_; - /*! \brief current id of the temporary allocated */ - int next_available_sid_{0}; - /*! \brief the set of intermediate tensors that are return variables */ - std::vector return_ids_; - /*! \brief the data types of the return values */ - std::vector return_ttypes_; -}; - -/*! \brief Code generator for AOT executor */ -class AOTExecutorCodegen : public MixedModeVisitor { - protected: - /*! \brief Describes the type of kernel call emitted. */ - enum CallType { - /*! - * \brief Emit PackedFunc calls bound just-in-time using TVMBackend* functions. - * - * When this type is selected, assumes all operators must be called via TVMFuncCall. Given the - * implementation of TVMFuncCall in the C++ runtime, this in practice implies that those - * functions are of type TVMBackendPackedCFunc. - * - * The following code is emitted at call sites to call a function named `func`: - * void* func_ptr = TVMBackendGetFuncFromEnv("func"); - * TVMFuncCall(func_ptr, values, tcodes, num_args, ret_values, ret_tcodes) - * - * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args` - * by LowerTVMBuiltin TIR transform. - * - * If `resource_handle` is passed to `func`, it is determined by TVMFuncCall (often, - * `resource_handle` is registered with the C++ runtime to provide a `this` equivalent when - * `func` is implemented in C). - * - * Compatible with both C++ and C runtimes, implemented with the C runtime only. - */ - kPacked, // Emit tir.call_packed and wrap all arguments in DLTensor. - - /*! - * \brief Directly call a TVMBackendPackedCFunc named according to the tir::Call. - * - * When this type is selected, assumes all operators are implemented in functions of type - * `TVMBackendPackedCFunc` and should be called directly. That is, presumes at the time of - * downstream compilation that there is a symbol named after the 0th arg to tir::Call of - * type `TVMBackendPackedCFunc`. This situation should occur when target_host == target. - * - * The following code is emitted at call sites to call a function named `func`: - * func(values, tcodes, num_args, ret_values, ret_tcodes, resource_handle) - * - * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args` - * by LowerTVMBuiltin TIR transform. - * - * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is - * always the device context parameter when not null. At present, the implementation does not - * support forwarding device context parameters to CPacked. - * - * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented - * in the same scenarios. - */ - kCPacked, // Emit tir.call_cpacked and wrap all arguments in DLTensor. - - /*! \brief Directly call a function accepting the `data` arrays as args. - * - * When this type is selected, assumes all operaotrs are implemented in C functions whose - * arguments are 1-to-1 with those in the tir::Call. DLTensor arguments are encoded as just the - * `data` parameters (i.e. no DLTensor object is passed along). - * - * The following code is emitted at call sites to a function named `func`: - * func(void* arg0, void* arg1, ..., void* argN) // no resource_handle - * -or- - * func(void* arg0, void* arg1, ..., void* argN, void* resource_handle) // with resource_handle - * - * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is - * always the device context parameter when not null. - * - * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented - * with the C runtime only. - */ - kUnpacked, // Emit tir.call_extern passing only the `data` part of DLTensors. - }; - - /*! - * \brief Return a vector of variables that represents the sids for the given Relay Expr - */ - std::vector PackSid(Expr expr) { - std::vector buffer_vars; - - ICHECK(storage_device_map_.find(expr) != storage_device_map_.end()) - << "Storage map did not contain constant expr " << PrettyPrint(expr); - StorageInfo& sinfo = storage_device_map_[expr]; - - // Note that an expression can have multiple sids associated with it - // e.g., returning multiple values from a function - for (auto sid : sinfo->storage_ids) { - // Determine if an sid is an output buffer - auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid); - if (output_iter != return_sid_.end()) { - int output_index = std::distance(return_sid_.begin(), output_iter); - buffer_vars.push_back(GetBufferVarForIO(input_vars_.size() + output_index)); - continue; - } - - auto sid_value = sids_table_[sid]; - buffer_vars.push_back(sid_value); - } - return buffer_vars; - } - - /*! - * brief Given an expression return the variable(s) associated with that expression - */ - std::vector FindExpr(Expr arg) { - auto input_iter = std::find(input_vars_.begin(), input_vars_.end(), arg); - if (input_iter != input_vars_.end()) { - // Input variable - int main_index = std::distance(input_vars_.begin(), input_iter); - return {GetBufferVarForIO(main_index)}; - } else { - // Storage identifier (i.e., intermediate memory) - return PackSid(arg); - } - } - - /*! - * \brief Reverse lookup the device name in devices_ map. - * \param device_context Value in devices_ to find. - * \return Key matching device_context in devices_. - */ - std::string FindDeviceName(tir::Var device_context) { - for (std::pair kv : devices_) { - if (kv.second->name_hint == device_context->name_hint) { - return kv.first; - } - } - ICHECK(false) << "Did not find a device name associated with " << device_context; - return ""; - } - - void PushArgs(const Expr& expr, const std::vector& sids, Array* args) { - const TupleNode* t = expr.as(); - if (t != nullptr) { - CHECK_EQ(sids.size(), t->fields.size()) << "Relay tuple does not map 1:1 into TIR; AOT can't " - "handle this type of Relay Expr in a CallNode."; - } - - args->insert(args->end(), sids.begin(), sids.end()); - } - - /* - * Wraps a call_extern with a tvm_check_return annotation if required otherwise - * returns the passed Call - */ - tir::Call AddCheckReturn(tir::Call existing_call) { - Array args = {tir::make_const(DataType::Int(32, 1), 0, Span()), - tir::make_const(DataType::Int(32, 1), -1, Span()), existing_call}; - return tir::Call(DataType::Int(32), tir::builtin::tvm_check_return(), args); - } - - /*! - * brief Create a function call - * \param call_lowered_props The lowered function and the arguments to call it with - * \param result_expr The call we got func and args from (so as to recover the storage - * ids to hold the result). - */ - void CreateFuncCall(CallLoweredProps call_lowered_props, const Expr& result_expr) { - std::string func_name = call_lowered_props.lowered_func->name_hint; - tvm::Array args{tvm::tir::StringImm(func_name)}; - std::vector create_func_call_stmts; - - // Pack the inputs - for (const Expr& arg : call_lowered_props.arguments) { - if (params_by_expr_.find(arg) != params_by_expr_.end()) { - auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), - {tir::StringImm(params_by_expr_[arg])}); - // NOTE: this cast looks like a no-op, but is required for compilation downstream. - // Because DataType::Handle has default bits=64, but CodeGenC does not observe this field, - // adding this cast forces the codegen to insert the cast. In this case, a cast is required - // because param_handle is actually code-generated as `const void*`, and the `const` piece - // needs to be removed. - args.push_back(tvm::tir::Cast(DataType::Handle(32, 1), param_handle)); - } else { - auto sids = FindExpr(arg); - PushArgs(arg, sids, &args); - } - } - - // Pack the return(s) value. A call node can produce multiple outputs - auto result_expr_sid = PackSid(result_expr); - PushArgs(result_expr, result_expr_sid, &args); - - GlobalVar global_var = call_lowered_props.lowered_func; - bool has_c_device_api_context = device_contexts_.count(global_var) != 0; - tir::Var device_context; - tir::Stmt func_call; - - switch (call_type_) { - case CallType::kUnpacked: { - // call_extern calling convention with optional context - if (has_c_device_api_context) { - device_context = device_contexts_.Get(global_var).value(); - args.push_back(device_context); - } - func_call = tir::Evaluate(AddCheckReturn( - tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args))); - break; - } - case CallType::kCPacked: { - if (has_c_device_api_context) { - device_context = device_contexts_.Get(global_var).value(); - args.push_back(device_context); - } else { - // NOTE: LowerTVMBuiltin expects some device_context placeholder. - args.push_back(tir::make_zero(DataType::Handle())); - } - func_call = tir::Evaluate( - tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_cpacked(), args)); - create_func_call_stmts.push_back(func_call); - break; - } - case CallType::kPacked: { - // call_packed does not accept a device context. - CHECK(!has_c_device_api_context) << "CallType::kPacked does not accept a device context"; - func_call = tir::Evaluate(AddCheckReturn( - tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args))); - create_func_call_stmts.push_back(func_call); - break; - } - default: - ICHECK(false) << "Unknown CallType: " << call_type_; - } - - ICHECK(func_call.defined()) << "Must define func_call"; - - if (has_c_device_api_context) { - func_call = tir::SeqStmt(Array({ - GenerateDeviceHook(device_context, "Open"), - func_call, - GenerateDeviceHook(device_context, "Close"), - })); - } - - tir::Stmt body = tir::SeqStmt({func_call}); - stmts_.push_back(body); - } - - /*! - * \brief Copy a variable to the output. This function is mainly used in edge cases - * when we want to return an input or a parameter. - * TODO(giuseros): we should try to avoid unnecessary copy to the output, e.g., in a - * copy-on-write fashion. - */ - void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) { - // Define intermediate DLTensor to load/store the data - tir::Buffer tmp_read = - tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_read"); - tir::Buffer tmp_write = - tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_write"); - te::Var loop_idx("i", DataType::Int(32)); - auto retval_i = tir::BufferLoad(tmp_read, {loop_idx}); - // Copy the variable from the input to the output - tir::Stmt copy = tir::For( - loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial, - tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx})); - stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy)); - } - - /* - * \brief Collects device context variables for passing to operators - */ - void CollectDeviceVariables(const Map& device_contexts) { - Map target_contexts; - TargetKindAttrMap target_attr_map = tvm::TargetKind::GetAttrMap("use_device_api"); - - for (const auto& it : device_contexts) { - const GlobalVar& global_var = it.first; - const std::string device_context_name = it.second; - - Optional target_kind = tvm::TargetKind::Get(device_context_name); - if (!target_kind || !target_attr_map.count(target_kind.value())) { - return; - } - if (target_attr_map[target_kind.value()]) { - std::string context_name = tvm::runtime::SanitizeName(device_context_name); - tir::Var device_context_var("device_context_" + context_name, DataType::Handle()); - - auto pair = target_contexts.find(target_kind.value()); - if (pair != target_contexts.end()) { - device_context_var = (*pair).second; - } else { - main_signature_.push_back(device_context_var); - devices_.Set(context_name, device_context_var); - target_contexts.Set(target_kind.value(), device_context_var); - } - - device_contexts_.Set(global_var, device_context_var); - } - } - } - - /** - * \brief Generates a call to a given hook for all Devices found for C Device API - * \param Name of hook to generate statements for - * \return Statement with function calls for each device - */ - tir::Stmt GenerateAllDeviceHook(const String& hook) { - std::vector device_hooks; - for (const auto& it : devices_) { - const String& device_name = it.first; - const tir::Var& context = it.second; - Array sections = {"Device", device_name, hook}; - String device_hook_name = ToCFunctionStyle(PrefixName(sections)); - - tir::Evaluate device_hook( - AddCheckReturn(tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), - {tvm::tir::StringImm(device_hook_name), context}))); - device_hooks.push_back(device_hook); - } - return tir::SeqStmt(device_hooks); - } - - /** - * \brief Generates a call to a given hook for a single Device function - * \param Var Device context to call hook on - * \param Name of hook to generate statements for - * \return Statement with function call to Device API - */ - tir::Stmt GenerateDeviceHook(const tir::Var& context, const String& hook) { - const auto& it = std::find_if(std::begin(devices_), std::end(devices_), [&](const auto& it) { - return it.second->name_hint == context->name_hint; - }); - const String& device_name = (*it).first; - Array sections = {"Device", device_name, hook}; - String device_hook = ToCFunctionStyle(PrefixName(sections)); - - return tir::Evaluate( - AddCheckReturn(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), - {tvm::tir::StringImm(device_hook), context}))); - } - - /*! - * Utility function to string together different arguments - */ - template - std::string MakeString(Args const&... args) { - std::ostringstream ss; - using List = int[]; - (void)List{0, ((void)(ss << args), 0)...}; - - return ss.str(); - } - - void VisitExpr_(const CallNode* call_node) override { - OnDeviceProps on_device_props = GetOnDeviceProps(call_node); - if (on_device_props.body.defined()) { - VisitExpr(on_device_props.body); - return; - } - - DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node); - CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); - - if (device_copy_props.body.defined()) { - // TODO(mbs): device_copy cleaunp - // Suspect treating as no-op is better since already built into the StorageInfo? - LOG(FATAL) << "The AOT executor does not currently support device_copy"; - return; - } - - // At this point we should only see calls of the form call_lowered(@callee, (args...)), - // where @callee can be a PrimFunc we've compiled or an external function supplied via - // some other mechanism. - ICHECK(call_lowered_props.lowered_func.defined()) - << "AOT does not support calling Relay functions. Attempting to call:" << std::endl - << PrettyPrint(GetRef(call_node)); - for (const auto& arg : call_lowered_props.arguments) { - // Evaluate the args - VisitExpr(arg); - } - CreateFuncCall(call_lowered_props, GetRef(call_node)); - } - - void VisitExpr_(const VarNode* op) override { - Expr expr = GetRef(op); - StorageInfo& sinfo = storage_device_map_[expr]; - - // Let bound vars refer to a value, so these should not be considered "output" vars. - if (let_bound_vars_.find(GetRef(op)) != let_bound_vars_.end()) { - return; - } - - // If the Var node is an output node we need to copy the content of the variable to the output - // It's safe to check the SID here because Var StorageToken are never reallocated - auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]); - if (output_iter != return_sid_.end()) { - int output_index = std::distance(return_sid_.begin(), output_iter); - if (params_by_expr_.find(expr) != params_by_expr_.end()) { - auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), - {tir::StringImm(params_by_expr_[expr])}); - CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), param_handle, - /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); - } else { - auto var_expr = FindExpr(expr); - CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), var_expr[0], - /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); - } - } - } - - void VisitExpr_(const ConstantNode* op) override { - Expr expr = GetRef(op); - ICHECK(storage_device_map_.find(expr) != storage_device_map_.end()) - << "Storage map did not contain constant expr " << PrettyPrint(expr); - StorageInfo& sinfo = storage_device_map_[expr]; - std::stringstream ss; - ss << "constant_" << constant_map_.size(); - - tir::Var constant(ss.str(), PointerType(PrimType(DataType(op->data->dtype)))); - constant_map_[constant] = op; - auto sid = sinfo->storage_ids[0]; - sids_table_[sid] = constant; - - // If the Constant node is an output node we need to copy the content of the parameter to the - // output. A node can only produce a single output - auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid); - if (output_iter != return_sid_.end()) { - int output_index = std::distance(return_sid_.begin(), output_iter); - auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), - {tir::StringImm(ss.str())}); - CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), constant, - /* pack_input */ false, sinfo->storage_sizes_in_bytes[0]); - } - } - - void VisitExpr_(const TupleNode* op) override { - for (auto field : op->fields) { - VisitExpr(field); - } - } - - void VisitExpr_(const LetNode* op) override { - auto pre_visit = [this](const LetNode* op) { - let_bound_vars_.insert(op->var); - this->VisitExpr(op->value); - }; - auto post_visit = [this](const LetNode* op) { - this->VisitExpr(op->body); - this->visit_counter_[op] += 1; - }; - ExpandANormalForm(op, pre_visit, post_visit); - } - - void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); } - void VisitExpr_(const OpNode* op) override { - if (GetRef(op) != CallLoweredOp() && GetRef(op) != OnDeviceOp()) { - LOG(FATAL) << "All OpNodes except for call_lowered should have been expanded"; - } - } - void VisitExpr_(const IfNode* op) override { - LOG(FATAL) << "All GlobalVarNodes should be removed before AOT executor's Codegen is called"; - } - void VisitExpr_(const FunctionNode* op) override { - ICHECK(op->GetAttr(attr::kCompiler).defined()) - << "FunctionNode only supported by custom codegen"; - } - void VisitExpr_(const RefCreateNode* op) override { - LOG(FATAL) << "AOT executor does not support references (found RefCreateNode)"; - } - void VisitExpr_(const RefReadNode* op) override { - LOG(FATAL) << "AOT executor does not support references (found RefReadNode)"; - } - void VisitExpr_(const RefWriteNode* op) override { - LOG(FATAL) << "AOT executor does not support references (found RefWriteNode)"; - } - void VisitExpr_(const ConstructorNode* op) override { - LOG(FATAL) << "AOT executor does not support ADTs (found ConstructorNode)"; - } - void VisitExpr_(const MatchNode* op) override { - LOG(FATAL) << "AOT executor does not support matching (found MatchNode)"; - } - - // Create the main PrimFunc to execute the graph. Please note that - // the packed function calls don't pack their arguments. The AOT - // runner function needs to be legalized by the LegalizePackedCalls pass. - tir::PrimFunc CreateMainFunc(String mod_name, unsigned int relay_params) { - tir::Stmt body = tir::SeqStmt(stmts_); - // Allocate the sids - std::unordered_map allocated; - - for (auto kv : storage_device_map_) { - // Only allocate sids that are needed - const bool is_input = - (std::find(input_vars_.begin(), input_vars_.end(), kv.first) != input_vars_.end()); - const bool is_param = (params_by_expr_.find(kv.first) != params_by_expr_.end()); - if (is_input || is_param) { - continue; - } - - for (unsigned int i = 0; i < kv.second->storage_ids.size(); i++) { - int size = kv.second->storage_sizes_in_bytes[i]; - int sid = kv.second->storage_ids[i]; - - if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) { - continue; - } - - // Make sure it hasn't already been allocated, this can happen - // with let-bound var/value pairs. - if (allocated.find(sid) != allocated.end()) { - continue; - } - - allocated[sid] = constant_map_.count(sids_table_[sid]); - - // TODO(giuseros): we should allocate this once outside the PrimFunc - // so we don't pay the price of allocation for every inference - if (!allocated[sid]) { - PointerType ptype = Downcast(sids_table_[sid]->type_annotation); - DataType element_type = Downcast(ptype->element_type)->dtype; - body = tir::Allocate(sids_table_[sid], element_type, {size}, tir::const_true(), body); - } - allocated[sid] = true; - } - } - - for (auto kv : constant_map_) { - auto buffer_var = kv.first; - auto dtype = DataType(kv.second->data->dtype); - - int ndim = kv.second->data->ndim; - Array extents; - - for (int i = 0; i < ndim; i++) { - int shape = kv.second->data->shape[i]; - extents.push_back(tir::make_const(DataType::Int(32), shape, Span())); - } - body = tir::AllocateConst(buffer_var, dtype, extents, kv.second->data, body); - } - - // Define the PrimFunc attributes - Map dict_attrs; - String run_func_name = runtime::get_name_mangled(mod_name, runtime::symbol::tvm_module_main); - dict_attrs.Set("global_symbol", run_func_name); - dict_attrs.Set("runner_function", Bool(true)); - dict_attrs.Set(tvm::attr::kTarget, config_->host_target); - - tir::Stmt device_activations = GenerateAllDeviceHook("Activate"); - tir::Stmt device_deactivations = GenerateAllDeviceHook("Deactivate"); - tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations}); - - // Make the PrimFunc - return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, {}, - DictAttrs(dict_attrs)); - } - - /*! - * \brief Access IO vars using the buffer vars and - * not the actual var. - */ - tir::Var GetBufferVarForIO(int index) { return main_buffer_map_[main_signature_[index]]->data; } - - /*! - * \brief Create tir::Var for input/output while updating the buffer_maps. - * - * \param expr The expression to evaluate. - * \param original_name The name of the tir::Var. - * \param use_unique_name Whether to generate a new unique name where a name conflicts. - */ - void CreateIOVar(const Expr& expr, const std::string& original_name, - bool use_unique_name = true) { - CreateIOVar(expr->checked_type(), original_name, use_unique_name); - } - - /*! - * \brief Create tir::Var for input/output while updating the buffer_maps. - * - * \param expr The expression to evaluate. - * \param original_name The name of the tir::Var. - * \param use_unique_name Whether to generate a new unique name where a name conflicts. - */ - void CreateIOVar(const Type& type, const std::string& original_name, - bool use_unique_name = true) { - if (type->IsInstance()) { - TupleType tuple_type = Downcast(type); - for (unsigned i = 0; i < tuple_type->fields.size(); i++) { - CreateIOVar(tuple_type->fields[i], original_name); - } - } else { - std::string name = original_name; - if (use_unique_name) { - name = GetUniqueIOVarName(original_name); - } - tir::Var var = tir::Var(name, DataType::Handle()); - main_signature_.push_back(var); - auto tensor_type = type.as(); - ICHECK(tensor_type) << "Expected TensorType node but was " << type->GetTypeKey(); - DataType elem_type = tensor_type->dtype; - tir::Var buffer_var = - tir::Var(name + "_buffer_var", PointerType(PrimType(elem_type), "global")); - tir::Buffer buffer = tir::Buffer(buffer_var, elem_type, tensor_type->shape, {}, 0, - name + "_buffer", 16, 1, tir::BufferType::kDefault); - main_buffer_map_.Set(var, buffer); - io_tensor_types_.Set(var, Downcast(type)); - } - } - - /*! - * \brief Create a unique name for I/O Var - */ - std::string GetUniqueIOVarName(std::string name) { - if (io_var_names_.find(name) == io_var_names_.end()) { - io_var_names_[name] = 1; - return name; - } else { - io_var_names_[name] = io_var_names_[name] + 1; - return name + std::to_string(io_var_names_[name]); - } - } - - /*! - * \brief Calculate workspace sizes for PrimFuncs in the IRModule - */ - Map CalculateWorkspaceSizes( - const IRModule& lowered_mod, const Map& function_metadata) { - Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(lowered_mod); - Map updated_function_metadata; - for (const auto& kv : lowered_mod->functions) { - GlobalVar global_var = kv.first; - BaseFunc base_func = kv.second; - if (base_func->IsInstance()) { - tir::PrimFunc pfunc = Downcast(base_func); - Target tgt = pfunc->GetAttr(tvm::attr::kTarget).value(); - const auto& ws = CalculateWorkspaceBytes(pfunc, workspace_byte_alignment); - if (function_metadata.count(global_var->name_hint)) { - updated_function_metadata.Set(global_var->name_hint, - function_metadata[global_var->name_hint]); - updated_function_metadata[global_var->name_hint]->workspace_sizes.Set(tgt, ws); - } else { - FunctionInfo finfo{{{tgt, ws}}, {}, {}, {{tgt, pfunc}}, {}}; - updated_function_metadata.Set(global_var->name_hint, finfo); - } - } - } - return updated_function_metadata; - } - - /*! - * \brief Run USMP to plan memory for lowered IRModule. - */ - IRModule PlanMemoryWithUSMP(const IRModule& mod) { - VLOG(1) << "Planning memory with USMP for module:" << std::endl << PrettyPrint(mod); - Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(mod); - IRModule lowered_mod = mod->ShallowCopy(); - lowered_mod = tir::transform::UnifiedStaticMemoryPlanner()(lowered_mod); - function_metadata_ = CalculateWorkspaceSizes(lowered_mod, function_metadata_); - Optional> allocated_pool_infos = - lowered_mod->GetAttr>(tvm::attr::kPoolArgs); - backend::FunctionInfo main_func_info = - lowered_mod->GetAttr("main_func_info").value(); - main_func_info->workspace_sizes.clear(); - if (allocated_pool_infos) { - for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) { - for (const auto& tgt : allocated_pool_info->pool_info->targets) { - VLOG(1) << "USMP requires target " << tgt->ToDebugString() << " to have pool size " - << allocated_pool_info->allocated_size->value; - size_t size = allocated_pool_info->allocated_size->value; - if (allocated_pool_info->pool_info->IsInstance()) { - size += main_func_info->constant_sizes.count(tgt) - ? main_func_info->constant_sizes[tgt]->value - : 0; - main_func_info->constant_sizes.Set(tgt, size); - } else if (allocated_pool_info->pool_info->IsInstance()) { - size += main_func_info->workspace_sizes.count(tgt) - ? main_func_info->workspace_sizes[tgt]->value - : 0; - main_func_info->workspace_sizes.Set(tgt, size); - } else { - LOG(FATAL) << "Unknown pool type: " << allocated_pool_info->pool_info->GetTypeKey(); - } - } - } - } - function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info); - return lowered_mod; - } - - /*! - * \brief Run StorageRewrite to plan memory for lowered IRModule. - */ - IRModule PlanMemoryWithStorageRewrite(const IRModule& mod) { - Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(mod); - IRModule lowered_mod = mod->ShallowCopy(); - function_metadata_ = CalculateWorkspaceSizes(lowered_mod, function_metadata_); - // Running StorageRewrite just on the main function - tir::PrimFunc tir_main_func = - Downcast(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); - IRModule main_func_mod; - main_func_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main), - tir_main_func); - main_func_mod = tir::transform::StorageRewrite()(main_func_mod); - lowered_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main), - main_func_mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); - tir_main_func = - Downcast(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); - // Use the PrimFunc to calculate the workspace required to service the allocates - Integer main_workspace_size_bytes = - CalculateWorkspaceBytes(tir_main_func, workspace_byte_alignment); - backend::FunctionInfo main_func_info = - lowered_mod->GetAttr("main_func_info").value(); - main_func_info->workspace_sizes.Set(config_->host_target, main_workspace_size_bytes); - function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info); - return lowered_mod; - } - - /*! - * \brief Gets module workspace alignment from supplied executor or defaults to 16 - */ - Integer GetModuleWorkspaceByteAlignment(const IRModule& mod) { - Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); - return executor_config->GetAttr("workspace-byte-alignment").value_or(16); - } - - /*! - * \brief Gets module constant alignment from supplied executor or defaults to 16 - */ - Integer GetModuleConstantByteAlignment(const IRModule& mod) { - Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); - return executor_config->GetAttr("constant-byte-alignment").value_or(16); - } - - protected: - /*! \brief mod */ - runtime::Module* mod_; - /*! \brief list of input expressions (i.e., variable passed by the user) */ - std::vector input_vars_; - /*! \brief map of device contexts variables */ - Map devices_; - /*! \brief map of GlobalVars to C Device API contexts */ - Map device_contexts_; - /*! \brief input and output variables belonging to the main function signature */ - Array main_signature_; - /*! \brief input and output variables belonging to the main function signature */ - Map main_buffer_map_; - /*! \brief maps input and output variables to TensorType which describe them */ - Map io_tensor_types_; - /*! \brief All available targets. */ - CompilationConfig config_; - /*! - * \brief The type of kernel call to be emitted. - * See CallType for more documentation. - */ - CallType call_type_; - - /*! - * \brief parameters (i.e. ConstantNodes found in the graph). - * These are take as inputs to the GraphRuntime. - * Maps param name to a pair of storage_id and NDArray. At runtime, the storage_id can be - * used to lookup the parameter. - */ - std::unordered_map params_; - /*! \brief mapping between expression and parameters */ - Map params_by_expr_; - /*! \brief mapping between parameter names ("p0", "p1", etc..) and storage identifiers*/ - std::unordered_map param_storage_ids_; - std::unordered_map - constant_map_; - - /*! \brief plan memory of device result */ - StorageMap storage_device_map_; - /*! \brief mapping sid -> tir::Var */ - std::unordered_map sids_table_; - /*! \brief lowered funcs */ - Map function_metadata_; - /*! \brief the set of statements that make the program */ - std::vector stmts_; - /*! \brief the list of return sids (note that the function might return more then one output */ - std::vector return_sid_; - /*! \brief This is per IO var name counter to aid the generating unique names */ - std::unordered_map io_var_names_; - /*! \brief A set of variables that are let bound. */ - std::unordered_set let_bound_vars_; - - public: - AOTExecutorCodegen(runtime::Module* mod, const Array& targets) - : mod_(mod), config_(transform::PassContext::Current(), targets) {} - - LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) { - VLOG_CONTEXT << "AOT"; - - Runtime runtime_config = mod->GetAttr(tvm::attr::kRuntime).value(); - Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(mod); - - Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); - std::string interface_api = - executor_config->GetAttr("interface-api").value_or("packed"); - bool unpacked_api = executor_config->GetAttr("unpacked-api").value_or(Bool(false)); - - // Validate choice of unpacked_api and use_call_cpacked_ - if (runtime_config->name == kTvmRuntimeCrt) { - if (unpacked_api == true) { - call_type_ = CallType::kUnpacked; - } else if (unpacked_api == false && interface_api == "packed") { - call_type_ = CallType::kCPacked; - } else { - CHECK(interface_api == "packed" || unpacked_api == true) - << "Either need interface_api == \"packed\" (got: " << interface_api - << ") or unpacked-api == true (got: " << unpacked_api << ") when targeting c runtime"; - ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api - << ", unpacked-api=" << unpacked_api; - } - } else if (runtime_config->name == kTvmRuntimeCpp) { - if (unpacked_api == false && interface_api == "packed") { - call_type_ = CallType::kCPacked; - } else { - CHECK(static_cast(unpacked_api) == false && interface_api == "packed") - << "Need unpacked-api == false (got: " << unpacked_api - << ") and interface-api == \"packed\" (got: " << interface_api - << ") when targeting c++ runtime"; - ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api - << ", unpacked-api=" << unpacked_api; - } - } else { - ICHECK(false) << "runtime_config (" << runtime_config->name - << ") is not one of the expected values"; - } - - mod = transform::ToANormalForm()(mod); - mod = transform::InferType()(mod); - mod = transform::AnnotateUsedMemory()(mod); - - IRModule lowered_mod = - tec::LowerTE(mod_name, config_, [this, workspace_byte_alignment](BaseFunc func) { - // We need to maintain the constant map for external - // functions so we pass this processing function which - // allows us to process each function as we lower it. - if (func->GetAttr(attr::kCompiler).defined()) { - UpdateConstants(func, ¶ms_); - } - - // TODO(@areusch, @jroesch): We should refactor this to - // execute as a further pass, instead writing data to the - // lowering process directly. - tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment); - })(mod); - - transform::PassContext pass_ctx = transform::PassContext::Current(); - bool enable_remove_reshapes = - pass_ctx->GetConfig("relay.remove_standalone_reshapes.enable", Bool(true)).value(); - if (enable_remove_reshapes) { - lowered_mod = transform::RemoveStandaloneReshapes()(lowered_mod); - } - auto lowered_main = lowered_mod->Lookup("main"); - auto lowered_main_func = GetRef(lowered_main.as()); - - // Post-lowering storage map for writing main func - AOTOnDemandAllocator final_aot_allocator; - final_aot_allocator.Run(lowered_main_func); - storage_device_map_ = final_aot_allocator.GetStorageMap(); - - // TODO(@electriclilies, @jroesch, @Mousius): remove UpdateMainWorkspaceSize - StaticMemoryPlan memory_plan(storage_device_map_); - backend::FunctionInfo func_info = - tec::UpdateMainWorkspaceSize(lowered_mod, config_, memory_plan->expr_to_storage_info); - lowered_mod = WithAttr(lowered_mod, "main_func_info", func_info); - - for (auto input : lowered_main_func->params) { - input_vars_.push_back(input); - std::string input_name = SanitizeName(input->name_hint()); - // We dont want the compiler changing input names in the - // event of a sanitization collision. Therefore, enforcing - // the var created to use the input_name strictly. - CreateIOVar(input, input_name, /*use_unique_name = */ false); - } - - // Define the storage allocator ids - for (auto kv : storage_device_map_) { - for (auto sid : kv.second->storage_ids) { - // The buffer_var is created with storage_scope to be global.workspace to be serviced by - // TVMBackendAllocWorkspace(TVMBAW) calls, explicitly. The reasoning being the executor - // allocates should be serviced by TVMBAWs as the data could be accessed by many devices and - // should not be lowered to the stack. For more details please refer to the discussion here: - // https://github.com/apache/tvm/issues/9022 - te::Var buffer_var(MakeString("sid_", sid), - PointerType(PrimType(DataType::Int(8)), "global.workspace")); - sids_table_[sid] = buffer_var; - } - } - - // Retrieve the return sids - return_sid_ = final_aot_allocator.GetReturnIds(); - // Insert outputs to main func signature - // If output tensor names were provided use them - if (auto opt = func->GetAttr>("output_tensor_names")) { - Array output_tensor_names = opt.value(); - Expr output_expr = lowered_main_func->body; - if (output_expr->checked_type()->IsInstance()) { - TupleType output_tuple_type = Downcast(output_expr->checked_type()); - for (unsigned i = 0; i < output_tuple_type->fields.size(); i++) { - // AoT Executor Codegen does not create these names, - // thus should be used as they are provided. - CreateIOVar(output_tuple_type->fields[i], output_tensor_names[i], - /*use_unique_name = */ false); - } - } else { - // AoT Executor Codegen does not create these names, - // thus should be used as they are provided. - CreateIOVar(lowered_main_func->body, output_tensor_names[0], /*use_unique_name = */ false); - } - } else { - // If output tensor names are not provided we will generate output(x) - // where x is a counter to create unique names. - CreateIOVar(lowered_main_func->body, "output"); - } - - CollectDeviceVariables(lowered_mod->GetAttr>("device_contexts").value()); - VisitExpr(lowered_main_func->body); - - // Create the runner function. Please note that the function is not legal yet - // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need - // to run the LegalizePackedCalls pass. - LoweredOutput ret; - - // Collect any constants extracted by external codegen. - ret.params = std::unordered_map(); - Map const_name_to_constant = - lowered_mod->GetAttr>(tvm::attr::kConstNameToConstant) - .value_or({}); - for (const auto& kv : const_name_to_constant) { - ICHECK(ret.params.emplace(kv.first, kv.second).second); - } - - // Collect any constants extracted during lowering. - for (const auto& kv : params_) { - ICHECK(ret.params.emplace(kv.first, kv.second).second); - } - - // AoT Executor codegen works completely on TIR beyond this point, hence removing relay main - // function and replacing it with its TIR version. We should try to make this a Pass. - lowered_mod->Remove(lowered_mod->GetGlobalVar("main")); - auto tir_main_func = CreateMainFunc(mod_name, lowered_main_func->params.size()); - // Extract additional information around main TIR PrimFunc arguments - Array devices = ListDevices(); - const auto main_func_params_end_iterator = - tir_main_func->params.begin() + tir_main_func->params.size(); - const auto outputs_begin_iterator = - main_func_params_end_iterator - return_sid_.size() - devices.size(); - Array inputs = Array(tir_main_func->params.begin(), outputs_begin_iterator); - Array input_tensor_types; - for (auto i : inputs) { - input_tensor_types.push_back(io_tensor_types_[i]); - } - Array outputs = - Array(outputs_begin_iterator, main_func_params_end_iterator - devices.size()); - - lowered_mod->Update(GlobalVar(::tvm::runtime::symbol::tvm_module_main), tir_main_func); - // Parallel for loops are not supported in AoT codegen. - lowered_mod = tir::transform::ConvertForLoopsToSerial()(lowered_mod); - - bool enable_usmp = pass_ctx->GetConfig(kUSMPEnableOption, Bool(false)).value(); - if (enable_usmp) { - lowered_mod = PlanMemoryWithUSMP(lowered_mod); - } else { - lowered_mod = PlanMemoryWithStorageRewrite(lowered_mod); - } - ret.function_metadata = std::move(function_metadata_); - - // Legalize AOT if needed. This means that all the packed calls - // need to be wrapped in TVMValues (unless unpacked_api is set) - if (call_type_ == CallType::kCPacked || call_type_ == CallType::kPacked) { - auto pack_calls = tir::transform::LegalizePackedCalls(); - lowered_mod = pack_calls(lowered_mod); - } - - // Collect any runtime modules generated by external codegen. - ret.external_mods = - lowered_mod->GetAttr>(tvm::attr::kExternalMods).value_or({}); - - // This is the point where we separate the functions in the module by target - VLOG(1) << "lowered module:" << std::endl << PrettyPrint(lowered_mod); - ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod); - VLOG(1) << "per-target modules:"; - for (const auto& kv : ret.lowered_funcs) { - VLOG(1) << "target:" << std::endl - << kv.first->ToDebugString() << std::endl - << "maps to:" << std::endl - << PrettyPrint(kv.second); - } - - // Extract USMP metadata to pass onto metadata sources - Map pool_var_info; - std::vector pool_vars; - tir_main_func = - Downcast(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); - Optional> allocated_pool_infos = - tir_main_func->GetAttr>(tvm::attr::kPoolArgs); - if (allocated_pool_infos) { - for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) { - int pool_var_index = allocated_pool_info->pool_var_idx.value()->value; - pool_vars.push_back(tir_main_func->params[pool_var_index]); - pool_var_info.Set(tir_main_func->params[pool_var_index], allocated_pool_info); - } - } - Map io_pool_allocations = - lowered_mod - ->GetAttr>(tvm::attr::kIOTensorPoolAllocations) - .value_or({}); - - std::vector output_var_names; - if (auto opt = func->GetAttr>("output_tensor_names")) { - Array output_tensor_names = opt.value(); - for (size_t i = 0; i < output_tensor_names.size(); ++i) { - output_var_names.push_back(output_tensor_names[i]); - } - } - - // If output names have not been specified then generate default output names - if (output_var_names.size() == 0) { - if (return_sid_.size() == 1) { - output_var_names.push_back(String("output")); - } else { - for (size_t i = 0; i < return_sid_.size(); ++i) { - output_var_names.push_back(String("output" + std::to_string(i))); - } - } - } - - Array output_tensor_types{final_aot_allocator.GetReturnTtypes()}; - - ret.metadata = ExecutorCodegenMetadata( - inputs, input_tensor_types, output_var_names, output_tensor_types, pool_vars, devices, - runtime::kTvmExecutorAot, mod_name, interface_api, unpacked_api, - GetModuleWorkspaceByteAlignment(mod), GetModuleConstantByteAlignment(mod), pool_var_info, - io_pool_allocations); - return ret; - } - - /*! - * \brief Get list of devices found - * \return List of devices - */ - Array ListDevices() { - std::vector device_names(devices_.size()); - std::transform(devices_.begin(), devices_.end(), device_names.begin(), - [](const auto& it) -> String { return it.first; }); - return device_names; - } -}; // namespace backend - -class AOTExecutorCodegenModule : public runtime::ModuleNode { - public: - AOTExecutorCodegenModule() {} - virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { - if (name == "init") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: " - << "runtime::Module mod and Array targets"; - void* mod = args[0]; - Array targets = args[1]; - init(mod, targets); - }); - } else if (name == "codegen") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - IRModule mod = args[0]; - Function func = args[1]; - String mod_name = args[2]; - this->output_ = this->codegen_->Codegen(mod, func, mod_name); - }); - } else if (name == "list_params_name") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = list_params_name(); }); - } else if (name == "get_param_by_name") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - String key = args[0]; - *rv = get_param_by_name(key); - }); - } else if (name == "get_irmodule") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_irmodule(); }); - } else if (name == "get_external_modules") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_external_modules(); }); - } else if (name == "get_function_metadata") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->output_.function_metadata; - }); - } else if (name == "get_devices") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->codegen_->ListDevices(); - }); - } else if (name == "get_executor_codegen_metadata") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = output_.metadata; }); - } else { - return PackedFunc([](TVMArgs args, TVMRetValue* rv) {}); - } - } - - const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; } - - private: - void init(void* mod, const Array& targets) { - codegen_ = - std::make_shared(reinterpret_cast(mod), targets); - } - - Array list_params_name() { - Array ret; - for (const auto& kv : this->output_.params) { - ret.push_back(kv.first); - } - return ret; - } - - runtime::NDArray get_param_by_name(String key) { - auto it = this->output_.params.find(key); - CHECK(it != this->output_.params.end()) << "no such parameter " << key; - return (*it).second; - } - - Array get_external_modules() { return output_.external_mods; } - - Map get_irmodule() { return this->output_.lowered_funcs; } - - std::shared_ptr codegen_; - LoweredOutput output_; -}; - -runtime::Module CreateAOTExecutorCodegenMod() { - auto ptr = make_object(); - return runtime::Module(ptr); -} - -TVM_REGISTER_GLOBAL("relay.build_module._AOTExecutorCodegen") - .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateAOTExecutorCodegenMod(); }); - -} // namespace backend -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index bca524794a20..3f0872de479b 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -65,8 +65,10 @@ struct ExecutorCodegen { CallFunc("init", m, raw_targets); } - void Codegen(IRModule mod, const Function& func, String mod_name) { - CallFunc("codegen", mod, func, mod_name); + void Codegen(IRModule mod, const Function& func, String mod_name, CompilationConfig config, + Executor executor, CallType call_type) { + CallFunc("codegen", mod, func, mod_name, config, executor, + Integer(static_cast(call_type))); } virtual void UpdateOutput(BuildOutput* ret) = 0; @@ -302,11 +304,45 @@ class RelayBuildModule : public runtime::ModuleNode { workspace_memory_pools_ = workspace_memory_pools; constant_memory_pools_ = constant_memory_pools; config_ = CompilationConfig(PassContext::Current(), raw_targets); + SetCallType(executor, runtime); VLOG(1) << "Using compilation config:" << std::endl << config_; BuildRelay(std::move(mod), mod_name); } protected: + void SetCallType(const Executor& executor, const Runtime& runtime) { + std::string interface_api = executor->GetAttr("interface-api").value_or("packed"); + bool unpacked_api = executor->GetAttr("unpacked-api").value_or(Bool(false)); + + // Validate choice of unpacked_api and use_call_cpacked_ + if (runtime->name == kTvmRuntimeCrt) { + if (unpacked_api == true) { + call_type_ = CallType::kUnpacked; + } else if (unpacked_api == false && interface_api == "packed") { + call_type_ = CallType::kCPacked; + } else { + CHECK(interface_api == "packed" || unpacked_api == true) + << "Either need interface_api == \"packed\" (got: " << interface_api + << ") or unpacked-api == true (got: " << unpacked_api << ") when targeting c runtime"; + ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api + << ", unpacked-api=" << unpacked_api; + } + } else if (runtime->name == kTvmRuntimeCpp) { + if (unpacked_api == false && interface_api == "packed") { + call_type_ = CallType::kCPacked; + } else { + CHECK(static_cast(unpacked_api) == false && interface_api == "packed") + << "Need unpacked-api == false (got: " << unpacked_api + << ") and interface-api == \"packed\" (got: " << interface_api + << ") when targeting c++ runtime"; + ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api + << ", unpacked-api=" << unpacked_api; + } + } else { + ICHECK(false) << "runtime (" << runtime->name << ") is not one of the expected values"; + } + } + /*! * \brief Optimize a Relay IRModule. * @@ -428,7 +464,7 @@ class RelayBuildModule : public runtime::ModuleNode { // Generate code for the updated function. executor_codegen_ = MakeExecutorCodegen(executor_->name); executor_codegen_->Init(nullptr, config_->primitive_targets); - executor_codegen_->Codegen(func_module, func, mod_name); + executor_codegen_->Codegen(func_module, func, mod_name, config_, executor_, call_type_); executor_codegen_->UpdateOutput(&ret_); ret_.params = executor_codegen_->GetParams(); @@ -484,6 +520,7 @@ class RelayBuildModule : public runtime::ModuleNode { Executor executor_; /*! \brief Runtime to codegen for */ Runtime runtime_; + CallType call_type_; /*! \brief Workspace memory pools to codegen for */ WorkspaceMemoryPools workspace_memory_pools_; /*! \brief Constant memory pools to codegen for */ diff --git a/tests/python/relay/aot/test_aot_create_function_metadata.py b/tests/python/relay/aot/test_aot_create_function_metadata.py index ff2a522572c5..80137bd23f0c 100644 --- a/tests/python/relay/aot/test_aot_create_function_metadata.py +++ b/tests/python/relay/aot/test_aot_create_function_metadata.py @@ -264,7 +264,7 @@ def __tvm_main__(a: T.handle, output: T.handle) -> None: T.evaluate(T.tvm_call_cpacked("test_fused_add", a_buffer.data, a_buffer.data, output_buffer.data, T.reinterpret(T.uint64(0), dtype="handle"), dtype="int32")) @T.prim_func - def test_fused_add(a: T.handle, b: T.handle, output: T.handle) -> None: + def test_fused_add(a: T.handle, b: T.handle, output: T.handle, device_context_unused: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "test_mod_test_fused_add", "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]})}) a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16) diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index c3426f147e0d..9d255b39c2b7 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -984,7 +984,7 @@ def test_aot_codegen_checks_returns(): # Check operator call is wrapped properly assert ( - str(main_func.body[1]) + str(main_func.body) == "tir.tvm_check_return(0, -1, tir.call_extern(" + '"tvmgen_default_fused_add",' + " x_buffer_var, y_buffer_var, output_buffer_var))\n"