diff --git a/CMakeLists.txt b/CMakeLists.txt index f87a3a9f617f..9c6a7dddfdf6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,6 +38,7 @@ tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" O tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF) tvm_option(USE_GRAPH_EXECUTOR "Build with tiny graph executor" ON) tvm_option(USE_GRAPH_EXECUTOR_CUDA_GRAPH "Build with tiny graph executor with CUDA Graph for GPUs" OFF) +tvm_option(USE_AOT_EXECUTOR "Build with AOT executor" ON) tvm_option(USE_PROFILER "Build profiler for the VM and graph executor" ON) tvm_option(USE_OPENMP "Build with OpenMP thread pool implementation" OFF) tvm_option(USE_RELAY_DEBUG "Building Relay in debug mode..." OFF) @@ -395,6 +396,13 @@ if(USE_PROFILER) list(APPEND RUNTIME_SRCS ${RUNTIME_VM_PROFILER_SRCS}) endif(USE_PROFILER) +if(USE_AOT_EXECUTOR) + message(STATUS "Build with AOT Executor support...") + file(GLOB RUNTIME_AOT_EXECUTOR_SRCS src/runtime/aot_executor/*.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_AOT_EXECUTOR_SRCS}) + +endif(USE_AOT_EXECUTOR) + # Enable ctest if gtest is available if(USE_GTEST) # Check env var for backward compatibility. A better way to specify package diff --git a/include/tvm/relay/runtime.h b/include/tvm/relay/runtime.h index a925045f9f41..10e124bc339b 100644 --- a/include/tvm/relay/runtime.h +++ b/include/tvm/relay/runtime.h @@ -44,6 +44,12 @@ class AttrRegistry; namespace relay { +/*! \brief Value used with Runtime::name to indicate the C++ runtime. */ +static constexpr const char* kTvmRuntimeCpp = "cpp"; + +/*! \brief Value used with Runtime::name to indicate the C runtime. */ +static constexpr const char* kTvmRuntimeCrt = "crt"; + /*! * \brief Runtime information. * diff --git a/include/tvm/runtime/metadata.h b/include/tvm/runtime/metadata.h index b716d41c5d27..cd65f6fb7486 100644 --- a/include/tvm/runtime/metadata.h +++ b/include/tvm/runtime/metadata.h @@ -33,12 +33,13 @@ #include #ifdef __cplusplus #include -#endif #include +#endif // Version number recorded in emitted artifacts for runtime checking. #define TVM_METADATA_VERSION 1 +#ifdef __cplusplus namespace tvm { namespace runtime { namespace metadata { @@ -51,7 +52,6 @@ static const constexpr int64_t kMetadataVersion = TVM_METADATA_VERSION; } // namespace runtime } // namespace tvm -#ifdef __cplusplus extern "C" { #endif @@ -75,6 +75,13 @@ struct TVMMetadata { const struct TVMTensorInfo* outputs; /*! \brief Number of elements in `outputs` array. */ int64_t num_outputs; + /*! \brief Memory Pools needed by the AOT main function. + * The order of the elements is the same as in the arguments to run_model. That is to say, + * this array specifies the last `num_pools` arguments to run_model. + */ + const struct TVMTensorInfo* pools; + /*! \brief Number of elements in `pools` array. */ + int64_t num_pools; /*! \brief Name of the model, as passed to tvm.relay.build. */ const char* mod_name; }; @@ -114,6 +121,8 @@ class MetadataNode : public MetadataBaseNode { ArrayAccessor inputs(); inline int64_t num_outputs() const { return data_->num_outputs; } ArrayAccessor outputs(); + inline int64_t num_pools() const { return data_->num_pools; } + ArrayAccessor pools(); inline ::tvm::runtime::String mod_name() const { return ::tvm::runtime::String(data_->mod_name); } const struct ::TVMMetadata* data() const { return data_; } TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, MetadataBaseNode); diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 2e2a79b1ca53..a93f1c66c395 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -217,6 +217,8 @@ TVM_DLL bool RuntimeEnabled(const std::string& target); /*! \brief namespace for constant symbols */ namespace symbol { +/*! \brief A PackedFunc that retrieves exported metadata. */ +constexpr const char* tvm_get_c_metadata = "get_c_metadata"; /*! \brief Global variable to store module context. */ constexpr const char* tvm_module_ctx = "__tvm_module_ctx"; /*! \brief Global variable to store device module blob */ diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index ac6803ca9842..e8e2798ef734 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -189,7 +189,7 @@ def set_input(self, key=None, value=None, **params): keys.sort(key=lambda x: -np.prod(params[x].shape)) for k in keys: # TODO(zhiics) Skip the weights for submodule in a better way. - # We should use MetadataModule for initialization and remove + # We should use ConstLoaderModule for initialization and remove # params from set_input val = self._get_input(k) if val: diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index d53c4ed49939..6b59b3443078 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -64,6 +64,10 @@ def generate_c_interface_header( return metadata_header +# List of type_key for modules which are ephemeral and do not need to be exported. +EPHEMERAL_MODULE_TYPE_KEYS = ("metadata_module",) + + def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None): """Populate the codegen sub-directory as part of a Model Library Format export. @@ -79,6 +83,11 @@ def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None): """ dso_modules = mod._collect_dso_modules() non_dso_modules = mod._collect_from_import_tree(lambda m: m not in dso_modules) + + # Filter ephemeral modules which cannot be exported. + dso_modules = [m for m in dso_modules if m.type_key not in EPHEMERAL_MODULE_TYPE_KEYS] + non_dso_modules = [m for m in non_dso_modules if m.type_key not in EPHEMERAL_MODULE_TYPE_KEYS] + if non_dso_modules: raise UnsupportedInModelLibraryFormatError( f"Don't know how to export non-c or non-llvm modules; found: {non_dso_modules!r}" diff --git a/python/tvm/relay/backend/executor_factory.py b/python/tvm/relay/backend/executor_factory.py index 9ff7a7a8120b..eee3169400ff 100644 --- a/python/tvm/relay/backend/executor_factory.py +++ b/python/tvm/relay/backend/executor_factory.py @@ -109,6 +109,13 @@ def __init__( executor_codegen_metadata, devices, ): + fcreate = get_global_func("tvm.aot_executor_factory.create") + args = [] + for k, v in params.items(): + args.append(k) + args.append(ndarray.array(v)) + + self.module = fcreate(libmod, libmod_name, *args) self.ir_mod = ir_mod self.lowered_ir_mods = lowered_ir_mods self.target = target @@ -134,6 +141,9 @@ def get_executor_config(self): def get_lib(self): return self.lib + def export_library(self, file_name, fcompile=None, addons=None, **kwargs): + return self.module.export_library(file_name, fcompile, addons, **kwargs) + class GraphExecutorFactoryModule(ExecutorFactoryModule): """Graph executor factory module. diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 5cfd3a16c3bc..7872091f1a5d 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -27,8 +27,11 @@ from tvm.tir import expr as tvm_expr from tvm.target import Target from .. import nd as _nd, autotvm, register_func +from ..runtime import load_module +from ..runtime.executor import aot_executor as _aot_executor from ..target import Target -from ..contrib import graph_executor as _graph_rt +from ..contrib import graph_executor as _graph_executor +from ..contrib import utils as contrib_utils from . import _build_module from . import ty as _ty from . import expr as _expr @@ -612,7 +615,7 @@ def _make_executor(self, expr=None): "Graph Executor only supports static graphs, got output type", ret_type ) mod = build(self.mod, target=self.target) - gmodule = _graph_rt.GraphModule(mod["default"](self.device)) + gmodule = _graph_executor.GraphModule(mod["default"](self.device)) def _unflatten(flat_iter, cur_type): if isinstance(cur_type, _ty.TensorType): @@ -641,6 +644,74 @@ def _graph_wrapper(*args, **kwargs): return _graph_wrapper +class AotExecutor(_interpreter.Executor): + """Implements the Executor interface for AOT. + + Parameters + ---------- + mod : :py:class:`~tvm.IRModule` + The module to support the execution. + + device : :py:class:`Device` + The runtime device to run the code on. + + target : :py:class:`Target` + The target option to build the function. + """ + + def __init__(self, mod, device, target): + assert mod is not None + self.mod = mod + self.device = device + self.target = target + assert target.attrs.get("executor", "graph") == "aot" + + def _make_executor(self, expr=None): + if expr: + self.mod["main"] = expr + self.mod = InferType()(self.mod) + ret_type = self.mod["main"].checked_type.ret_type + if _ty.is_dynamic(ret_type): + raise ValueError("AOT Executor only supports static graphs, got output type", ret_type) + mod = build(self.mod, target=self.target) + + # NOTE: Given AOT requires use of the "c" backend, must export/import to compile the + # generated code. + temp_so_dir = contrib_utils.TempDirectory() + temp_so = temp_so_dir / "temp.so" + mod.export_library(temp_so, cc="gcc", options=["-std=c11"]) + + mod = load_module(temp_so) + aot_mod = mod["default"](self.device) + gmodule = _aot_executor.AotModule(aot_mod) + + def _unflatten(flat_iter, cur_type): + if isinstance(cur_type, _ty.TensorType): + return next(flat_iter) + if isinstance(cur_type, _ty.TupleType): + fields = [] + for field_type in cur_type.fields: + field = _unflatten(flat_iter, field_type) + fields.append(field) + return fields + raise ValueError("Return type", ret_type, "contains unsupported type", cur_type) + + def _aot_wrapper(*args, **kwargs): + args = self._convert_args(self.mod["main"], args, kwargs) + # Create map of inputs. + for i, arg in enumerate(args): + gmodule.set_input(i, arg) + # Run the module, and fetch the output. + gmodule.run() + flattened = [] + for i in range(gmodule.get_num_outputs()): + flattened.append(gmodule.get_output(i).copyto(_nd.cpu(0))) + unflattened = _unflatten(iter(flattened), ret_type) + return unflattened + + return _aot_wrapper + + # TODO(mbs): Collapse the create_executor/evaluate phases together since a) most callers don't # reuse the executor for multiple expressions and b) any preparation necessary for the expression # evaluation needs to (currently) be done along with preparation for the module. @@ -664,9 +735,8 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N Parameters ---------- kind : str - The type of executor. Avaliable options are `debug` for the - interpreter, `graph` for the graph executor, and `vm` for the virtual - machine. + The type of executor. Avaliable options are `debug` for the interpreter, `graph` for the + graph executor, `aot` for the aot executor, and `vm` for the virtual machine. mod : :py:class:`~tvm.IRModule` The Relay module containing collection of functions @@ -703,4 +773,6 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N return GraphExecutor(mod, device, target) if kind == "vm": return VMExecutor(mod, device, target) + if kind == "aot": + return AotExecutor(mod, device, target) raise RuntimeError("unknown execution strategy: {0}".format(kind)) diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index b3504dbac506..ab0fc1709fa9 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -31,3 +31,5 @@ from .module import load_module, enabled, system_lib from .container import String, ShapeTuple from .params import save_param_dict, load_param_dict + +from . import executor diff --git a/python/tvm/runtime/executor/__init__.py b/python/tvm/runtime/executor/__init__.py new file mode 100644 index 000000000000..ecc4097dbaa0 --- /dev/null +++ b/python/tvm/runtime/executor/__init__.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains Python wrappers for the TVM C++ Executor implementations. + +NOTE: at present, only AOT Executor is contained here. The others are: + - GraphExecutor, in python/tvm/contrib/graph_executor.py + - VM Executor, in python/tvm/runtime/vm.py + +TODO(areusch): Consolidate these into this module. +""" +from .aot_executor import AotModule diff --git a/python/tvm/runtime/executor/aot_executor.py b/python/tvm/runtime/executor/aot_executor.py new file mode 100644 index 000000000000..9ef0d1dee894 --- /dev/null +++ b/python/tvm/runtime/executor/aot_executor.py @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A Python wrapper for the Module-based Model Runtime Interface for Ahead-of-Time compilation.""" + +import numpy as np + + +class AotModule(object): + """Wraps the AOT executor runtime.Module. + + This is a thin wrapper of the underlying TVM module. + you can also directly call set_input, run, and get_output + of underlying module functions + + Parameters + ---------- + module : tvm.runtime.Module + The internal tvm module that holds the implemented model functions. + + Attributes + ---------- + module : tvm.runtime.Module + The internal tvm module that holds the implemented model functions. + + Examples + -------- + + .. code-block:: python + + import tvm + from tvm import relay + from tvm.contrib import graph_executor + + # build the library using graph executor + lib = relay.build(...) + lib.export_library("compiled_lib.so") + # load it back as a runtime + lib: tvm.runtime.Module = tvm.runtime.load_module("compiled_lib.so") + # Call the library factory function for default and create + # a new runtime.Module, wrap with aot module. + gmod = tvm.runtime.executor.AotModule(lib["default"](dev)) + # use the aot module. + gmod.set_input("x", data) + gmod.run() + """ + + def __init__(self, module): + self.module = module + self._set_input = module["set_input"] + self._run = module["run"] + self._get_output = module["get_output"] + self._get_input = module["get_input"] + self._get_num_outputs = module["get_num_outputs"] + self._get_input_index = module["get_input_index"] + self._get_num_inputs = module["get_num_inputs"] + + def set_input(self, key=None, value=None, **params): + """Set inputs to the module via kwargs + + Parameters + ---------- + key : int or str + The input key + + value : the input value. + The input key + + params : dict of str to NDArray + Additional arguments + """ + if key is not None: + v = self._get_input(key) + if v is None: + raise RuntimeError("Could not find '%s' in model's inputs" % key) + v.copyfrom(value) + + if params: + # upload big arrays first to avoid memory issue in rpc mode + keys = list(params.keys()) + keys.sort(key=lambda x: -np.prod(params[x].shape)) + for k in keys: + # TODO(zhiics) Skip the weights for submodule in a better way. + # We should use MetadataModule for initialization and remove + # params from set_input + val = self._get_input(k) + if val: + self._get_input(k).copyfrom(params[k]) + + def run(self, **input_dict): + """Run forward execution of the model + + Parameters + ---------- + input_dict: dict of str to NDArray + List of input values to be feed to + """ + if input_dict: + self.set_input(**input_dict) + self._run() + + def get_num_outputs(self): + """Get the number of outputs from the model + + Returns + ------- + count : int + The number of outputs. + """ + return self._get_num_outputs() + + def get_num_inputs(self): + """Get the number of inputs to the model + + Returns + ------- + count : int + The number of inputs. + """ + return self._get_num_inputs() + + def get_input(self, index, out=None): + """Get index-th input to out + + Parameters + ---------- + index : int + The input index + + out : NDArray + The output array container + """ + if out: + self._get_input(index).copyto(out) + return out + + return self._get_input(index) + + def get_input_index(self, name): + """Get inputs index via input name. + + Parameters + ---------- + name : str + The input key name + + Returns + ------- + index: int + The input index. -1 will be returned if the given input name is not found. + """ + return self._get_input_index(name) + + def get_output(self, index, out=None): + """Get index-th output to out + + Parameters + ---------- + index : int + The output index + + out : NDArray + The output array container + """ + if out: + self._get_output(index, out) + return out + + return self._get_output(index) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index a25ef458906c..6e6c4b6dfeb2 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -70,6 +71,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { void Run(const Function& func) { VisitExpr(func); } std::vector GetReturnIds() const { return return_ids_; } + std::vector GetReturnTtypes() const { return return_ttypes_; } StorageMap GetStorageMap() const { return storage_device_map_; } @@ -177,6 +179,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { for (auto sid : sinfo->storage_ids) { return_ids_.push_back(sid); } + return_ttypes_.clear(); + return_ttypes_ = FlattenTupleType(e->checked_type()); } } /*! @@ -252,6 +256,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { int next_available_sid_{0}; /*! \brief the set of intermediate tensors that are return variables */ std::vector return_ids_; + /*! \brief the data types of the return values */ + std::vector return_ttypes_; }; /*! \brief Code generator for AOT executor */ @@ -317,6 +323,16 @@ class AOTExecutorCodegen : public MixedModeVisitor { } } + void PushArgs(const Expr& expr, const std::vector& sids, Array* args) { + const TupleNode* t = expr.as(); + if (t != nullptr) { + CHECK_EQ(sids.size(), t->fields.size()) << "Relay tuple does not map 1:1 into TIR; AOT can't " + "handle this type of Relay Expr in a CallNode."; + } + + args->insert(args->end(), sids.begin(), sids.end()); + } + /*! * brief Create a function call * \param call_lowered_props The lowered function and the arguments to call it with @@ -332,30 +348,35 @@ class AOTExecutorCodegen : public MixedModeVisitor { if (params_by_expr_.find(arg) != params_by_expr_.end()) { auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), {tir::StringImm(params_by_expr_[arg])}); - args.push_back(tvm::tir::Cast(DataType::Handle(), param_handle)); + // NOTE: this cast looks like a no-op, but is required for compilation downstream. + // Because DataType::Handle has default bits=64, but CodeGenC does not observe this field, + // adding this cast forces the codegen to insert the cast. In this case, a cast is required + // because param_handle is actually code-generated as `const void*`, and the `const` piece + // needs to be removed. + args.push_back(tvm::tir::Cast(DataType::Handle(32, 1), param_handle)); } else { - auto var_arg = FindExpr(arg); - for (const auto& var : var_arg) { - args.push_back(var); - } + auto sids = FindExpr(arg); + PushArgs(arg, sids, &args); } } // Pack the return(s) value. A call node can produce multiple outputs - for (const auto& var : PackSid(result_expr)) { - args.push_back(var); - } + auto result_expr_sid = PackSid(result_expr); + PushArgs(result_expr, result_expr_sid, &args); - // Use tvm_call_packed to execute the function unless we're calling directly - auto calling_pattern = tvm::tir::builtin::tvm_call_cpacked(); + // Choose call style based on Runtime/Executor config. + Op calling_pattern; if (use_unpacked_api_) { calling_pattern = tvm::tir::builtin::call_extern(); + } else if (use_call_cpacked_) { + calling_pattern = tvm::tir::builtin::tvm_call_cpacked(); + } else { + calling_pattern = tvm::tir::builtin::tvm_call_packed(); } GlobalVar global_var = call_lowered_props.lowered_func; tir::Var empty_var("no_device_context", DataType::Handle()); bool has_c_device_api_context = device_contexts_.count(global_var) != 0; - bool use_cpacked_api = !use_unpacked_api_; // The device context is passed to the operator in one of the following calling patterns: // * Unpacked / direct function call with context: @@ -379,7 +400,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { func_call, GenerateDeviceHook(context, "Close"), })); - } else if (use_cpacked_api) { + } else if (use_call_cpacked_) { // call_cpacked calling convention needs a blank context args.push_back(tir::make_zero(DataType::Handle())); tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args)); @@ -698,6 +719,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { tir::Buffer buffer = tir::Buffer(buffer_var, elem_type, tensor_type->shape, {}, 0, name + "_buffer", 16, 1, tir::BufferType::kDefault); main_buffer_map_.Set(var, buffer); + io_tensor_types_.Set(var, Downcast(expr->checked_type())); } } @@ -806,19 +828,37 @@ class AOTExecutorCodegen : public MixedModeVisitor { Array main_signature_; /*! \brief input and output variables belonging to the main function signature */ Map main_buffer_map_; + /*! \brief maps input and output variables to TensorType which describe them */ + Map io_tensor_types_; /*! \brief target device */ tec::TargetMap targets_; /*! \brief target host */ Target target_host_; /*! * \brief unpacked api toggle - * When set to true the code generated will use unpacked calls to functions: + * When set to true, the generated code will use unpacked calls to functions: * func(void* arg0, void* arg1) - * Rather than packed calls: - * func(void* args) + * Rather than packed calls (in which arg0 and arg1 are in `arg_values`). + * func(TVMValue* arg_values, int* arg_type_codes, int num_args, ...) * Defaults to using the packed calling convention + * + * Unpacked API is supported when runtime == "c" and interface_api is "c". */ Bool use_unpacked_api_; + /*! + * \brief cpacked api toggle + * When set to true, the generated code will use call_cpacked to call functions directly, assuming + * they exist in a DSO-exportable module: + * func(...) + * Rather than through the traditional call_packed calls, which should use function pointers + * looked-up through TVMBackendGetFuncFromEnv: + * TVMBackendPackedCFunc* func_ptr = TVMBackendGetFuncFromEnv("func"); + * func_ptr(...) + * Defaults to using the packed calling convention + * + * call_cpacked is required when runtime is "c++" and supported when runtime is "c" + */ + Bool use_call_cpacked_; /*! * \brief parameters (i.e. ConstantNodes found in the graph). @@ -847,7 +887,11 @@ class AOTExecutorCodegen : public MixedModeVisitor { public: AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host) - : mod_(mod), targets_(targets), target_host_(target_host), use_unpacked_api_(Bool(false)) {} + : mod_(mod), + targets_(targets), + target_host_(target_host), + use_unpacked_api_(Bool(false)), + use_call_cpacked_(Bool(false)) {} LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) { VLOG_CONTEXT << "AOT"; @@ -857,11 +901,29 @@ class AOTExecutorCodegen : public MixedModeVisitor { ICHECK(target_host_.defined()) << "require a target_host to be given for AOT codegen"; VLOG(1) << "target host: " << target_host_->ToDebugString(); + Runtime runtime_config = mod->GetAttr(tvm::attr::kRuntime).value(); Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); String interface_api = executor_config->GetAttr("interface-api").value_or("packed"); Integer workspace_byte_alignment = executor_config->GetAttr("workspace-byte-alignment").value_or(16); use_unpacked_api_ = executor_config->GetAttr("unpacked-api").value_or(Bool(false)); + use_call_cpacked_ = !use_unpacked_api_; + + // Validate choice of use_unpacked_api_ and use_call_cpacked_ + if (runtime_config->name == kTvmRuntimeCrt) { + ICHECK(interface_api == "packed" || static_cast(use_unpacked_api_) == true) + << "Either need interface_api == \"packed\" (got: " << interface_api + << ") or unpacked-api == true (got: " << use_unpacked_api_ + << ") when targeting c runtime"; + } else if (runtime_config->name == kTvmRuntimeCpp) { + ICHECK(static_cast(use_unpacked_api_) == false) + << "Need unpacked-api == false (got: " << use_unpacked_api_ + << ") and interface-api == \"packed\" (got: " << interface_api + << ") when targeting c++ runtime"; + } else { + ICHECK(false) << "runtime_config (" << runtime_config->name + << ") is not one of the expected values"; + } // TODO(mbs): Plumb from compiler config VirtualDevice host_virtual_device = VirtualDevice::ForTarget(target_host_); @@ -996,6 +1058,11 @@ class AOTExecutorCodegen : public MixedModeVisitor { tir_main_func->params.begin() + tir_main_func->params.size() - return_sid_.size() - pool_vars.size() - devices.size()); + Array input_tensor_types; + for (auto i : inputs) { + input_tensor_types.push_back(io_tensor_types_[i]); + } + std::vector output_var_names; if (auto opt = func->GetAttr>("output_tensor_names")) { Array output_tensor_names = opt.value(); @@ -1015,9 +1082,11 @@ class AOTExecutorCodegen : public MixedModeVisitor { } } - ret.metadata = ExecutorCodegenMetadata(inputs, pool_vars, devices, output_var_names, - runtime::kTvmExecutorAot, mod_name, interface_api, - use_unpacked_api_, pool_var_info); + Array output_tensor_types{final_aot_allocator.GetReturnTtypes()}; + + ret.metadata = ExecutorCodegenMetadata( + inputs, input_tensor_types, output_var_names, output_tensor_types, pool_vars, devices, + runtime::kTvmExecutorAot, mod_name, interface_api, use_unpacked_api_, pool_var_info); return ret; } diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 227f7bbfdf31..30bc0beeebce 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -290,25 +290,15 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator - * - * \param shape - * \return std::vector - */ - std::vector _ShapeToJSON(tvm::Array shape) { - std::vector ret; - for (IndexExpr dim : shape) { - const int64_t* pval = tir::as_const_int(dim); - ret.push_back(*pval); - } - return ret; - } - /*! * \brief Add node to graph * @@ -352,7 +342,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorfields.size(); ++i) { if (const auto* typ = tuple_type->fields[i].as()) { ret.push_back(GraphNodeRef(node_id, i)); - shape.emplace_back(_ShapeToJSON(typ->shape)); + shape.emplace_back(ShapeToJSON(typ->shape)); dtype.emplace_back(DType2String(typ->dtype)); } else { LOG(FATAL) << "type " << checked_type->GetTypeKey() << " not supported"; @@ -369,7 +359,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator()) { ShapeVector shape; std::vector dtype; - shape.emplace_back(_ShapeToJSON(tensor_type->shape)); + shape.emplace_back(ShapeToJSON(tensor_type->shape)); dtype.emplace_back(DType2String(tensor_type->dtype)); node->attrs_["shape"] = shape; node->attrs_["dtype"] = dtype; diff --git a/src/relay/backend/runtime.cc b/src/relay/backend/runtime.cc index 786d6f937f14..923c9b2d5f65 100644 --- a/src/relay/backend/runtime.cc +++ b/src/relay/backend/runtime.cc @@ -88,9 +88,9 @@ RuntimeRegEntry& RuntimeRegEntry::RegisterOrGet(const String& name) { /********** Register Runtimes and options **********/ -TVM_REGISTER_RUNTIME("crt").add_attr_option("system-lib"); +TVM_REGISTER_RUNTIME(kTvmRuntimeCrt).add_attr_option("system-lib"); -TVM_REGISTER_RUNTIME("cpp").add_attr_option("system-lib"); +TVM_REGISTER_RUNTIME(kTvmRuntimeCpp).add_attr_option("system-lib"); /********** Registry **********/ diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 7d0df672bf6f..990d619e4231 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -179,14 +179,17 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); ExecutorCodegenMetadata::ExecutorCodegenMetadata( - Array inputs, Array pools, Array devices, Array outputs, + Array inputs, Array input_tensor_types, Array outputs, + Array output_tensor_types, Array pools, Array devices, String executor, String mod_name, String interface_api, bool unpacked_api, Map pool_inputs) { auto n = make_object(); n->inputs = inputs; + n->input_tensor_types = input_tensor_types; + n->outputs = outputs; + n->output_tensor_types = output_tensor_types; n->pools = pools; n->devices = devices; - n->outputs = outputs; n->executor = executor; n->interface_api = interface_api; n->unpacked_api = unpacked_api; @@ -294,6 +297,15 @@ void UpdateAutoSchedulerOpWeights(const IRModule& module) { (*te_compiler_update_weights)(weight_map); } +std::vector ShapeToJSON(tvm::Array shape) { + std::vector ret; + for (IndexExpr dim : shape) { + const int64_t* pval = tir::as_const_int(dim); + ret.push_back(*pval); + } + return ret; +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 645940e7de89..3b4d4c18de89 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -45,6 +45,7 @@ #include #include "../../runtime/meta_data.h" +#include "../../target/metadata.h" namespace tvm { namespace relay { @@ -63,8 +64,12 @@ class ExecutorCodegenMetadataNode : public Object { public: /*! \brief input information for the main function */ Array inputs; + /*! \brief input tensor type information */ + Array input_tensor_types; /*! \brief output information for the main function */ Array outputs; + /*! \brief output tensor type information */ + Array output_tensor_types; /*! \brief pool information for the main function */ Array pools; /*! \brief device contexts information for the main function */ @@ -82,8 +87,10 @@ class ExecutorCodegenMetadataNode : public Object { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("inputs", &inputs); - v->Visit("pools", &pools); + v->Visit("input_tensor_types", &input_tensor_types); v->Visit("outputs", &outputs); + v->Visit("output_tensor_types", &output_tensor_types); + v->Visit("pools", &pools); v->Visit("devices", &devices); v->Visit("executor", &executor); v->Visit("unpacked_api", &unpacked_api); @@ -99,8 +106,9 @@ class ExecutorCodegenMetadataNode : public Object { */ class ExecutorCodegenMetadata : public ObjectRef { public: - TVM_DLL ExecutorCodegenMetadata(Array inputs, Array pools, - Array devices, Array outputs, String executor, + TVM_DLL ExecutorCodegenMetadata(Array inputs, Array input_tensor_types, + Array outputs, Array output_tensor_types, + Array pools, Array devices, String executor, String mod_name, String interface_api = "packed", bool unpacked_api = false, Map pool_inputs = @@ -587,6 +595,14 @@ Map TargetStrModuleMapToTargetModuleMap( */ void UpdateAutoSchedulerOpWeights(const IRModule& module); +/*! + * \brief Extract shape from expr to vector + * + * \param shape + * \return std::vector + */ +std::vector ShapeToJSON(tvm::Array shape); + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/runtime/aot_executor/aot_executor.cc b/src/runtime/aot_executor/aot_executor.cc new file mode 100644 index 000000000000..5bd0d80ff95c --- /dev/null +++ b/src/runtime/aot_executor/aot_executor.cc @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Defines an implementation of Module-based Model Runtime Interface that works with + * Ahead-of-Time compilation. + * \file aot_executor.cc + */ + +#include "aot_executor.h" + +#include + +#include + +#include "../meta_data.h" + +namespace tvm { +namespace runtime { + +AotExecutor::AotExecutor(tvm::runtime::Module module, const std::vector& devs) + : module_{module}, devices_{devs} { + auto fmetadata = module->GetFunction("get_metadata"); + CHECK(fmetadata != nullptr) << "Expected a module with PackedFunc get_metadata"; + auto ret_value = fmetadata(); + metadata_ = ret_value.AsObjectRef(); + + ICHECK_EQ(devices_.size(), 1) << "Expect exactly 1 device passed."; + DLDevice expected_device{kDLCPU, 0}; + ICHECK_EQ(devices_[0].device_id, expected_device.device_id) + << "At this time, AOTExecutor supports only execution on kDLCPU 0"; + ICHECK_EQ(devices_[0].device_type, expected_device.device_type) + << "At this time, AOTExecutor supports only execution on kDLCPU 0"; + + for (auto input : metadata_->inputs()) { + // TODO(areusch): Encode device information in Metadata. + args_.emplace_back(NDArray::Empty(ShapeTuple(input->shape().begin(), input->shape().end()), + input->dtype(), devices_[0])); + } + + for (auto output : metadata_->outputs()) { + args_.emplace_back(NDArray::Empty(ShapeTuple(output->shape().begin(), output->shape().end()), + output->dtype(), devices_[0])); + } + + for (auto pool : metadata_->pools()) { + args_.emplace_back(NDArray::Empty(ShapeTuple(pool->shape().begin(), pool->shape().end()), + pool->dtype(), devices_[0])); + } +} + +PackedFunc AotExecutor::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + // Return member functions during query. + if (name == "set_input") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + int in_idx = this->GetInputIndex(args[0].operator String()); + if (in_idx >= 0) this->SetInput(in_idx, args[1]); + } else { + this->SetInput(args[0], args[1]); + } + }); + } else if (name == "set_input_zero_copy") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + int in_idx = this->GetInputIndex(args[0].operator String()); + if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]); + } else { + this->SetInputZeroCopy(args[0], args[1]); + } + }); + } else if (name == "set_output_zero_copy") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + int out_idx = this->GetOutputIndex(args[0].operator String()); + if (out_idx >= 0) this->SetOutputZeroCopy(out_idx, args[1]); + } else { + this->SetOutputZeroCopy(args[0], args[1]); + } + }); + } else if (name == "get_output") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (args.num_args == 2) { + this->CopyOutputTo(args[0], args[1]); + } else { + *rv = this->GetOutput(args[0]); + } + }); + } else if (name == "get_input") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + int in_idx = 0; + if (String::CanConvertFrom(args[0])) { + in_idx = this->GetInputIndex(args[0].operator String()); + } else { + in_idx = args[0]; + } + if (in_idx >= 0) { + *rv = this->GetInput(in_idx); + } + }); + } else if (name == "get_num_outputs") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); }); + } else if (name == "get_num_inputs") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumInputs(); }); + } else if (name == "run") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); }); + } else if (name == "get_input_index") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string"; + *rv = this->GetInputIndex(args[0].operator String()); + }); + } else { + return PackedFunc(); + } +} + +void AotExecutor::Run() { + auto pf = module_.GetFunction( + get_name_mangled(metadata_->mod_name(), ::tvm::runtime::symbol::tvm_module_main), + true /* query_imports */); + ICHECK(pf != nullptr) << "Module entrypoint is not defined"; + + const int num_args = args_.size(); + auto call_values = ::std::make_unique(num_args); + auto call_type_codes = ::std::make_unique(num_args); + for (int i = 0; i < num_args; ++i) { + auto managed = args_[i].ToDLPack(); + call_values.get()[i].v_handle = &managed->dl_tensor; + call_type_codes.get()[i] = kTVMDLTensorHandle; + } + + TVMArgs args{call_values.get(), call_type_codes.get(), num_args}; + TVMRetValue rv; + pf.CallPacked(args, &rv); +} + +int AotExecutor::GetInputIndex(const std::string& name) { + auto inputs = metadata_->inputs(); + for (unsigned int i = 0; i < inputs.size(); i++) { + if (inputs[i]->name() == name) { + return i; + } + } + return -1; +} + +int AotExecutor::GetOutputIndex(const std::string& name) { + auto outputs = metadata_->outputs(); + for (unsigned int i = 0; i < outputs.size(); i++) { + if (outputs[i]->name() == name) { + return i; + } + } + return -1; +} + +void AotExecutor::SetInput(int index, DLTensor* data_ref) { args_[index].CopyFrom(data_ref); } + +void AotExecutor::SetInputZeroCopy(int index, DLTensor* data_ref) { + ICHECK(false) << "not implemented"; +} + +void AotExecutor::SetOutputZeroCopy(int index, DLTensor* data_ref) { + ICHECK(false) << "not implemented"; +} + +int AotExecutor::NumOutputs() const { return metadata_->num_outputs(); } + +int AotExecutor::NumInputs() const { return metadata_->num_inputs(); } + +NDArray AotExecutor::GetInput(int index) const { return args_[index]; } + +NDArray AotExecutor::GetOutput(int index) const { return args_[metadata_->num_inputs() + index]; } + +void AotExecutor::CopyOutputTo(int index, DLTensor* data_out) { GetOutput(index).CopyTo(data_out); } + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/aot_executor/aot_executor.h b/src/runtime/aot_executor/aot_executor.h new file mode 100644 index 000000000000..ccbcf8fdf3d2 --- /dev/null +++ b/src/runtime/aot_executor/aot_executor.h @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Defines an implementation of Module-based Model Runtime Interface that works with + * Ahead-of-Time compilation. + * \file aot_executor.h + */ +#ifndef TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_ +#define TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_ + +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { + +class TVM_DLL AotExecutor : public ModuleNode { + public: + /*! + * \brief Implements member function lookup for this Module for the frontend. + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * \return The corresponding member function. + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override; + + /*! + * \return The type key of the executor. + */ + const char* type_key() const final { return "AotExecutor"; } + + void Run(); + + /*! + * \brief Initialize the AOT executor with metadata, runtime::Module, and device. + * \param module The module containing the compiled functions for the host + * processor. + * \param devs A 1-element vector. The Device which AOT compute will run on. Currently, only + * Device(kDLCPU, 0) is supported. + */ + AotExecutor(tvm::runtime::Module module, const std::vector& devs); + + /*! + * \brief Get the input index given the name of input. + * \param name The name of the input. + * \return The index of input. + */ + int GetInputIndex(const std::string& name); + + /*! + * \brief Get the output index given the name of output. + * \param name The name of the output. + * \return The index of output. + */ + int GetOutputIndex(const std::string& name); + + /*! + * \brief set index-th input to the graph. + * \param index The input index. + * \param data_in The input data. + */ + void SetInput(int index, DLTensor* data_in); + /*! + * \brief set index-th input to the graph without copying the data + * \param index The input index. + * \param data_ref The input data that is referred. + */ + void SetInputZeroCopy(int index, DLTensor* data_ref); + /*! + * \brief set index-th output to the graph without copying the data. + * \param index The output index. + * \param data_ref The output data that is referred. + */ + void SetOutputZeroCopy(int index, DLTensor* data_ref); + /*! + * \brief Get the number of outputs + * + * \return The number of outputs from graph. + */ + int NumOutputs() const; + /*! + * \brief Get the number of inputs + * + * \return The number of inputs to the graph. + */ + int NumInputs() const; + /*! + * \brief Return NDArray for given input index. + * \param index The input index. + * + * \return NDArray corresponding to given input node index. + */ + NDArray GetInput(int index) const; + /*! + * \brief Return NDArray for given output index. + * \param index The output index. + * + * \return NDArray corresponding to given output node index. + */ + NDArray GetOutput(int index) const; + /*! + * \brief Copy index-th output to data_out. + * \param index The output index. + * \param data_out the output data. + */ + void CopyOutputTo(int index, DLTensor* data_out); + + private: + /*! \brief Metadata provided to the runtime from the compiler. */ + metadata::Metadata metadata_; + + /*! \brief Runtime module which contains the AOT top-level function. */ + Module module_; + + /*! \brief The devices which should be used to execute the computations. */ + std::vector devices_; + + /*! \brief Holds one NDArray per function argument in the same order. */ + std::vector args_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_ diff --git a/src/runtime/aot_executor/aot_executor_factory.cc b/src/runtime/aot_executor/aot_executor_factory.cc new file mode 100644 index 000000000000..7760f0fe6c4d --- /dev/null +++ b/src/runtime/aot_executor/aot_executor_factory.cc @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file aot_executor_factory.cc + * \brief AOT executor factory implementations + */ + +#include "./aot_executor_factory.h" + +#include +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { + +AotExecutorFactory::AotExecutorFactory( + const std::unordered_map& params, + const std::string& module_name) { + params_ = params; + module_name_ = module_name; +} + +PackedFunc AotExecutorFactory::GetFunction( + const std::string& name, const tvm::runtime::ObjectPtr& sptr_to_self) { + if (name == module_name_) { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK_GT(args.num_args, 0) << "Must supply at least one device argument"; + std::vector devices; + for (int i = 0; i < args.num_args; ++i) { + devices.emplace_back(args[i].operator Device()); + } + *rv = this->ExecutorCreate(devices); + }); + } else if (name == "remove_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::unordered_map empty_params{}; + auto exec = make_object(empty_params, this->module_name_); + exec->Import(this->imports_[0]); + *rv = Module(exec); + }); + } else { + return PackedFunc(); + } +} + +void AotExecutorFactory::SaveToBinary(dmlc::Stream* stream) { + std::vector names; + std::vector arrays; + for (const auto& v : params_) { + names.emplace_back(v.first); + arrays.emplace_back(const_cast(v.second.operator->())); + } + uint64_t sz = arrays.size(); + ICHECK(sz == names.size()); + stream->Write(sz); + stream->Write(names); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::SaveDLTensor(stream, arrays[i]); + } + stream->Write(module_name_); +} + +Module AotExecutorFactory::ExecutorCreate(const std::vector& devs) { + auto exec = make_object(this->imports_[0], devs); + // set params + SetParams(exec.get(), this->params_); + return Module(exec); +} + +Module AotExecutorFactoryModuleLoadBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + std::unordered_map params; + std::string module_name; + uint64_t sz; + ICHECK(stream->Read(&sz)); + std::vector names; + ICHECK(stream->Read(&names)); + ICHECK(sz == names.size()); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::NDArray temp; + temp.Load(stream); + params[names[i]] = temp; + } + ICHECK(stream->Read(&module_name)); + auto exec = make_object(params, module_name); + return Module(exec); +} + +TVM_REGISTER_GLOBAL("tvm.aot_executor_factory.create").set_body([](TVMArgs args, TVMRetValue* rv) { + ICHECK_GE(args.num_args, 2) << "The expected number of arguments for " + "aot_executor_factory.create needs at least 2, " + "but it has " + << args.num_args; + // The argument order is module, module_name, param0_name, param0_tensor, + // [param1_name, param1_tensor], ... + ICHECK_EQ((args.size() - 2) % 2, 0); + std::unordered_map params; + for (size_t i = 2; i < static_cast(args.size()); i += 2) { + std::string name = args[i].operator String(); + params[name] = args[i + 1].operator tvm::runtime::NDArray(); + } + auto exec = make_object(params, args[1]); + exec->Import(args[0]); + *rv = Module(exec); +}); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_AotExecutorFactory") + .set_body_typed(AotExecutorFactoryModuleLoadBinary); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/aot_executor/aot_executor_factory.h b/src/runtime/aot_executor/aot_executor_factory.h new file mode 100644 index 000000000000..1d6a0a62776e --- /dev/null +++ b/src/runtime/aot_executor/aot_executor_factory.h @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/aot_executor/aot_executor_factory.h + * \brief Aot executor factory creating aot executor. + */ + +#ifndef TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_FACTORY_H_ +#define TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_FACTORY_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "./aot_executor.h" + +namespace tvm { +namespace runtime { + +class TVM_DLL AotExecutorFactory : public runtime::ModuleNode { + public: + /*! + * \brief Construct the AotExecutorFactory. + * \param params The params of aot. + * \param module_name The module name of aot. + */ + AotExecutorFactory(const std::unordered_map& params, + const std::string& module_name); + + /*! + * \brief Get member function to front-end + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * \return The corresponding member function. + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + + /*! + * \return The type key of the executor. + */ + const char* type_key() const override { return "AotExecutorFactory"; } + + /*! + * \brief Save the module to binary stream. + * \param stream The binary stream to save to. + */ + void SaveToBinary(dmlc::Stream* stream) override; + + /*! + * \brief Create a specific executor module + * \param devs The device of the host and devices where the model will be + * executed. + * \return created executor module + */ + Module ExecutorCreate(const std::vector& devs); + + /*! + * \brief Set params. + * \param aot_executor The aot executor we want to set the params into. + * \param params The aot params value we want to set. + */ + void SetParams(AotExecutor* aot_executor, + const std::unordered_map& params) const { + std::unordered_map value = params; + // upload big arrays first to avoid memory issue in rpc mode + std::vector keys; + for (const auto& p : value) { + keys.emplace_back(p.first); + } + std::sort(std::begin(keys), std::end(keys), + [&](const std::string& lhs, const std::string& rhs) -> bool { + auto lhs_size = GetDataSize(*value[lhs].operator->()); + auto rhs_size = GetDataSize(*value[rhs].operator->()); + return lhs_size > rhs_size; + }); + for (const auto& key : keys) { + int in_idx = aot_executor->GetInputIndex(key); + if (in_idx >= 0) { + aot_executor->SetInput(in_idx, const_cast(value[key].operator->())); + } + } + } + + protected: + /*! \brief The params. */ + std::unordered_map params_; + /*! \brief module name */ + std::string module_name_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_FACTORY_H_ diff --git a/src/runtime/metadata_module.cc b/src/runtime/const_loader_module.cc similarity index 67% rename from src/runtime/metadata_module.cc rename to src/runtime/const_loader_module.cc index 7cb986bba62c..5496e161e57f 100644 --- a/src/runtime/metadata_module.cc +++ b/src/runtime/const_loader_module.cc @@ -18,13 +18,13 @@ */ /*! - * \file src/runtime/metadata_module.cc - * \brief A wrapper for initializing imported modules using metadata. This + * \file src/runtime/const_loader_module.cc + * \brief A wrapper for initializing imported modules using constant NDArray. This * module is intended to be used by various runtime in the TVM stack, i.e. * graph executor, relay VM, AOT runtime, and various user defined runtimes. It * paves the way to separate the code and metedata, which makes compilation * and/or interpretation more convenient. In addition, the clear separation of - * code and metadata significantly reduces the efforts for handling external + * code and constants significantly reduces the efforts for handling external * codegen and runtimes. */ #include @@ -42,18 +42,19 @@ namespace tvm { namespace runtime { /*! - * \brief The metadata module is designed to manage initialization of the - * imported submodules. + * \brief The const-loader module is designed to manage initialization of the + * imported submodules for the C++ runtime. */ -class MetadataModuleNode : public ModuleNode { +class ConstLoaderModuleNode : public ModuleNode { public: - MetadataModuleNode(const std::unordered_map& metadata, - const std::unordered_map>& sym_vars) - : metadata_(metadata), sym_vars_(sym_vars) { + ConstLoaderModuleNode( + const std::unordered_map& const_var_ndarray, + const std::unordered_map>& const_vars_by_symbol) + : const_var_ndarray_(const_var_ndarray), const_vars_by_symbol_(const_vars_by_symbol) { // Only the related submodules are cached to reduce the number of runtime // symbol lookup for initialization. Otherwise, symbols/primitives in the // DSO module will also be cached but they never need to be initialized. - for (const auto& it : sym_vars_) { + for (const auto& it : const_vars_by_symbol_) { initialized_[it.first] = false; } } @@ -78,20 +79,20 @@ class MetadataModuleNode : public ModuleNode { return PackedFunc(nullptr); } - const char* type_key() const { return "metadata"; } + const char* type_key() const { return "const_loader"; } /*! - * \brief Get the list of metadata that is required by the given module. + * \brief Get the list of constants that is required by the given module. * \param symbol The symbol that is being queried. * \return The list of needed NDArray. */ - Array GetRequiredMetadata(const std::string& symbol) { + Array GetRequiredConstants(const std::string& symbol) { Array ret; - ICHECK_GT(sym_vars_.count(symbol), 0U) << "No symbol is recorded for " << symbol; - std::vector vars = sym_vars_[symbol]; + ICHECK_GT(const_vars_by_symbol_.count(symbol), 0U) << "No symbol is recorded for " << symbol; + std::vector vars = const_vars_by_symbol_[symbol]; for (const auto& it : vars) { - ICHECK_GT(metadata_.count(it), 0U) << "Found not recorded constant variable: " << it; - ret.push_back(metadata_[it]); + ICHECK_GT(const_var_ndarray_.count(it), 0U) << "Found not recorded constant variable: " << it; + ret.push_back(const_var_ndarray_[it]); } return ret; } @@ -102,12 +103,12 @@ class MetadataModuleNode : public ModuleNode { * for runtime lookup. * * \note A module could be like the following: - * MetadataModuleNode (contains all the metadata) + * ConstLoaderModuleNode (contains all the constants) * - CSourceModule * - JSON runtime module * * The initializer iterates through the imported modules and intilizes the - * found module accordingly by passing the needed metadata into it. + * found module accordingly by passing the needed constants into it. */ void InitSubModule(const std::string& symbol) { PackedFunc init(nullptr); @@ -116,8 +117,8 @@ class MetadataModuleNode : public ModuleNode { std::string init_name = "__init_" + symbol; init = it.GetFunction(init_name, false); if (init != nullptr) { - auto md = GetRequiredMetadata(symbol); - // Initialize the module with metadata. + auto md = GetRequiredConstants(symbol); + // Initialize the module with constants. int ret = init(md); // Report the error if initialization is failed. ICHECK_EQ(ret, 0) << TVMGetLastError(); @@ -128,32 +129,32 @@ class MetadataModuleNode : public ModuleNode { void SaveToBinary(dmlc::Stream* stream) final { std::vector variables; - std::vector metadata; - for (const auto& it : metadata_) { + std::vector const_var_ndarray; + for (const auto& it : const_var_ndarray_) { String var_name = it.first; variables.push_back(var_name); - metadata.push_back(it.second); + const_var_ndarray.push_back(it.second); } // Save all variables in the function. stream->Write(variables); // Save all constant data. - uint64_t sz = static_cast(metadata.size()); + uint64_t sz = static_cast(const_var_ndarray.size()); stream->Write(sz); for (uint64_t i = 0; i < sz; i++) { - metadata[i].Save(stream); + const_var_ndarray[i].Save(stream); } // Save the symbol to list of required constant variables mapping std::vector symbols; std::vector> const_vars; - for (const auto& it : sym_vars_) { + for (const auto& it : const_vars_by_symbol_) { symbols.push_back(it.first); const_vars.push_back(it.second); } stream->Write(symbols); - sz = static_cast(sym_vars_.size()); + sz = static_cast(const_vars_by_symbol_.size()); stream->Write(sz); for (uint64_t i = 0; i < sz; i++) { stream->Write(const_vars[i]); @@ -165,9 +166,9 @@ class MetadataModuleNode : public ModuleNode { // Load the variables. std::vector variables; - ICHECK(stream->Read(&variables)) << "Loading variables failed"; + ICHECK(stream->Read(&variables)) << "Loading variable names failed"; uint64_t sz; - ICHECK(stream->Read(&sz, sizeof(sz))) << "Loading metadata size failed"; + ICHECK(stream->Read(&sz, sizeof(sz))) << "Loading number of vars failed"; ICHECK_EQ(static_cast(sz), variables.size()) << "The number of variables and ndarray counts must match"; // Load the list of ndarray. @@ -178,10 +179,10 @@ class MetadataModuleNode : public ModuleNode { arrays.push_back(temp); } - std::unordered_map metadata; + std::unordered_map const_var_ndarray; for (uint64_t i = 0; i < sz; i++) { - ICHECK_EQ(metadata.count(variables[i]), 0U); - metadata[variables[i]] = arrays[i]; + ICHECK_EQ(const_var_ndarray.count(variables[i]), 0U); + const_var_ndarray[variables[i]] = arrays[i]; } // Load the symbol to list of required constant variables mapping @@ -196,12 +197,12 @@ class MetadataModuleNode : public ModuleNode { const_vars.push_back(vars); } - std::unordered_map> sym_vars; + std::unordered_map> const_vars_by_symbol; for (uint64_t i = 0; i < sz; i++) { - sym_vars[symbols[i]] = const_vars[i]; + const_vars_by_symbol[symbols[i]] = const_vars[i]; } - auto n = make_object(metadata, sym_vars); + auto n = make_object(const_var_ndarray, const_vars_by_symbol); return Module(n); } @@ -212,19 +213,21 @@ class MetadataModuleNode : public ModuleNode { */ std::unordered_map initialized_; /*! \brief Variable name to NDArray mapping. */ - std::unordered_map metadata_; + std::unordered_map const_var_ndarray_; /*! \brief Symbol name to required constant variables mapping. */ - std::unordered_map> sym_vars_; + std::unordered_map> const_vars_by_symbol_; }; -Module MetadataModuleCreate( - const std::unordered_map& metadata, - const std::unordered_map>& sym_vars) { - auto n = make_object(metadata, sym_vars); +Module ConstLoaderModuleCreate( + const std::unordered_map& const_var_ndarray, + const std::unordered_map>& const_vars_by_symbol) { + auto n = make_object(const_var_ndarray, const_vars_by_symbol); return Module(n); } TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metadata") - .set_body_typed(MetadataModuleNode::LoadFromBinary); + .set_body_typed(ConstLoaderModuleNode::LoadFromBinary); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_const_loader") + .set_body_typed(ConstLoaderModuleNode::LoadFromBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/const_loader_module.h b/src/runtime/const_loader_module.h new file mode 100644 index 000000000000..eb548dfcf370 --- /dev/null +++ b/src/runtime/const_loader_module.h @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file const_loader_module.h + * \brief Defines an interface to use the ConstLoaderModule. + */ + +#ifndef TVM_RUNTIME_CONST_LOADER_MODULE_H_ +#define TVM_RUNTIME_CONST_LOADER_MODULE_H_ + +#include + +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief Create a ConstLoader module object. + * + * \param const_var_ndarray Maps consts var name to NDArray containing data for the var. + * \param const_vars_by_symbol Maps the name of a module init function to a list of names of + * const vars whose data will be passed to that init function. + * + * \return The created ConstLoaderModule. + */ +Module ConstLoaderModuleCreate( + const std::unordered_map& const_var_ndarray, + const std::unordered_map>& const_vars_by_symbol); + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONST_LOADER_MODULE_H_ diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index e83e1a3a7629..766b93261ac0 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -52,19 +53,19 @@ inline String get_name_mangled(const String& module_name, const String& name) { /*! * \brief Create a metadata module object. * - * \param metadata The variable name to ndarray mapping. - * \param sym_vars The symbol to the list of required constant variables - * mapping. + * \param metadata Exported metadata structure. * * \return The created metadata module. */ -Module MetadataModuleCreate( - const std::unordered_map& metadata, - const std::unordered_map>& sym_vars); +Module MetadataModuleCreate(metadata::Metadata metadata); + +namespace launch_param { /*! \brief A tag to specify whether or not dynamic shared memory is used */ constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; +} // namespace launch_param + /*! \brief function information needed by device */ struct FunctionInfo { std::string name; diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc index 7ca333b06c82..90469fabad2c 100644 --- a/src/runtime/metadata.cc +++ b/src/runtime/metadata.cc @@ -18,22 +18,32 @@ */ /*! - * \file metadata.cc - * \brief Implementations of the runtime component of Metadata. + * \file tvm/runtime/metadata.h + * \brief Defines implementations of TVM metadata which can exist in the runtime. */ +#include +#include #include +#include + +#include namespace tvm { namespace runtime { namespace metadata { +TVM_REGISTER_OBJECT_TYPE(MetadataBaseNode); + ArrayAccessor MetadataNode::inputs() { return ArrayAccessor(data_->inputs, data_->num_inputs); } ArrayAccessor MetadataNode::outputs() { return ArrayAccessor(data_->outputs, data_->num_outputs); } +ArrayAccessor MetadataNode::pools() { + return ArrayAccessor(data_->pools, data_->num_pools); +} TVM_REGISTER_OBJECT_TYPE(MetadataBaseNode); @@ -52,5 +62,62 @@ TensorInfo::TensorInfo(const struct ::TVMTensorInfo* data) TVM_REGISTER_OBJECT_TYPE(TensorInfoNode); } // namespace metadata + +class MetadataModuleNode : public ::tvm::runtime::ModuleNode { + public: + explicit MetadataModuleNode(runtime::metadata::Metadata metadata) + : metadata_{::std::move(metadata)} {} + + const char* type_key() const { return "metadata_module"; } + + static Module LoadFromBinary() { + return Module(make_object(runtime::metadata::Metadata())); + } + + void SaveToBinary(dmlc::Stream* stream) final {} + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "get_metadata") { + return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { + if (!metadata_.defined()) { + TVMFunctionHandle f_handle; + int32_t ret_code = TVMBackendGetFuncFromEnv(this, symbol::tvm_get_c_metadata, &f_handle); + ICHECK_EQ(ret_code, 0) << "Unable to locate " << symbol::tvm_get_c_metadata + << " PackedFunc"; + + TVMValue ret_value; + int ret_type_code; + ret_code = TVMFuncCall(f_handle, nullptr, nullptr, 0, &ret_value, &ret_type_code); + ICHECK_EQ(ret_code, 0) << "Invoking " << symbol::tvm_get_c_metadata + << ": TVMFuncCall returned " << ret_code; + + ICHECK_EQ(ret_type_code, kTVMOpaqueHandle) + << "Expected kOpaqueHandle returned; got " << ret_type_code; + ICHECK(ret_value.v_handle != nullptr) + << symbol::tvm_get_c_metadata << " returned nullptr"; + + metadata_ = runtime::metadata::Metadata( + static_cast(ret_value.v_handle)); + } + + *rv = metadata_; + return; + }); + } + + return PackedFunc(); + } + + private: + runtime::metadata::Metadata metadata_; +}; + +Module MetadataModuleCreate(metadata::Metadata metadata) { + return Module(make_object(metadata)); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metadata_module") + .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = MetadataModuleNode::LoadFromBinary(); }); + } // namespace runtime } // namespace tvm diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index d577770db1a9..4122f9d0798e 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -24,6 +24,7 @@ #ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ +#include #include #include @@ -205,7 +206,7 @@ class LaunchParamConfig { std::vector filled(6, false); for (size_t i = 0; i < launch_param_tags.size(); ++i) { const std::string& tag = launch_param_tags[i]; - if (tag == kUseDynamicSharedMemoryTag) { + if (tag == launch_param::kUseDynamicSharedMemoryTag) { ICHECK_EQ(i, launch_param_tags.size() - 1) << "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags."; use_dyn_shared_memory_ = true; diff --git a/src/target/build_common.h b/src/target/build_common.h index c66c2b52822e..6c94ec8703b7 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -58,7 +58,7 @@ inline std::unordered_map ExtractFuncInfo(co } if (auto opt = f->GetAttr(tir::attr::kDeviceUseDynSharedMemory)) { if (opt.value()) { - info.launch_param_tags.push_back(runtime::kUseDynamicSharedMemoryTag); + info.launch_param_tags.push_back(runtime::launch_param::kUseDynamicSharedMemoryTag); } } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); diff --git a/src/target/metadata.h b/src/target/metadata.h index 2621d5d4e65d..b8ca24580f15 100644 --- a/src/target/metadata.h +++ b/src/target/metadata.h @@ -71,6 +71,17 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { v->Visit("outputs", &outputs_metadata_array); int64_t num_outputs_cpp = num_outputs(); v->Visit("num_outputs", &num_outputs_cpp); + auto pools_array = Array(); + auto pools_accessor = pools(); + pools_array.reserve(num_pools()); + for (int64_t i = 0; i < num_pools(); ++i) { + pools_array.push_back(::tvm::runtime::metadata::TensorInfo{pools_accessor[i]}); + } + ::tvm::runtime::metadata::MetadataArray pools_metadata_array{ + pools_array, ::tvm::runtime::metadata::MetadataTypeIndex::kMetadata, "TVMTensorInfo"}; + v->Visit("pools", &pools_metadata_array); + int64_t num_pools_cpp = num_pools(); + v->Visit("num_pools", &num_pools_cpp); ::std::string mod_name_cpp{data()->mod_name}; v->Visit("mod_name", &mod_name_cpp); } @@ -86,19 +97,22 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNode { public: InMemoryMetadataNode() - : InMemoryMetadataNode(0 /* version */, {} /* inputs */, {} /* outputs */, + : InMemoryMetadataNode(0 /* version */, {} /* inputs */, {} /* outputs */, {} /* pools */, "" /* mod_name */) {} InMemoryMetadataNode(int64_t version, const ::std::vector<::tvm::runtime::metadata::TensorInfo>& inputs, const ::std::vector<::tvm::runtime::metadata::TensorInfo>& outputs, + const ::std::vector<::tvm::runtime::metadata::TensorInfo>& pools, const ::tvm::runtime::String mod_name) : VisitableMetadataNode{&storage_}, - inputs_{new struct TVMTensorInfo[inputs.size()]()}, + inputs_{new struct TVMTensorInfo[inputs.size()]}, inputs_objs_{inputs}, - outputs_{new struct TVMTensorInfo[outputs.size()]()}, + outputs_{new struct TVMTensorInfo[outputs.size()]}, outputs_objs_{outputs}, + pools_{new struct TVMTensorInfo[pools.size()]}, + pools_objs_{pools}, mod_name_{mod_name}, - storage_{version, nullptr, 0, nullptr, 0, mod_name_.c_str()} { + storage_{version, nullptr, 0, nullptr, 0, nullptr, 0, mod_name_.c_str()} { storage_.inputs = inputs_.get(); storage_.num_inputs = inputs.size(); for (unsigned int i = 0; i < inputs.size(); ++i) { @@ -109,6 +123,11 @@ class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNo for (unsigned int i = 0; i < outputs.size(); ++i) { outputs_.get()[i] = *outputs[i]->data(); } + storage_.pools = pools_.get(); + storage_.num_pools = pools.size(); + for (unsigned int i = 0; i < pools.size(); ++i) { + pools_.get()[i] = *pools[i]->data(); + } } private: @@ -116,6 +135,8 @@ class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNo std::vector<::tvm::runtime::metadata::TensorInfo> inputs_objs_; ::std::unique_ptr outputs_; std::vector<::tvm::runtime::metadata::TensorInfo> outputs_objs_; + ::std::unique_ptr pools_; + std::vector<::tvm::runtime::metadata::TensorInfo> pools_objs_; ::std::string mod_name_; struct ::TVMMetadata storage_; }; diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index 2facf1de64d5..8abd18c1d8f3 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -25,8 +25,10 @@ #include +#include #include +#include "../runtime/const_loader_module.h" #include "../runtime/meta_data.h" #include "llvm/llvm_module.h" #include "source/source_module.h" @@ -34,8 +36,136 @@ namespace tvm { namespace codegen { +static runtime::Module CreateCrtMetadataModule( + runtime::Module target_module, Target target, relay::Runtime runtime, + relay::backend::ExecutorCodegenMetadata metadata, + Array non_crt_exportable_modules, + Array crt_exportable_modules, + const std::unordered_map& const_var_ndarray) { + if (!non_crt_exportable_modules.empty()) { + std::string non_exportable_modules; + for (unsigned int i = 0; i < non_crt_exportable_modules.size(); i++) { + if (i > 0) { + non_exportable_modules += ", "; + } + auto mod = non_crt_exportable_modules[i]; + auto pf_sym = mod.GetFunction("get_symbol"); + if (pf_sym != nullptr) { + non_exportable_modules += pf_sym().operator std::string(); + } else { + non_exportable_modules += + std::string{"(module type_key="} + mod->type_key() + std::string{")"}; + } + } + CHECK(false) << "These " << non_crt_exportable_modules.size() + << " modules are not exportable to C-runtime: " << non_exportable_modules; + } + + if (target->kind->name == "c") { + crt_exportable_modules.push_back(target_module); + target_module = + CreateCSourceCrtMetadataModule(crt_exportable_modules, target, runtime, metadata); + } else if (target->kind->name == "llvm") { +#ifdef TVM_LLVM_VERSION + crt_exportable_modules.push_back(target_module); + target_module = CreateLLVMCrtMetadataModule(crt_exportable_modules, target, runtime); +#else // TVM_LLVM_VERSION + LOG(FATAL) << "TVM was not built with LLVM enabled."; +#endif // TVM_LLVM_VERSION + } + + return target_module; +} + +// TODO(areusch,masahi): Unify metadata representation and remove the need for this function +static runtime::metadata::Metadata ConvertMetaData( + relay::backend::ExecutorCodegenMetadata metadata) { + ICHECK(metadata.defined()); + ICHECK_NOTNULL(metadata->pool_inputs); + + std::vector inputs; + for (size_t i = 0; i < metadata->inputs.size(); ++i) { + auto v = metadata->inputs[i]; + auto ttype = metadata->input_tensor_types[i]; + inputs.push_back( + runtime::metadata::TensorInfo(make_object( + v->name_hint, relay::backend::ShapeToJSON(ttype->shape), ttype->dtype))); + } + + std::vector outputs; + auto output_ttypes = metadata->output_tensor_types; + for (size_t i = 0; i < output_ttypes.size(); ++i) { + auto ttype = output_ttypes[i]; + std::stringstream name; + name << "output" << i; + outputs.push_back( + runtime::metadata::TensorInfo(make_object( + name.str(), relay::backend::ShapeToJSON(ttype->shape), ttype->dtype))); + } + + std::vector pools; + for (size_t i = 0; i < metadata->pools.size(); ++i) { + auto var = metadata->pools[i]; + pools.push_back( + runtime::metadata::TensorInfo(make_object( + var->name_hint, + std::vector{metadata->pool_inputs.value()[var]->allocated_size}, + tvm::runtime::DataType{kDLUInt, 8, 1}))); + } + + auto n = make_object( + runtime::metadata::kMetadataVersion, inputs, outputs, pools, metadata->mod_name); + + return runtime::metadata::Metadata(std::move(n)); +} + +static runtime::Module CreateCppMetadataModule( + runtime::Module target_module, Target target, relay::Runtime runtime, + relay::backend::ExecutorCodegenMetadata metadata, + const std::unordered_map>& const_vars_by_symbol, + Array non_crt_exportable_modules, + Array crt_exportable_modules, + const std::unordered_map& const_var_ndarray) { + if (!non_crt_exportable_modules.empty()) { + runtime::Module const_loader_mod = + runtime::ConstLoaderModuleCreate(const_var_ndarray, const_vars_by_symbol); + const_loader_mod.Import(target_module); + for (const auto& it : non_crt_exportable_modules) { + const_loader_mod.Import(it); + } + target_module = const_loader_mod; + } + + if (metadata.defined()) { + runtime::metadata::Metadata runtime_metadata = ConvertMetaData(metadata); + + if (metadata->executor == runtime::kTvmExecutorAot && runtime->name == relay::kTvmRuntimeCpp) { + if (target->kind->name == "c") { + auto metadata_module = CreateCSourceCppMetadataModule(runtime_metadata); + metadata_module->Import(target_module); + target_module = metadata_module; + } else { + CHECK(false) << "Don't know how to create MetadataModule for target type " << target->str(); + } + } + } + + return target_module; +} + +/*! + * \brief Create a metadata module wrapper. The helper is used by different + * codegens, such as graph executor codegen and the vm compiler. + * + * \param params The metadata for initialization of all modules. + * \param target_module the internal module that is compiled by tvm. + * \param ext_modules The external modules that needs to be imported inside the metadata + * module(s). + * \param target The target that all the modules are compiled for + * \return The created metadata module that manages initialization of metadata. + */ runtime::Module CreateMetadataModule( - const std::unordered_map& params, + const std::unordered_map& const_var_ndarray, tvm::runtime::Module target_module, const Array& ext_modules, Target target, tvm::relay::Runtime runtime, relay::backend::ExecutorCodegenMetadata metadata) { // Here we split modules into two groups: @@ -52,19 +182,19 @@ runtime::Module CreateMetadataModule( bool is_targeting_crt = runtime->name == "crt"; // Wrap all submodules in the initialization wrapper. - std::unordered_map> sym_metadata; + std::unordered_map> const_vars_by_symbol; for (tvm::runtime::Module mod : ext_modules) { auto pf_sym = mod.GetFunction("get_symbol"); auto pf_var = mod.GetFunction("get_const_vars"); - std::vector arrays; + std::vector symbol_const_vars; if (pf_sym != nullptr && pf_var != nullptr) { String symbol = pf_sym(); Array variables = pf_var(); for (size_t i = 0; i < variables.size(); i++) { - arrays.push_back(variables[i].operator std::string()); + symbol_const_vars.push_back(variables[i].operator std::string()); } - ICHECK_EQ(sym_metadata.count(symbol), 0U) << "Found duplicated symbol: " << symbol; - sym_metadata[symbol] = arrays; + ICHECK_EQ(const_vars_by_symbol.count(symbol), 0U) << "Found duplicated symbol: " << symbol; + const_vars_by_symbol[symbol] = symbol_const_vars; } // We only need loading of serialized constant data // if there are constants present and required by the @@ -74,7 +204,7 @@ runtime::Module CreateMetadataModule( // TODO(@manupa-arm) : we should be able to use csource_metadata // if the variables are empty when all the runtime modules implement get_func_names - if (arrays.empty() && is_targeting_crt && DSOExportable(mod) && + if (symbol_const_vars.empty() && is_targeting_crt && DSOExportable(mod) && (target->kind->name == "c" || target->kind->name == "llvm")) { crt_exportable_modules.push_back(mod); } else { @@ -83,49 +213,16 @@ runtime::Module CreateMetadataModule( } if (is_targeting_crt) { - if (!non_crt_exportable_modules.empty()) { - std::string non_exportable_modules; - for (unsigned int i = 0; i < non_crt_exportable_modules.size(); i++) { - if (i > 0) { - non_exportable_modules += ", "; - } - auto mod = non_crt_exportable_modules[i]; - auto pf_sym = mod.GetFunction("get_symbol"); - if (pf_sym != nullptr) { - non_exportable_modules += pf_sym().operator std::string(); - } else { - non_exportable_modules += - std::string{"(module type_key="} + mod->type_key() + std::string{")"}; - } - } - CHECK(false) << "These " << non_crt_exportable_modules.size() - << " modules are not exportable to C-runtime: " << non_exportable_modules; - } - - if (target->kind->name == "c") { - crt_exportable_modules.push_back(target_module); - target_module = - CreateCSourceCrtMetadataModule(crt_exportable_modules, target, runtime, metadata); - } else if (target->kind->name == "llvm") { -#ifdef TVM_LLVM_VERSION - crt_exportable_modules.push_back(target_module); - target_module = CreateLLVMCrtMetadataModule(crt_exportable_modules, target, runtime); -#else // TVM_LLVM_VERSION - LOG(FATAL) << "TVM was not built with LLVM enabled."; -#endif // TVM_LLVM_VERSION - } + return CreateCrtMetadataModule(target_module, target, runtime, metadata, + non_crt_exportable_modules, crt_exportable_modules, + const_var_ndarray); } else { - if (!non_crt_exportable_modules.empty()) { - runtime::Module binary_meta_mod = runtime::MetadataModuleCreate(params, sym_metadata); - binary_meta_mod.Import(target_module); - for (const auto& it : non_crt_exportable_modules) { - binary_meta_mod.Import(it); - } - return binary_meta_mod; - } + return CreateCppMetadataModule(target_module, target, runtime, metadata, const_vars_by_symbol, + non_crt_exportable_modules, crt_exportable_modules, + const_var_ndarray); } - return target_module; } } // namespace codegen + } // namespace tvm diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index d7fb3dcf6d80..7ddea46c07bd 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -23,6 +23,7 @@ #include "codegen_c_host.h" #include +#include #include #include #include @@ -51,6 +52,10 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, std::string target_s CodeGenC::Init(output_ssa); } +void CodeGenCHost::InitGlobalContext() { + decl_stream << "void* " << tvm::runtime::symbol::tvm_module_ctx << " = NULL;\n"; +} + void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; } void CodeGenCHost::AddFunction(const PrimFunc& f) { @@ -366,6 +371,19 @@ runtime::Module BuildCHost(IRModule mod, Target target) { cg.AddFunction(aot_executor_fn); } + // NOTE: it's possible that kRuntime attr is not attached when the mod was built with tvm.build(). + // See issue #10373. + auto opt_runtime = mod->GetAttr(tvm::attr::kRuntime); + relay::Runtime runtime; + if (opt_runtime.get() != nullptr) { + runtime = opt_runtime.value(); + } else { + runtime = relay::Runtime::Create("cpp", {}); + } + if (aot_executor_fn.defined() && runtime->name == relay::kTvmRuntimeCpp) { + cg.InitGlobalContext(); + } + if (target->GetAttr("system-lib").value_or(Bool(false))) { ICHECK_EQ(target->GetAttr("runtime").value_or(""), "c") << "c target only supports generating C runtime SystemLibs"; diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 44e791ef7bc3..c0e4ee9a263c 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -40,6 +40,7 @@ class CodeGenCHost : public CodeGenC { CodeGenCHost(); void Init(bool output_ssa, bool emit_asserts, std::string target_str); + void InitGlobalContext(); void AddFunction(const PrimFunc& f); void DefineModuleName(); diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 8f8f9e1b8bf2..66287f9ad181 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -25,6 +25,7 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ #define TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ +#include #include #include #include @@ -157,6 +158,20 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, const Array& func_names, const Array& const_vars = {}); +/*! + * \brief Wrap the submodules in a metadata module. + * \param params The variable to constant mapping that is collected by the host + * module. + * \param target_module The main TIR-lowered internal runtime module + * \param modules All the external modules that needs to be imported inside the metadata module(s). + * \param target The target that all the modules are compiled for + * \param metadata Metadata which should be exported to the runtime. + * \return The wrapped module. + */ +runtime::Module CreateMetadataModule( + const std::unordered_map& params, runtime::Module target_module, + const Array& ext_modules, Target target, runtime::metadata::Metadata metadata); + /*! * \brief Create a source module for viewing and limited saving for device. * \param data The code data to be viewed. @@ -169,6 +184,16 @@ runtime::Module DeviceSourceModuleCreate( std::string data, std::string fmt, std::unordered_map fmap, std::string type_key, std::function fget_source = nullptr); +/*! + * \brief Wrap the submodules that are to be wrapped in a c-source metadata module for C runtime. + * \param modules The modules to be wrapped. + * \param target the target the modules are compiled for. + * \param metadata the metadata needed for code generation. + * \return The wrapped module. + */ +runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, + runtime::metadata::Metadata metadata); + } // namespace codegen } // namespace tvm #endif // TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 97461ca2091f..7db5d8c83a84 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -23,18 +23,23 @@ */ #include "source_module.h" +#include #include #include #include #include +#include #include +#include #include +#include #include "../../relay/backend/name_transforms.h" #include "../../runtime/file_utils.h" #include "../../support/str_escape.h" #include "../func_registry_generator.h" +#include "../metadata.h" #include "codegen_source_base.h" namespace tvm { @@ -518,6 +523,254 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { } }; +static std::string address_from_parts(const std::vector& parts) { + std::stringstream ss; + for (unsigned int i = 0; i < parts.size(); ++i) { + if (i > 0) { + ss << "_"; + } + ss << parts[i]; + } + return ss.str(); +} + +class MetadataQueuer : public AttrVisitor { + public: + using QueueItem = std::tuple; + explicit MetadataQueuer(std::vector* queue) : queue_{queue} {} + + void Visit(const char* key, double* value) final {} + void Visit(const char* key, int64_t* value) final {} + void Visit(const char* key, uint64_t* value) final {} + void Visit(const char* key, int* value) final {} + void Visit(const char* key, bool* value) final {} + void Visit(const char* key, std::string* value) final {} + void Visit(const char* key, DataType* value) final {} + void Visit(const char* key, runtime::NDArray* value) final {} + void Visit(const char* key, void** value) final {} + + void Visit(const char* key, ObjectRef* value) final { + address_parts_.push_back(key); + if (value->as() != nullptr) { + auto metadata = Downcast(*value); + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + if (arr != nullptr) { + for (unsigned int i = 0; i < arr->array.size(); i++) { + ObjectRef o = arr->array[i]; + if (o.as() != nullptr) { + std::stringstream ss; + ss << i; + address_parts_.push_back(ss.str()); + runtime::metadata::MetadataBase metadata = Downcast(o); + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + address_parts_.pop_back(); + } + } + } else { + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + } + + queue_->push_back(std::make_tuple(address_from_parts(address_parts_), + Downcast(*value))); + } + address_parts_.pop_back(); + } + + private: + std::vector* queue_; + std::vector address_parts_; +}; + +class MetadataSerializer : public AttrVisitor { + public: + static constexpr const char* kGlobalSymbol = "kTvmgenMetadata"; + using MetadataTypeIndex = ::tvm::runtime::metadata::MetadataTypeIndex; + + MetadataSerializer() : is_first_item_{true} {} + + void WriteComma() { + if (is_first_item_) { + is_first_item_ = false; + } else { + code_ << ", " << std::endl; + } + } + + void WriteKey(const char* key) { + if (key != nullptr) { + code_ << " /* " << key << "*/"; + } + } + + void Visit(const char* key, double* value) final { + WriteComma(); + code_.setf(std::ios::hex | std::ios::showbase | std::ios::fixed | std::ios::scientific, + std::ios::basefield | std::ios::showbase | std::ios::floatfield); + code_ << *value; + WriteKey(key); + } + + void Visit(const char* key, int64_t* value) final { + WriteComma(); + code_ << *value << "L"; + WriteKey(key); + } + + void Visit(const char* key, uint64_t* value) final { + WriteComma(); + code_ << *value << "UL"; + WriteKey(key); + } + void Visit(const char* key, int* value) final { + WriteComma(); + code_ << *value; + WriteKey(key); + } + void Visit(const char* key, bool* value) final { + WriteComma(); + code_ << *value; + WriteKey(key); + } + void Visit(const char* key, std::string* value) final { + WriteComma(); + code_ << "\"" << *value << "\""; + WriteKey(key); + } + void Visit(const char* key, void** value) final { + WriteComma(); + code_ << *value; + WriteKey(key); + } + void Visit(const char* key, DataType* value) final { + WriteComma(); + code_ << "{" << value->code() << ", " << value->bits() << ", " << value->lanes() << "}"; + WriteKey(key); + } + + void Visit(const char* key, runtime::NDArray* value) final { + // TODO(areusch): probably we could consolidate --link-params here, tho... + ICHECK(false) << "do not support serializing NDArray as metadata"; + } + + void VisitArray(const runtime::metadata::MetadataArrayNode* array) { + auto old_is_first_item = is_first_item_; + is_first_item_ = true; + for (unsigned int i = 0; i < array->array.size(); ++i) { + ObjectRef o = array->array[i]; + if (o->IsInstance()) { + int64_t i = Downcast(o); + Visit(nullptr, &i); + continue; + } + + if (o->IsInstance()) { + std::string s = Downcast(o); + Visit(nullptr, &s); + continue; + } + + runtime::metadata::MetadataBase metadata = Downcast(o); + std::stringstream i_str; + i_str << i; + address_.push_back(i_str.str()); + Visit(nullptr, &metadata); + address_.pop_back(); + } + is_first_item_ = old_is_first_item; + } + + void Visit(const char* key, ObjectRef* value) final { + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + if (arr != nullptr) { + WriteComma(); + if (key != nullptr) { + address_.push_back(key); + } + code_ << address_from_parts(address_); + if (key != nullptr) { + address_.pop_back(); + } + return; + } + + runtime::metadata::MetadataBase metadata = Downcast(*value); + if (key != nullptr) { // NOTE: outermost call passes nullptr key + address_.push_back(key); + } + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + if (key != nullptr) { // NOTE: outermost call passes nullptr key + address_.pop_back(); + } + } + + void CodegenMetadata(::tvm::runtime::metadata::Metadata metadata) { + decl_ << "#include " << std::endl + << "#include " << std::endl + << "#include " << std::endl; + std::vector queue; + MetadataQueuer queuer{&queue}; + queuer.Visit(kGlobalSymbol, &metadata); + + for (MetadataQueuer::QueueItem item : queue) { + auto struct_name = std::get<0>(item); + auto obj = std::get<1>(item); + auto arr = obj.as(); + is_first_item_ = true; + address_.push_back(struct_name); + if (arr != nullptr) { + const char* const_part = "const "; + if (arr->type_index == MetadataTypeIndex::kString) { + const_part = ""; + } + code_ << const_part; + switch (arr->type_index) { + case MetadataTypeIndex::kUint64: + code_ << "uint64_t"; + break; + case MetadataTypeIndex::kInt64: + code_ << "int64_t"; + break; + case MetadataTypeIndex::kBool: + code_ << "bool"; + break; + case MetadataTypeIndex::kString: + code_ << "const char*"; + break; + case MetadataTypeIndex::kHandle: + code_ << "void*"; + break; + case MetadataTypeIndex::kMetadata: + code_ << "struct " << arr->struct_name; + break; + default: + CHECK(false) << "Unknown type_index in array: " << arr->type_index + << " (struct_name=" << arr->struct_name << ")"; + break; + } + code_ << " " << struct_name << "[" << arr->array.size() << "] = {" << std::endl; + VisitArray(arr); + } else { + code_ << "const struct TVMMetadata " << struct_name << " = {" << std::endl; + Visit(nullptr, &obj); + } + address_.pop_back(); + code_ << "};" << std::endl; + } + } + + std::string GetOutput() { return decl_.str() + code_.str(); } + + private: + std::vector address_; + std::stringstream decl_; + std::stringstream code_; + bool is_first_item_; + std::unordered_set generated_struct_decls_; + std::vector is_defining_struct_; +}; + runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, relay::Runtime runtime, relay::backend::ExecutorCodegenMetadata metadata) { @@ -539,6 +792,32 @@ runtime::Module CreateCSourceCrtMetadataModule(const Array& mod return std::move(csrc_metadata_module); } +runtime::Module CreateCSourceCppMetadataModule(runtime::metadata::Metadata metadata) { + MetadataSerializer serializer; + serializer.CodegenMetadata(metadata); + std::stringstream lookup_func; + lookup_func << "#ifdef __cplusplus\n" + << "extern \"C\"\n" + << "#endif\n"; + + lookup_func << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_get_c_metadata + << "(TVMValue* arg_values, int* arg_tcodes, int " + "num_args, TVMValue* ret_values, int* ret_tcodes, void* resource_handle) {" + << std::endl; + lookup_func << " ret_values[0].v_handle = (void*) &" << MetadataSerializer::kGlobalSymbol + << ";" << std::endl; + lookup_func << " ret_tcodes[0] = kTVMOpaqueHandle;" << std::endl; + lookup_func << " return 0;" << std::endl; + lookup_func << "};" << std::endl; + + auto mod = MetadataModuleCreate(metadata); + std::vector func_names{::tvm::runtime::symbol::tvm_get_c_metadata}; + auto c = CSourceModuleCreate(serializer.GetOutput() + lookup_func.str(), "c", func_names, + Array()); + mod->Import(c); + return mod; +} + // supports limited save without cross compile class DeviceSourceModuleNode final : public runtime::ModuleNode { public: diff --git a/src/target/source/source_module.h b/src/target/source/source_module.h index 3b482a107600..2a63a8eeb814 100644 --- a/src/target/source/source_module.h +++ b/src/target/source/source_module.h @@ -26,6 +26,7 @@ #define TVM_TARGET_SOURCE_SOURCE_MODULE_H_ #include +#include #include #include @@ -36,17 +37,24 @@ namespace tvm { namespace codegen { /*! + * \brief Wrap the submodules that are to be wrapped in a c-source metadata module for C runtime. * \param modules The modules to be wrapped. * \param target the target the modules are compiled for. * \param runtime the runtime to code generate against - * \param metadata the metadata needed for code generation. + * \param metadata Compiler-generated metadata exported to runtime. * \return The wrapped module. */ runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, relay::Runtime runtime, relay::backend::ExecutorCodegenMetadata metadata); +/*! + * \brief Create C++-runtime targeted metadata module for "c" backend. + * \param metadata Compiler-generated metadata. + */ +runtime::Module CreateCSourceCppMetadataModule(runtime::metadata::Metadata metadata); + } // namespace codegen } // namespace tvm diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc index cb2b50260326..2d8b6681fa84 100644 --- a/src/tir/transforms/legalize_packed_calls.cc +++ b/src/tir/transforms/legalize_packed_calls.cc @@ -75,6 +75,12 @@ class PackedCallLegalizer : public StmtExprMutator { new_stmts.push_back(tir::Evaluate( tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), {sid_array, 0, tir::builtin::kArrData, call->args[i]}))); + new_stmts.push_back(tir::Evaluate( + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {sid_array, 0, tir::builtin::kArrDeviceType, kDLCPU}))); + new_stmts.push_back(tir::Evaluate( + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {sid_array, 0, tir::builtin::kArrDeviceId, 0}))); packed_args.push_back(sid_array); } } diff --git a/tests/cpp/aot_metadata_test.cc b/tests/cpp/aot_metadata_test.cc index 730762237488..abf37ce4569a 100644 --- a/tests/cpp/aot_metadata_test.cc +++ b/tests/cpp/aot_metadata_test.cc @@ -37,8 +37,11 @@ const int64_t kNormalOutput1Shape[3] = {3, 8, 8}; const struct TVMTensorInfo kNormalOutputs[1] = { {"output1", kNormalOutput1Shape, 3, DLDataType{3, 4, 5}}}; +const int64_t kNormalPool1Shape[3] = {3, 8, 8}; +const struct TVMTensorInfo kNormalPools[1] = {{"pool1", kNormalPool1Shape, 3, DLDataType{3, 4, 7}}}; + const struct TVMMetadata kNormal = { - TVM_METADATA_VERSION, kNormalInputs, 2, kNormalOutputs, 1, "default", + TVM_METADATA_VERSION, kNormalInputs, 2, kNormalOutputs, 1, kNormalPools, 1, "default", }; } // namespace @@ -74,6 +77,14 @@ TEST(Metadata, ParseStruct) { EXPECT_THAT(output1->shape(), ElementsAre(3, 8, 8)); EXPECT_THAT(output1->dtype(), Eq(tvm::runtime::DataType(DLDataType{3, 4, 5}))); + auto pools = md->pools(); + EXPECT_THAT(pools.size(), Eq(1)); + + auto pool1 = pools[0]; + EXPECT_THAT(pool1->name(), Eq("pool1")); + EXPECT_THAT(pool1->shape(), ElementsAre(3, 8, 8)); + EXPECT_THAT(pool1->dtype(), Eq(tvm::runtime::DataType(DLDataType{3, 4, 7}))); + EXPECT_THAT(md->mod_name(), Eq("default")); } @@ -131,7 +142,8 @@ TEST(Metadata, Visitor) { ::tvm::ReflectionVTable::Global()->VisitAttrs(md.operator->(), &v); EXPECT_THAT(v.keys, ElementsAre(StrEq("version"), StrEq("inputs"), StrEq("num_inputs"), - StrEq("outputs"), StrEq("num_outputs"), StrEq("mod_name"))); + StrEq("outputs"), StrEq("num_outputs"), StrEq("pools"), + StrEq("num_pools"), StrEq("mod_name"))); EXPECT_THAT(Downcast(v.values[0])->value, Eq(TVM_METADATA_VERSION)); EXPECT_THAT(Downcast(v.values[0])->value, Eq(TVM_METADATA_VERSION)); @@ -164,6 +176,19 @@ TEST(Metadata, Visitor) { auto num_outputs = Downcast(v.values[4]); EXPECT_THAT(num_outputs->value, Eq(1)); + + auto pool_array = Downcast(v.values[5]); + EXPECT_THAT(pool_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata)); + EXPECT_THAT(pool_array->struct_name, StrEq("TVMTensorInfo")); + auto pool1 = Downcast(pool_array->array[0]); + + EXPECT_THAT(pool1->name(), Eq("pool1")); + + auto num_pools = Downcast(v.values[6]); + EXPECT_THAT(num_pools->value, Eq(1)); + + auto mod_name = Downcast(v.values[7]); + EXPECT_THAT(mod_name, Eq("default")); } using ::tvm::runtime::make_object; @@ -184,6 +209,10 @@ TEST(Metadata, InMemory) { make_object( tvm::String("Output1"), std::vector{3, 8, 8}, tvm::runtime::DataType(DLDataType{3, 4, 5})))}), + std::vector({tvm::runtime::metadata::TensorInfo( + make_object( + tvm::String("Pool1"), std::vector{5, 10, 10}, + tvm::runtime::DataType(DLDataType{3, 4, 7})))}), "default")); auto md_data = md->data(); @@ -211,6 +240,13 @@ TEST(Metadata, InMemory) { EXPECT_THAT(tvm::runtime::DataType(output0->dtype), Eq(tvm::runtime::DataType(DLDataType({3, 4, 5})))); + auto pool0 = &md_data->pools[0]; + EXPECT_THAT(pool0->name, StrEq("Pool1")); + EXPECT_THAT(std::vector(pool0->shape, pool0->shape + pool0->num_shape), + ElementsAre(5, 10, 10)); + EXPECT_THAT(tvm::runtime::DataType(pool0->dtype), + Eq(tvm::runtime::DataType(DLDataType({3, 4, 7})))); + EXPECT_THAT(md_data->mod_name, StrEq("default")); } @@ -222,7 +258,7 @@ TEST(Metadata, ZeroElementLists) { make_object( tvm::String("Output1"), std::vector{}, tvm::runtime::DataType(DLDataType{3, 4, 5})))}), - "default")); + std::vector({}), "default")); EXPECT_THAT(md->data()->num_inputs, Eq(0)); EXPECT_THAT(md->inputs().size(), Eq(0)); @@ -233,4 +269,8 @@ TEST(Metadata, ZeroElementLists) { EXPECT_THAT(output0.num_shape, Eq(0)); EXPECT_THAT(md->outputs()[0]->shape().size(), Eq(0)); EXPECT_THAT(md->outputs()[0]->shape(), ElementsAre()); + + EXPECT_THAT(md->pools().size(), Eq(0)); + EXPECT_THAT(md->num_pools(), Eq(0)); + EXPECT_THAT(md->pools(), ElementsAre()); } diff --git a/tests/python/relay/aot/test_c_device_api.py b/tests/python/relay/aot/test_c_device_api.py index 8252ee68ade8..b84e5eb6d775 100644 --- a/tests/python/relay/aot/test_c_device_api.py +++ b/tests/python/relay/aot/test_c_device_api.py @@ -225,17 +225,24 @@ def test_without_device_api_packed_api(non_device_api_main_func): """Test a graph without the Device API with the packed internal calls""" main_func = non_device_api_main_func(interface_api="packed", use_unpacked_api=False) - assert ( - str(main_func.body) - == 'let tvm_value_3 = tir.tvm_stack_alloca("array", 1)\n' - + 'let tvm_value_2 = tir.tvm_stack_alloca("array", 1)\n' - + 'let tvm_value_1 = tir.tvm_stack_alloca("array", 1)\n' - + 'let tvm_value_0 = tir.tvm_stack_alloca("array", 1)\n' - + "tir.tvm_struct_set(tvm_value_0, 0, 1, x_buffer_var)\n" - + "tir.tvm_struct_set(tvm_value_1, 0, 1, y_buffer_var)\n" - + "tir.tvm_struct_set(tvm_value_2, 0, 1, output_buffer_var)\n" - + "tir.tvm_struct_set(tvm_value_3, 0, 1, tir.reinterpret((uint64)0))\n" - + 'tir.tvm_call_cpacked("tvmgen_default_fused_multiply", tvm_value_0, tvm_value_1, tvm_value_2, tvm_value_3)\n' + assert str(main_func.body) == ( + 'let tvm_value_3 = tir.tvm_stack_alloca("array", 1)\n' + 'let tvm_value_2 = tir.tvm_stack_alloca("array", 1)\n' + 'let tvm_value_1 = tir.tvm_stack_alloca("array", 1)\n' + 'let tvm_value_0 = tir.tvm_stack_alloca("array", 1)\n' + "tir.tvm_struct_set(tvm_value_0, 0, 1, x_buffer_var)\n" + "tir.tvm_struct_set(tvm_value_0, 0, 10, 1)\n" + "tir.tvm_struct_set(tvm_value_0, 0, 9, 0)\n" + "tir.tvm_struct_set(tvm_value_1, 0, 1, y_buffer_var)\n" + "tir.tvm_struct_set(tvm_value_1, 0, 10, 1)\n" + "tir.tvm_struct_set(tvm_value_1, 0, 9, 0)\n" + "tir.tvm_struct_set(tvm_value_2, 0, 1, output_buffer_var)\n" + "tir.tvm_struct_set(tvm_value_2, 0, 10, 1)\n" + "tir.tvm_struct_set(tvm_value_2, 0, 9, 0)\n" + "tir.tvm_struct_set(tvm_value_3, 0, 1, tir.reinterpret((uint64)0))\n" + "tir.tvm_struct_set(tvm_value_3, 0, 10, 1)\n" + "tir.tvm_struct_set(tvm_value_3, 0, 9, 0)\n" + 'tir.tvm_call_cpacked("tvmgen_default_fused_multiply", tvm_value_0, tvm_value_1, tvm_value_2, tvm_value_3)\n' ) diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py new file mode 100644 index 000000000000..48057404dd4c --- /dev/null +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -0,0 +1,197 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import re +import sys +import textwrap + +import numpy as np +import pytest + +import tvm +from tvm import relay, TVMError +from tvm.ir.module import IRModule +from tvm.relay import backend, testing, transform +from tvm.relay.testing import byoc +from tvm.relay.op.annotation import compiler_begin, compiler_end +from aot_test_utils import ( + AOTTestModel, + AOT_DEFAULT_RUNNER, + generate_ref_data, + convert_to_relay, + compile_and_run, + compile_models, + parametrize_aot_options, +) + + +def test_error_c_interface(): + interface_api = "c" + use_unpacked_api = False + test_runner = AOT_DEFAULT_RUNNER + + two = relay.add(relay.const(1), relay.const(1)) + func = relay.Function([], two) + + with pytest.raises( + tvm.TVMError, + match=re.escape( + 'Either need interface_api == "packed" (got: c) or ' + "unpacked-api == true (got: (bool)0) when targeting " + "c runtime" + ), + ): + compile_and_run( + AOTTestModel( + module=IRModule.from_expr(func), inputs={}, outputs=generate_ref_data(func, {}) + ), + test_runner, + interface_api, + use_unpacked_api, + ) + + +enable_usmp = tvm.testing.parameter(True, False) + + +def test_conv2d(enable_usmp): + RELAY_MODEL = textwrap.dedent( + """\ + #[version = "0.0.5"] + def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), int8]) { + %1 = nn.conv2d( + %data, + %weight, + padding=[2, 2], + channels=3, + kernel_size=[5, 5], + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32"); + %2 = cast(nn.max_pool2d(%1, pool_size=[3, 3]), dtype="int8"); + %3 = nn.conv2d( + %2, + %weight, + padding=[2, 2], + channels=3, + kernel_size=[5, 5], + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32"); + %4 = nn.max_pool2d(%3, pool_size=[3, 3]); + %4 + } + """ + ) + ir_mod = tvm.parser.fromtext(RELAY_MODEL) + + main_func = ir_mod["main"] + shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params} + type_dict = {p.name_hint: p.checked_type.dtype for p in main_func.params} + + weight_data = np.ones(shape_dict["weight"]).astype(type_dict["weight"]) + input_data = np.ones(shape_dict["data"]).astype(type_dict["data"]) + + params = {"weight": weight_data} + inputs = {"data": input_data} + ref_outputs = generate_ref_data(ir_mod, inputs, params) + + with tvm.transform.PassContext( + opt_level=3, config={"tir.disable_vectorize": True, "tir.usmp.enable": enable_usmp} + ): + mod = tvm.relay.build( + ir_mod, + params=params, + target="c", + executor=backend.Executor("aot", {"interface-api": "packed"}), + ) + + temp_dir = tvm.contrib.utils.TempDirectory() + test_so_path = temp_dir / "test.so" + mod.export_library(test_so_path, cc="gcc", options=["-std=c11"]) + loaded_mod = tvm.runtime.load_module(test_so_path) + runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0))) + runner.set_input(**inputs) + runner.run() + assert (runner.get_output(0).asnumpy() == list(ref_outputs.values())[0]).all() + + +def test_mobilenet(): + ir_mod, params = testing.mobilenet.get_workload(batch_size=1) + data_shape = [int(x) for x in ir_mod["main"].checked_type.arg_types[0].shape] + data = np.random.uniform(size=data_shape).astype("float32") + inputs = {"data": data} + ref_outputs = generate_ref_data(ir_mod, inputs, params) + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = tvm.relay.build( + ir_mod, + params=params, + target="c", + executor=backend.Executor("aot", {"interface-api": "packed"}), + ) + + temp_dir = tvm.contrib.utils.TempDirectory() + test_so_path = temp_dir / "test.so" + mod.export_library(test_so_path, cc="gcc", options=["-std=c11"]) + loaded_mod = tvm.runtime.load_module(test_so_path) + runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0))) + runner.set_input(**inputs) + runner.run() + assert (runner.get_output(0).asnumpy() == list(ref_outputs.values())[0]).all() + + +def test_create_executor(): + x = tvm.relay.var("x", tvm.relay.TensorType([1], dtype="float32")) + expr = tvm.relay.add(x, tvm.relay.Constant(tvm.nd.array(np.array([1], dtype="float32")))) + actual = relay.create_executor( + "aot", mod=tvm.IRModule.from_expr(tvm.relay.Function([x], expr)), target="c -executor=aot" + ).evaluate()(np.array([2], dtype="float32")) + + np.isfinite(np.array([3], dtype="float32")) + + np.testing.assert_allclose(actual.numpy(), np.array([3], dtype="float32")) + + +def test_pass_wrong_device_arg(): + x = tvm.relay.var("x", tvm.relay.TensorType([1], dtype="float32")) + expr = tvm.relay.add(x, tvm.relay.Constant(tvm.nd.array(np.array([1], dtype="float32")))) + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = tvm.relay.build( + tvm.IRModule.from_expr(tvm.relay.Function([x], expr)), + target="c", + executor=backend.Executor("aot", {"interface-api": "packed"}), + ) + + temp_dir = tvm.contrib.utils.TempDirectory() + test_so_path = temp_dir / "test.so" + mod.export_library(test_so_path, cc="gcc", options=["-std=c11"]) + loaded_mod = tvm.runtime.load_module(test_so_path) + + with pytest.raises(tvm.TVMError) as cm: + tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0), tvm.cpu(0))) + + assert ( + "Check failed: devices_.size() == 1 (2 vs. 1) : Expect exactly 1 device passed." + in str(cm.exception) + ) + # TODO write asserts for # and type of device. + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 0147b8cf755a..2ce36f19fcc8 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -16,6 +16,7 @@ # under the License. from collections import OrderedDict +import re import sys import numpy as np @@ -49,7 +50,14 @@ def test_error_c_interface_with_packed_api(): two = relay.add(relay.const(1), relay.const(1)) func = relay.Function([], two) - with pytest.raises(tvm.TVMError, match="Packed interface required for packed operators"): + with pytest.raises( + tvm.TVMError, + match=re.escape( + 'Either need interface_api == "packed" (got: c) or ' + "unpacked-api == true (got: (bool)0) when targeting " + "c runtime" + ), + ): compile_and_run( AOTTestModel( module=IRModule.from_expr(func), inputs={}, outputs=generate_ref_data(func, {}) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index ebfec0fa23a0..e4666c63c8c5 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -922,12 +922,14 @@ def test_get_input_index(target, dev): assert vm_factory.get_input_index(data_0) == 0 assert vm_factory.get_input_index("invalid") == -1 + def get_one_input_relay_mod(tensor_type, shape, data_name): - x = relay.var(data_name, shape = shape, dtype = tensor_type) + x = relay.var(data_name, shape=shape, dtype=tensor_type) y = relay.exp(x) f = relay.Function([x], y) return IRModule.from_expr(f) + @tvm.testing.parametrize_targets("llvm") def test_one_set_input(target, dev): dtype = "float32" @@ -956,11 +958,13 @@ def test_one_set_input(target, dev): assert output.dtype == ref_res.dtype tvm.testing.assert_allclose(ref_res_core, output.numpy()) + def get_multiple_input_relay_mod(tensor_type, shape, data_name0, data_name1): - x, y = [relay.var(c, shape=shape, dtype = tensor_type) for c in [data_name0, data_name1]] + x, y = [relay.var(c, shape=shape, dtype=tensor_type) for c in [data_name0, data_name1]] f = relay.Function([x, y], x + y) return IRModule.from_expr(f) + @tvm.testing.parametrize_targets("llvm") def test_multiple_set_input(target, dev): dtype = "float32" @@ -992,6 +996,7 @@ def test_multiple_set_input(target, dev): assert output.dtype == ref_res.dtype tvm.testing.assert_allclose(ref_res_core, output.numpy()) + @tvm.testing.parametrize_targets("llvm") def test_one_set_one_input(target, dev): dtype = "float32" @@ -1025,6 +1030,7 @@ def test_one_set_one_input(target, dev): assert output.dtype == ref_res.dtype tvm.testing.assert_allclose(ref_res_core, output.numpy()) + @tvm.testing.parametrize_targets("llvm") def test_multiple_set_one_input(target, dev): dtype = "float32" @@ -1065,6 +1071,7 @@ def test_multiple_set_one_input(target, dev): assert output.dtype == ref_res.dtype tvm.testing.assert_allclose(ref_res_core, output.numpy()) + @tvm.testing.parametrize_targets("llvm") def test_benchmark(target, dev): mod, params = mlp.get_workload(1) diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py index 222d647f4ea7..54561ade23e4 100644 --- a/tests/python/unittest/test_aot_legalize_packed_call.py +++ b/tests/python/unittest/test_aot_legalize_packed_call.py @@ -57,8 +57,17 @@ def tir_packed_call() -> None: with T.let(tvm_value_1, T.tvm_stack_alloca("array", 1, dtype="handle")): with T.let(tvm_value_0, T.tvm_stack_alloca("array", 1, dtype="handle")): T.evaluate(T.tvm_struct_set(tvm_value_0, 0, 1, A, dtype="handle")) + T.evaluate(T.tvm_struct_set(tvm_value_0, 0, 10, 1, dtype="handle")) + T.evaluate(T.tvm_struct_set(tvm_value_0, 0, 9, 0, dtype="handle")) + T.evaluate(T.tvm_struct_set(tvm_value_1, 0, 1, B, dtype="handle")) + T.evaluate(T.tvm_struct_set(tvm_value_1, 0, 10, 1, dtype="handle")) + T.evaluate(T.tvm_struct_set(tvm_value_1, 0, 9, 0, dtype="handle")) + T.evaluate(T.tvm_struct_set(tvm_value_2, 0, 1, C, dtype="handle")) + T.evaluate(T.tvm_struct_set(tvm_value_2, 0, 10, 1, dtype="handle")) + T.evaluate(T.tvm_struct_set(tvm_value_2, 0, 9, 0, dtype="handle")) + T.evaluate( T.tvm_call_cpacked( "tvm_test_cpacked",