From 56a9d4e328ea5e6ba3770e69c8c866c5aa9b6572 Mon Sep 17 00:00:00 2001 From: Hua Jiang Date: Fri, 17 Dec 2021 11:05:39 -0800 Subject: [PATCH] [Runtime][Pipeline Executor] Add the map logic of global input and subgraph input. (#9751) * [Runtime][Pipeline Executor] Add the map logic of global input and subgraph input. User can use "global input name" to feed input data for pipeline runtime. The name like "data_a" will be mapped into a input interface of subgraph. In this PR, we create the related logic to do the following things. 1. building the input map configuration 2. in runtime c++ module, parseing the input connection configuration then creating related data structure to record the said connection map. 3. providing the function to return the map information for verification. * address review comments. * addres review comments. * address review comments. --- python/tvm/contrib/pipeline_executor.py | 138 +++++++++++--- src/runtime/pipeline/pipeline_executor.cc | 25 ++- src/runtime/pipeline/pipeline_executor.h | 48 +++-- src/runtime/pipeline/pipeline_scheduler.cc | 2 +- src/runtime/pipeline/pipeline_scheduler.h | 3 +- src/runtime/pipeline/pipeline_struct.h | 181 ++++++++++++++----- tests/python/relay/test_pipeline_executor.py | 16 +- 7 files changed, 306 insertions(+), 107 deletions(-) diff --git a/python/tvm/contrib/pipeline_executor.py b/python/tvm/contrib/pipeline_executor.py index 37b9fed8eb91..f8f9ed1c4125 100644 --- a/python/tvm/contrib/pipeline_executor.py +++ b/python/tvm/contrib/pipeline_executor.py @@ -49,16 +49,26 @@ def build(pipe_configs): Common interface for pipeline executor factory modules. """ libs = {} - mod_n_configs = pipe_configs.get_config() + config = pipe_configs.get_config() + if "module_connection" not in config: + raise RuntimeError('"module_connection" is missing') + if "input_connection" not in config: + raise RuntimeError('"input_connection" is missing') + + mod_n_configs = config["module_connection"] config_len = len(mod_n_configs) - string_config = [{} for _ in range(config_len)] + module_string_config = [{} for _ in range(config_len)] + # Use hardware configurations to build backend modules for each subgraph. for ir_mod, mod_config in mod_n_configs.items(): - mconf = mod_config["pipeline"].copy() - mod_idx = mconf["mod_idx"] + pipe_config = mod_config["pipeline"].copy() + mod_idx = pipe_config["mod_idx"] dev = mod_config["dev"] target = mod_config["target"] build_func = relay.build - # Check whether there is a customized build function. + # Callers may need to use a customized building function to wrap the pre-building logic + # and the backend building logic. For example, in order to support a backend which only + # can do "int8" computation, the caller may need to merge the "quantization" logic + # into the building logic to creat a customized building function. if "build" in mod_config and mod_config["build"]: build_func = mod_config["build"] @@ -70,11 +80,20 @@ def build(pipe_configs): mod_name=mod_config["mod_name"], ) - mconf["dev"] = "{},{}".format(dev.device_type, dev.device_id) - # Create a pipeline configuration. - string_config[mod_idx] = mconf + pipe_config["dev"] = "{},{}".format(dev.device_type, dev.device_id) + # Use "mod_idx" as the key to create a "module_connection" map which is not only + # for the module index but also for the module connection used to build the pipeline. + module_string_config[mod_idx] = pipe_config libs[mod_idx] = {"lib": lib, "dev": dev} + # Creating a text form configuration to record the "input_connection" and the + # "module_connection" information. The "input_connection" is used to record the + # map of global input and subgraph input, and the "module_connection" is used to + # record module dependency. + string_config = {} + string_config["input_connection"] = config["input_connection"] + string_config["module_connection"] = module_string_config + return PipelineExecutorFactoryModule(libs, string_config) @@ -94,6 +113,17 @@ def __init__(self, module): self.module = module # Get the packed functions from the pipeline executor. self._get_num_outputs = self.module["get_num_outputs"] + self._get_input_pipeline_map = self.module["get_input_pipeline_map"] + + def get_input_pipeline_map(self, name): + """Using the "name" to get the corresponding subgraph index and also get the "input name" + of the corresponding subgraph interface. + Returns + ------- + input map: Array[str] + Returning the index and "input name" of the subgraph. + """ + return self._get_input_pipeline_map(name) @property def num_outputs(self): @@ -199,12 +229,48 @@ def is_pipeline_executor_interface(self): return not isinstance(self.io_owner, PipelineConfig.ModuleWrapper) def __repr__(self): - # Get all binding information. - ret = " |{}: ".format(self.name) + # Geting the binding information in the form of text. + str_format = " |{}: ".format(self.name) for binding in self.bindings: mname, dname = binding.get_name() - ret += "{0}:{1} ".format(mname, dname) - return ret + str_format += "{0}:{1} ".format(mname, dname) + + return str_format + + def check_binding_dict(self, connection_dict): + """Checking the binding dictionary. + Parameter + --------- + connection_dict : Dict[str, Any] + It is a dictionary of module connections. + """ + if "interface_name" not in connection_dict: + raise RuntimeError('"inteface_name" is missing in global config!"') + if "connection" not in connection_dict: + raise RuntimeError(f'"connection" is missing!"') + # The global interface mapping should be one-to-one. + if not connection_dict["connection"]: + raise RuntimeError("The global interface map is empty!") + if len(connection_dict["connection"]) > 1: + raise RuntimeError("A global interface maps multiple module interfaces!") + if "mod_idx" not in connection_dict["connection"][0]: + raise RuntimeError('"mod_idx" is missing!') + + def get_binding_dict(self): + """Returning the binding information in the form of dictionary. + Returns + ------- + data : Dict[str, Any] + The binding information is in the form of dictionary. + """ + dict_format = {"interface_name": self.name, "connection": []} + for binding in self.bindings: + _, dname = binding.get_name() + midx = binding.get_owner_idx() + dict_format["connection"].append({"mod_idx": midx, "interface_name": dname}) + + self.check_binding_dict(dict_format) + return dict_format def check_dag_acyclic(self, start, inputs): """This is to check whether the DAG containing these input interfaces is acyclic. @@ -243,30 +309,34 @@ def connect(self, binding): # Check whether the binding setting is correct or not. if self.io_owner == binding.io_owner: - raise RuntimeError(f"Can not bind itself.") + raise RuntimeError("Can not bind itself.") if not self.is_pipeline_executor_interface() and self.io_type == "input": - raise RuntimeError(f"Module can only bind from output interface!") + raise RuntimeError("Module can only bind from output interface!") if ( not self.is_pipeline_executor_interface() and not binding.is_pipeline_executor_interface() and binding.io_type == "output" ): - raise RuntimeError(f"Can not bind module output with another module output!") + raise RuntimeError("Can not bind module output with another module output!") if ( not self.is_pipeline_executor_interface() and binding.is_pipeline_executor_interface() and binding.io_type == "input" ): - raise RuntimeError(f"Can not bind module output with pipeline input!") + raise RuntimeError("Can not bind module output with pipeline input!") if self.is_pipeline_executor_interface() and self.io_type == "output": - raise RuntimeError(f"Global output can not be used as binding start point.") + raise RuntimeError("Global output can not be used as binding start point.") - if self.is_pipeline_executor_interface() and binding.io_type != "input": - raise RuntimeError(f"Global input can only bind with module input.") + if ( + self.is_pipeline_executor_interface() + and self.io_type == "input" + and binding.io_type != "input" + ): + raise RuntimeError("Global input can only bind with module input.") self.bindings.append(binding) if not self.is_pipeline_executor_interface(): @@ -288,7 +358,7 @@ def connect(self, binding): if not self.check_dag_acyclic( binding.io_owner, self.io_owner.input_bindings.bindings ): - raise RuntimeError(f"Illegal connection: Cause a cycle!") + raise RuntimeError("Illegal connection: Cause a cycle!") class BindingList: """Container for bindings(input or output interface). @@ -357,7 +427,9 @@ def __getitem__(self, key): if key == "output": return self.output_bindings - raise RuntimeError(f"{key} not found!") + raise RuntimeError(f"{key} not found!") + + raise RuntimeError('The data type of "key" is not supported!') def get_data_type(self, key, interface_type): """Get the module interface data type according to the key value and interface type. @@ -468,6 +540,8 @@ def get_config(self): # Use topological sort to get the correct order of modules. self.dag_topology_sort() mconfig = {} + module_connection = {} + input_connection = {} for mod in self.mod_wrapper: # Generate pipeline configuration. mconf = {} @@ -495,7 +569,7 @@ def get_config(self): mconf["mod_idx"] = module.idx mconf["output"] = output_conf - mconfig[mod] = { + module_connection[mod] = { "pipeline": mconf, "target_host": module.target_host, "mod_name": "default", @@ -505,6 +579,22 @@ def get_config(self): "dev": module.dev, } + # Create a map of pipeline input and subgraph input. + input_connection = [] + for input_name in self.input_bindings.bindings: + input_dict = self.input_bindings.bindings[input_name].get_binding_dict() + if "interface_name" not in input_dict["connection"][0]: + raise RuntimeError("interface_name is missing in connection config!") + # Creating the map of global interface and subgraph interface. + input_map = { + "global_interface_name": input_dict["interface_name"], + "mod_idx": input_dict["connection"][0]["mod_idx"], + "module_interface_name": input_dict["connection"][0]["interface_name"], + } + input_connection.append(input_map) + + mconfig["module_connection"] = module_connection + mconfig["input_connection"] = input_connection return mconfig def dag_topology_sort(self): @@ -601,11 +691,11 @@ def export_library(self, directory_path): Export the files to this directory. """ if not self.pipeline_mods: - raise RuntimeError(f"The pipeline executor has not been initialized.") + raise RuntimeError("The pipeline executor has not been initialized.") # Check if the directory_path exists. if not os.path.exists(directory_path): - raise RuntimeError(f"The directory {directory_path} does not exist.") + raise RuntimeError("The directory {directory_path} does not exist.") # Create an load configuration. load_config_file_name = "{}/load_config".format(directory_path) pipeline_config_file_name = "{}/pipeline_config".format(directory_path) diff --git a/src/runtime/pipeline/pipeline_executor.cc b/src/runtime/pipeline/pipeline_executor.cc index 3820ce942af0..32414c607df6 100644 --- a/src/runtime/pipeline/pipeline_executor.cc +++ b/src/runtime/pipeline/pipeline_executor.cc @@ -34,6 +34,14 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name, if (name == "get_num_outputs") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); }); + } else if (name == "get_input_pipeline_map") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + *rv = this->GetInputPipeplineMapping(args[0].operator String()); + } else { + LOG(FATAL) << "Function only support the input name value in the form of string"; + } + }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc(); @@ -41,6 +49,17 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name, return nullptr; } +/*! + * \brief Using the global input name to get the index, and also get the input interface name + of corresponding subgraph from the input connection configuration. + * \param The global input name. + * \return Returning the index and the input interface name of corresponding subgraph. + */ +Array PipelineExecutor::GetInputPipeplineMapping(std::string input_name) { + std::pair map = input_connection_config[input_name]; + return {std::to_string(map.first), map.second}; +} + /*! * \brief Use the mod_config information to create a graph runtime list. * \param mod_config The config information that generates by the export library function call. @@ -108,11 +127,11 @@ void PipelineExecutor::Init(const std::vector& modules, const std::strin // Use JSONReader to load pipeline configuration. std::istringstream is(pipeline_json); dmlc::JSONReader reader(&is); - PipelineConfig& pipeline_config = this->LoadPipelineConfig(&reader); - ICHECK(!pipeline_config.Empty()) << "The pipeline config information is empty."; + this->LoadConfig(&reader); + ICHECK(!pipeline_config_.Empty()) << "The pipeline config information is empty."; // Initialize the pipeline function class used for pipeline thread pool management // and schedule etc. This function returns the number of output. - num_outputs_ = pipeline_scheduler_.PipelineInit(modules, pipeline_config); + num_outputs_ = pipeline_scheduler_.PipelineInit(modules, pipeline_config_); return; } diff --git a/src/runtime/pipeline/pipeline_executor.h b/src/runtime/pipeline/pipeline_executor.h index a883ba25ec08..1ae52e07c260 100644 --- a/src/runtime/pipeline/pipeline_executor.h +++ b/src/runtime/pipeline/pipeline_executor.h @@ -24,12 +24,14 @@ #ifndef TVM_RUNTIME_PIPELINE_PIPELINE_EXECUTOR_H_ #define TVM_RUNTIME_PIPELINE_PIPELINE_EXECUTOR_H_ +#include #include #include #include #include #include +#include #include #include "pipeline_scheduler.h" @@ -67,7 +69,13 @@ class TVM_DLL PipelineExecutor : public ModuleNode { * \return The corresponding packed function. */ virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); - + /*! + * \brief Using the global input name to get the index, and also get the input interface name + of corresponding subgraph from the input connection configuration. + * \param The global input name. + * \return Returning the index and the input interface name of corresponding subgraph. + */ + Array GetInputPipeplineMapping(std::string input_name); /*! * \brief Get the number of outputs. * @@ -115,37 +123,27 @@ class TVM_DLL PipelineExecutor : public ModuleNode { /*!\brief The class used to execute and schedule the pipeline logic.*/ PipelineScheduler pipeline_scheduler_; /*!\brief The dependency information of each graph runtime module of the pipeline.*/ - PipelineConfig pipeline_config_; + ConfigPipelineExecution pipeline_config_; + /*!\brief The map of global input and subgraph input.*/ + InputConnectionConfig input_connection_config; /*!\brief The module information used to create the graph runtimes.*/ ModuleConfig mod_config_; /*!\brief How many outputs are in this pipeline executor.*/ size_t num_outputs_ = 0; /*!\brief Json loader.*/ - PipelineConfig& LoadPipelineConfig(dmlc::JSONReader* reader) { - reader->BeginArray(); - while (reader->NextArrayItem()) { - std::string key; - reader->BeginObject(); - int mod_idx = -1; - OutputMap output; - std::string dev; - while (reader->NextObjectItem(&key)) { - if (key == "mod_idx") { - reader->Read(&mod_idx); - } else if (key == "dev") { - reader->Read(&dev); - } else if (key == "output") { - reader->Read(&output); - } else { - LOG(FATAL) << "do not support key " << key; - } + void LoadConfig(dmlc::JSONReader* reader) { + reader->BeginObject(); + std::string key; + while (reader->NextObjectItem(&key)) { + if (key == "module_connection") { + reader->Read(&pipeline_config_); + } else if (key == "input_connection") { + reader->Read(&input_connection_config); + } else { + LOG(FATAL) << "do not support key " << key; } - ICHECK(mod_idx >= 0) << "Invalid mod_idx value " << mod_idx; - // Check if the output is successfully read. - ICHECK(!output.Empty()) << "Invalid output binding result."; - pipeline_config_.Insert(mod_idx, output); } - return pipeline_config_; + return; } }; } // namespace runtime diff --git a/src/runtime/pipeline/pipeline_scheduler.cc b/src/runtime/pipeline/pipeline_scheduler.cc index 82caf855a479..67a9795c47d4 100644 --- a/src/runtime/pipeline/pipeline_scheduler.cc +++ b/src/runtime/pipeline/pipeline_scheduler.cc @@ -28,7 +28,7 @@ namespace runtime { * \param pipeline_conf The dependency information of each graph executor module. */ size_t PipelineScheduler::PipelineInit(const std::vector& modules, - const PipelineConfig& pipeline_config) { + const ConfigPipelineExecution& pipeline_config) { graph_modules_ = modules; int num_output = pipeline_config.GetGlobalOutputNum(); return num_output; diff --git a/src/runtime/pipeline/pipeline_scheduler.h b/src/runtime/pipeline/pipeline_scheduler.h index 5ee127edffa3..0572e060a1b8 100644 --- a/src/runtime/pipeline/pipeline_scheduler.h +++ b/src/runtime/pipeline/pipeline_scheduler.h @@ -41,7 +41,8 @@ class PipelineScheduler { * \param modules The list of graph executor module. * \param pipeline_config The dependency information of each graph executor module. */ - size_t PipelineInit(const std::vector& modules, const PipelineConfig& pipeline_config); + size_t PipelineInit(const std::vector& modules, + const ConfigPipelineExecution& pipeline_config); private: /*!\brief The list of graph executors.*/ diff --git a/src/runtime/pipeline/pipeline_struct.h b/src/runtime/pipeline/pipeline_struct.h index 3cc9621702c1..52422b764564 100644 --- a/src/runtime/pipeline/pipeline_struct.h +++ b/src/runtime/pipeline/pipeline_struct.h @@ -25,20 +25,16 @@ #include #include #include +#include #include /*! * \brief All binding information of a output interface. */ -struct OutputBindings { - /*!\brief Output interface binding information, 'int' is the index of the module that - * uses this output data as the input interface data, 'string' is the input interface name - * of the module. - */ - std::unordered_map bindings; - /*! The index value of the global interface to which the current output are bound.*/ - int global_output_index = std::numeric_limits::min(); +class ConfigBindings { + public: /*!\brief Whether this binding is bound to the PipelineExecutor output interface.*/ - bool IsGlobalOutput() const { return global_output_index >= 0; } + bool IsGlobalOutput() const { return global_output_index_ > -1; } + /*! * \brief Create a module interface map from JSONReader. * \param reader JSON reader. @@ -59,8 +55,8 @@ struct OutputBindings { reader->Read(&input_name); } else if (key == "global_output_index") { // There should be only one global binding. - ICHECK(global_output_index < 0); - reader->Read(&global_output_index); + ICHECK(global_output_index_ < 0); + reader->Read(&global_output_index_); // When the key value is 'global_output_index', it means that this output is bound to // a global interface. global_binding = true; @@ -71,44 +67,58 @@ struct OutputBindings { // When this output is bound to a global interface, check if the global interface index // start from 0. if (global_binding) { - ICHECK(global_output_index >= 0); + ICHECK(global_output_index_ >= 0); } else { // When this output is bound to a graph executor module interface, check if the module // index start from 0. ICHECK(mod_idx >= 0); - bindings[mod_idx] = input_name; + bindings_[mod_idx] = input_name; } } } -}; + private: + /*!\brief Output interface binding information, 'int' is the index of the module that + * uses this output data as the input interface data, 'string' is the input interface name + * of the module. + */ + std::unordered_map bindings_; + /*! The index value of the global interface to which the current output are bound.*/ + int global_output_index_ = std::numeric_limits::min(); +}; /*! * \brief The binding information of all outputs of a module. */ -struct OutputMap { - /*! \brief Output binding map, 'int' is output interface index.*/ - std::unordered_map output_binding_map; - OutputMap& operator=(const OutputMap& output) { - output_binding_map = output.output_binding_map; +class ConfigOutputBindings { + public: + ConfigOutputBindings& operator=(const ConfigOutputBindings& output) { + output_binding_map_ = output.GetOutBindings(); return *this; } - /*!\brief This function is used to verify whether OutputMap is successfully loaded. - * \return Return true to indicate that this class has not been successfully loaded. + ConfigBindings& operator[](const int key) { + ICHECK(output_binding_map_.find(key) != output_binding_map_.end()); + return output_binding_map_[key]; + } + /*!brief Return the variable "output_binding_map_".*/ + std::unordered_map GetOutBindings() const { return output_binding_map_; } + /*! + *\brief This function is used to verify whether ConfigOutputBindings is successfully loaded. + *\return Return true to indicate that this class has not been successfully loaded. */ - bool Empty() { return output_binding_map.empty(); } - /*! \brief The pipeline outputs is the final outputs of pipeline, this function is used to - * get how many pipeline outputs are in this Outputmap - * \return Number of pipeline outputs. + bool Empty() { return output_binding_map_.empty(); } + /*! + * \brief The pipeline outputs is the final outputs of pipeline, this function is used to + * get how many pipeline outputs are in this Outputmap + * \return Number of pipeline outputs. */ size_t GetGlobalOutputNum(void) const { size_t num_output = 0; - for (auto bindings : output_binding_map) { + for (auto bindings : output_binding_map_) { num_output += bindings.second.IsGlobalOutput() ? 1 : 0; } return num_output; } - /*! * \brief Create a output binding map from JSONReader. * \param reader Json reader. @@ -119,7 +129,7 @@ struct OutputMap { std::string key; reader->BeginObject(); int output_idx = -1; - OutputBindings binding; + ConfigBindings binding; while (reader->NextObjectItem(&key)) { if (key == "output_idx") { reader->Read(&output_idx); @@ -130,42 +140,117 @@ struct OutputMap { } } ICHECK(output_idx >= 0); - output_binding_map[output_idx] = binding; + output_binding_map_[output_idx] = binding; } } + + private: + /*!\brief The map of output binding, 'int' is the output interface index.*/ + std::unordered_map output_binding_map_; }; + /*! * \brief The binding or dependency information of each module output interface. */ -struct PipelineConfig { - /*!\brief The key is the module index, this variable records all module pipeline configuration - * information. +class ConfigPipelineExecution { + public: + /* + *!\brief This function is used to verify whether config is loaded successfully. + * \return Return "true" to indicate that this class has not been successfully loaded. */ - std::unordered_map config; - OutputMap& operator[](int key) { - ICHECK(config.find(key) != config.end()); - return config[key]; - } - - void Insert(int key, const OutputMap& map) { config[key] = map; } - - /*!\brief This function is used to verify whether config is loaded successfully. - * \return Return true to indicate that this class has not been successfully loaded. - */ - bool Empty() { return config.empty(); } - + bool Empty() { return config_.empty(); } /*! - * \brief Get the number of global outputs. - * \return The number of outputs the entire pipeline has. + * \brief Getting the number of global outputs. + * \return The number of outputs in the entire pipeline. */ size_t GetGlobalOutputNum() const { size_t num_output = 0; - for (auto mod_output : config) { + for (auto mod_output : config_) { num_output += mod_output.second.GetGlobalOutputNum(); } return num_output; } + /*! + * \brief Create a pipeline config from JSONReader. + * \param reader Json reader. + */ + void Load(dmlc::JSONReader* reader) { + reader->BeginArray(); + while (reader->NextArrayItem()) { + std::string key; + reader->BeginObject(); + int mod_idx = -1; + ConfigOutputBindings output; + std::string dev; + while (reader->NextObjectItem(&key)) { + if (key == "mod_idx") { + reader->Read(&mod_idx); + } else if (key == "dev") { + reader->Read(&dev); + } else if (key == "output") { + reader->Read(&output); + } else { + LOG(FATAL) << "do not support key " << key; + } + } + ICHECK(mod_idx >= 0) << "Invalid mod_idx value " << mod_idx; + // Check if the output is successfully read. + ICHECK(!output.Empty()) << "Invalid output binding result."; + // Build the mapping of mod_idx and "ConfigOutputBindings". + config_[mod_idx] = output; + } + } + + private: + /* + *!\brief The key is the module index, this variable records all module pipeline configuration + * information. + */ + std::unordered_map config_; }; + +struct InputConnectionConfig { + /*!\brief The key("string") is the name of global module input interfaces. The value("pair") + * includes the index of graph module and the name of a graph module input interface. + */ + std::unordered_map> input_connection; + std::pair operator[](const std::string key) { + if (input_connection.find(key) == input_connection.end()) { + LOG(FATAL) << "Not find the key " << key; + } + return input_connection[key]; + } + /*! + * \brief Create a input connection config from JSONReader. + * \param reader Json reader. + */ + void Load(dmlc::JSONReader* reader) { + reader->BeginArray(); + while (reader->NextArrayItem()) { + reader->BeginObject(); + std::string key; + std::string global_interface_name; + std::string module_interface_name; + int mod_idx = -1; + while (reader->NextObjectItem(&key)) { + if (key == "global_interface_name") { + reader->Read(&global_interface_name); + } else if (key == "mod_idx") { + reader->Read(&mod_idx); + } else if (key == "module_interface_name") { + reader->Read(&module_interface_name); + } else { + LOG(FATAL) << "do not support key " << key; + } + } + ICHECK(mod_idx >= 0) << "Invalid mod_idx value " << mod_idx; + ICHECK(!global_interface_name.empty()) << "Invalid global interface name value"; + ICHECK(!module_interface_name.empty()) << "Invalid module interface name value"; + input_connection[global_interface_name] = make_pair(mod_idx, module_interface_name); + } + } +}; + /*! * \brief The information used to initialize the graph executor module, the information * come from the export library function call. diff --git a/tests/python/relay/test_pipeline_executor.py b/tests/python/relay/test_pipeline_executor.py index 4a9b7eacdf65..4e51f873b3fa 100644 --- a/tests/python/relay/test_pipeline_executor.py +++ b/tests/python/relay/test_pipeline_executor.py @@ -182,11 +182,11 @@ def test_pipeline(): # The pipeline input named "data_0" will be connected to a input named "data_0" # of mod1. - pipe_config["input"]["data_0"].connect(pipe_config[mod1]["input"]["data_0"]) + pipe_config["input"]["data_a"].connect(pipe_config[mod1]["input"]["data_0"]) # The pipeline Input named "data_1" will be connected to a input named "data_1" # of mod2. - pipe_config["input"]["data_1"].connect(pipe_config[mod2]["input"]["data_1"]) + pipe_config["input"]["data_b"].connect(pipe_config[mod2]["input"]["data_1"]) # The mod1 output[0] will be connected to a input named "data_0" of mod2. pipe_config[mod1]["output"][0].connect(pipe_config[mod2]["input"]["data_0"]) @@ -205,8 +205,8 @@ def test_pipeline(): # Print configueration (print(pipe_config)), the result looks like following. # # Inputs - # |data_0: mod1:data_0 - # |data_1: mod2:data_1 + # |data_a: mod1:data_0 + # |data_b: mod2:data_1 # # output # |output(1) : mod1.output(2) @@ -228,7 +228,8 @@ def test_pipeline(): pipe_config[mod3].dev = tvm.cpu(0) # Here is to check the correctness of the configuration generated by API. - assert pipe_config.get_config() == get_manual_conf([mod1, mod2, mod3], target) + mconfig = pipe_config.get_config() + assert mconfig["module_connection"] == get_manual_conf([mod1, mod2, mod3], target) # Build and create a pipeline module. with tvm.transform.PassContext(opt_level=3): @@ -249,6 +250,11 @@ def test_pipeline(): pipeline_module_test = pipeline_executor.PipelineModule.load_library(config_file_name) assert pipeline_module_test.num_outputs == 2 + input_map = pipeline_module_test.get_input_pipeline_map("data_b") + assert input_map[0] == "1" and input_map[1] == "data_1" + input_map = pipeline_module_test.get_input_pipeline_map("data_a") + assert input_map[0] == "0" and input_map[1] == "data_0" + if __name__ == "__main__": pytest.main([__file__])