diff --git a/python/tvm/contrib/pipeline_executor.py b/python/tvm/contrib/pipeline_executor.py index dc276b1b0285f..3072d871d4201 100644 --- a/python/tvm/contrib/pipeline_executor.py +++ b/python/tvm/contrib/pipeline_executor.py @@ -164,10 +164,7 @@ def set_input(self, key, value): value : array_like. The input value """ - v = self._get_input(key) - if v is None: - raise RuntimeError("Could not find '%s' in pipeline's inputs" % key) - v.copyfrom(value) + self._set_input(key, tvm.nd.array(value)) def set_params(self, params_group_name, params_data): """Set the parameter group value given the parameter group name. Note that the parameter diff --git a/src/runtime/pipeline/pipeline_executor.cc b/src/runtime/pipeline/pipeline_executor.cc index aff7e5205c948..a191f816f7159 100644 --- a/src/runtime/pipeline/pipeline_executor.cc +++ b/src/runtime/pipeline/pipeline_executor.cc @@ -94,11 +94,7 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name, * \param data_in The input data. */ void PipelineExecutor::SetInput(std::string input_name, DLTensor* data_in) { - std::pair indexs = this->GetInputIndex(input_name); - if (indexs.first < 0 || indexs.first >= static_cast(runtimes_.size())) { - LOG(FATAL) << "input name " << input_name << " not found."; - } - runtimes_[indexs.first]->SetInput(indexs.second, data_in); + global_runtime_->SetPipelineInput(input_name, data_in); } /*! * \brief get input from the runtime module. @@ -118,7 +114,7 @@ NDArray PipelineExecutor::GetInput(std::string input_name) { * \return int The module index. */ int PipelineExecutor::GetParamModuleIndex(const std::string& name) { - return param_connection_config[name]; + return param_connection_config_[name]; } /*! * \brief Using the global input name to get the index, and also get the input interface name @@ -127,7 +123,7 @@ int PipelineExecutor::GetParamModuleIndex(const std::string& name) { * \return Returning the index and the input interface name of corresponding subgraph. */ Array PipelineExecutor::GetInputPipeplineMap(std::string input_name) { - std::pair map = input_connection_config[input_name]; + std::pair map = input_connection_config_[input_name]; return {std::to_string(map.first), map.second}; } @@ -137,11 +133,11 @@ Array PipelineExecutor::GetInputPipeplineMap(std::string input_name) { * \return int The module index. */ int PipelineExecutor::GetParamsGroupPipelineMap(const std::string& name) { - return param_connection_config[name]; + return param_connection_config_[name]; } /*!\brief Run the pipeline executor.*/ -void PipelineExecutor::Run() { pipeline_scheduler_.PipelineRun(runtimes_, pipeline_config_); } +void PipelineExecutor::Run() { pipeline_scheduler_.PipelineRun(runtimes_); } /*! * \brief return A list of global output data. */ @@ -226,7 +222,7 @@ void PipelineExecutor::SetParam(std::string param_group_name, std::string param_ * \return std::pair A pair of module index and the input index. */ std::pair PipelineExecutor::GetInputIndex(const std::string& name) { - std::pair index = input_connection_config[name]; + std::pair index = input_connection_config_[name]; auto gruntime = runtimes_[index.first]; return std::make_pair(index.first, gruntime->GetInputIndex(index.second)); } @@ -250,7 +246,9 @@ void PipelineExecutor::Init(const std::vector& modules, const std::strin num_outputs_ = pipeline_config_.GetGlobalOutputNum(); // Initialize the pipeline function class used for pipeline thread pool management // and schedule etc. This function returns a list of runtime. - runtimes_ = pipeline_scheduler_.PipelineInit(modules, pipeline_config_); + global_runtime_ = + pipeline_scheduler_.PipelineInit(modules, pipeline_config_, input_connection_config_); + runtimes_ = global_runtime_->GetRuntimeList(); return; } diff --git a/src/runtime/pipeline/pipeline_executor.h b/src/runtime/pipeline/pipeline_executor.h index 9a24acdc2741a..9f9b24bdf0bec 100644 --- a/src/runtime/pipeline/pipeline_executor.h +++ b/src/runtime/pipeline/pipeline_executor.h @@ -176,15 +176,16 @@ class TVM_DLL PipelineExecutor : public ModuleNode { /*!\brief The dependency information of each graph runtime module of the pipeline.*/ ConfigPipelineExecution pipeline_config_; /*!\brief The map of global input and subgraph input.*/ - InputConnectionConfig input_connection_config; + InputConnectionConfig input_connection_config_; /*!\brief The map includes global parameters groups and runtime modules.*/ - ParamConnectionConfig param_connection_config; + ParamConnectionConfig param_connection_config_; /*!\brief The module information used to create the graph runtimes.*/ ModuleConfig mod_config_; /*!\brief How many outputs are in this pipeline executor.*/ size_t num_outputs_ = 0; /*!The list of backend runtime module.*/ std::vector> runtimes_; + std::shared_ptr global_runtime_; /*!\brief Json loader.*/ void LoadConfig(dmlc::JSONReader* reader) { reader->BeginObject(); @@ -193,9 +194,9 @@ class TVM_DLL PipelineExecutor : public ModuleNode { if (key == "module_connection") { reader->Read(&pipeline_config_); } else if (key == "input_connection") { - reader->Read(&input_connection_config); + reader->Read(&input_connection_config_); } else if (key == "param_connection") { - reader->Read(¶m_connection_config); + reader->Read(¶m_connection_config_); } else { LOG(FATAL) << "do not support key " << key; } diff --git a/src/runtime/pipeline/pipeline_scheduler.cc b/src/runtime/pipeline/pipeline_scheduler.cc index a417feb683017..63603de226279 100644 --- a/src/runtime/pipeline/pipeline_scheduler.cc +++ b/src/runtime/pipeline/pipeline_scheduler.cc @@ -28,16 +28,20 @@ namespace runtime { * \param modules The list of graph executor modules. * \param pipeline_conf The dependency information of each graph executor module. */ -std::vector> PipelineScheduler::PipelineInit( - const std::vector& modules, const ConfigPipelineExecution& pipeline_config) { +std::shared_ptr PipelineScheduler::PipelineInit( + const std::vector& modules, const ConfigPipelineExecution& pipeline_config, + const InputConnectionConfig& input_connection_config) { std::vector> runtimes; graph_modules_ = modules; - global_runtime_ = std::make_shared(GLOBAL_MODULE_INDEX); // Creating a list of runtimes. for (size_t i = 0; i < graph_modules_.size(); i++) { auto run_item = std::make_shared(graph_modules_[i], i); runtimes.push_back(run_item); } + // Creating the global runtime to represent the pipeline executor. + global_runtime_ = std::make_shared(GLOBAL_MODULE_INDEX); + // Initialize the data structures that are used by pipeline logic. + global_runtime_->InitializePipeline(input_connection_config, runtimes); // Creating a list of NDArray in order to storage the outputs data. auto global_output_map = pipeline_config.GetGlobalConfigOutputBindings(); for (size_t i = 0; i < global_output_map.size(); i++) { @@ -52,15 +56,14 @@ std::vector> PipelineScheduler::PipelineInit( for (auto runtime : runtimes) { runtime->InitializePipeline(pipeline_config, &runtimes, global_runtime_); } - return runtimes; + return global_runtime_; } /*! * \brief Running pipeline logic. * \param runtimes A list of backend runtime modules. * \param pipeline_config The dependency configuration of each runtime module. */ -void PipelineScheduler::PipelineRun(const std::vector>& runtimes, - ConfigPipelineExecution pipeline_config) { +void PipelineScheduler::PipelineRun(const std::vector>& runtimes) { runtimes.front()->RunPipeline(); } /*! diff --git a/src/runtime/pipeline/pipeline_scheduler.h b/src/runtime/pipeline/pipeline_scheduler.h index 9fb357b8e9f0a..1141af26f57b9 100644 --- a/src/runtime/pipeline/pipeline_scheduler.h +++ b/src/runtime/pipeline/pipeline_scheduler.h @@ -41,15 +41,14 @@ class PipelineScheduler { * \param modules The list of graph executor module. * \param pipeline_config The dependency information of each graph executor module. */ - std::vector> PipelineInit( - const std::vector& modules, const ConfigPipelineExecution& pipeline_config); + std::shared_ptr PipelineInit(const std::vector& modules, + const ConfigPipelineExecution& pipeline_config, + const InputConnectionConfig& input_connection_config); /*! * \brief Running the pipeline logic. * \param runtimes A list of backend runtime modules. - * \param pipeline_config The dependency configuration of each runtime module. */ - void PipelineRun(const std::vector>& runtimes, - ConfigPipelineExecution pipeline_config); + void PipelineRun(const std::vector>& runtimes); /*! * \brief Get a list of outputs. */ diff --git a/src/runtime/pipeline/pipeline_struct.h b/src/runtime/pipeline/pipeline_struct.h index 82dc6f53c90c5..bdfeafe2afbba 100644 --- a/src/runtime/pipeline/pipeline_struct.h +++ b/src/runtime/pipeline/pipeline_struct.h @@ -547,17 +547,44 @@ struct InputConnectionConfig { * includes the index of graph module and the name of a graph module input interface. */ std::unordered_map> input_connection; + /*!\brief The map includes the global input name and global input index.*/ + std::unordered_map input_name_index_map; + /*!\brief The map includes the runtime index and the pair of global and runtime input name.*/ + std::unordered_map>> input_runtime_map; std::pair operator[](const std::string key) { if (input_connection.find(key) == input_connection.end()) { LOG(FATAL) << "Not find the key " << key; } return input_connection[key]; } + /*! + * \brief Get the global input index through the input name. + * \param input_name The global input name. + */ + int GetInputIndex(std::string input_name) { + auto input_index_iter = input_name_index_map.find(input_name); + if (input_index_iter == input_name_index_map.end()) { + LOG(FATAL) << "Do not finding the input name! " << input_name; + } + return input_index_iter->second; + } + /*!\brief Enumerating the input binding configuration for a specified runtime.*/ + void VisitConfig(BindingConfigParseFunc parse_function, int runtime_index) { + auto config = input_runtime_map.find(runtime_index); + // Only do the processing when there are input configuration in the runtime. + if (config != input_runtime_map.end()) { + for (auto x : config->second) { + int input_index = GetInputIndex(x.first); + parse_function(input_index, runtime_index, x.second); + } + } + } /*! * \brief Create an input connection config from JSONReader. * \param reader Json reader. */ void Load(dmlc::JSONReader* reader) { + int global_interface_index = 0; reader->BeginArray(); while (reader->NextArrayItem()) { reader->BeginObject(); @@ -568,6 +595,7 @@ struct InputConnectionConfig { while (reader->NextObjectItem(&key)) { if (key == "global_interface_name") { reader->Read(&global_interface_name); + input_name_index_map[global_interface_name] = global_interface_index++; } else if (key == "mod_idx") { reader->Read(&mod_idx); } else if (key == "module_interface_name") { @@ -580,6 +608,9 @@ struct InputConnectionConfig { ICHECK(!global_interface_name.empty()) << "Invalid global interface name value"; ICHECK(!module_interface_name.empty()) << "Invalid module interface name value"; input_connection[global_interface_name] = make_pair(mod_idx, module_interface_name); + // Creating a map includes the runtime index and the pair of gloal and runtime interface. + input_runtime_map[mod_idx].push_back( + std::make_pair(global_interface_name, module_interface_name)); } } }; @@ -640,6 +671,13 @@ class BasicRuntime { explicit BasicRuntime(int runtime_idx) : runtime_idx_(runtime_idx) {} /*!\brief Return the index of the current module.*/ int GetModuleIndex() { return runtime_idx_; } + /*!\brief Setting the data to this runtime via input index.*/ + virtual void SetInput(const int index, DLTensor* data_in) {} + /*! + * \brief Notifying an input is ready. + * \param input_index The index of 'input interface' which is ready for data. + */ + virtual void ParentNotify(int input_index) {} /*! *\brief Creating a parent notification. *\param input_index The input index of the 'current runtime'. @@ -647,18 +685,23 @@ class BasicRuntime { *\param parent_output_idx The output index of the 'parent runtime' which will send * the notification. */ - virtual void CreateParentsNotify(int input_index, int parent_idx, int parent_output_idx) {} - /*! - * \brief Notifying an input is ready. - * \param input_index The index of 'input interface' which is ready for data. - */ - virtual void ParentNotify(int input_index) {} + void CreateParentsNotify(int input_index, int parent_idx, int parent_output_idx) { + if (parents_notify_.find(input_index) != parents_notify_.end()) { + LOG(FATAL) << "The notification associated with the input interface " << input_index + << " in runtime " << runtime_idx_ << " already been created!"; + return; + } + parents_notify_[input_index] = + std::make_shared(ModuleInterfaceID(parent_idx, parent_output_idx, OUTPUT)); + } protected: /*!\brief The index of runtime indicates the runtime position in the pipeline.*/ int runtime_idx_; /*!\brief A list of runtime which depends on the current runtime.*/ std::unordered_map children_; + /*!\brief A map including the runtime input index and the notification data structure.*/ + std::unordered_map> parents_notify_; /*! * \brief A list of SPSC input queues in which the input interface will poll the data sent from * other backend cores. @@ -666,10 +709,12 @@ class BasicRuntime { std::unordered_map> input_queue_; /*! - * \brief A list of SPSC output queues in which the output interface will push the data to + * \brief A list of SPSC forward queues in which the parent interface will push the data to * other backend cores. */ - std::unordered_map output_queue_; + std::unordered_map forward_queue_; + /*!\brief The state of the pipeline.*/ + std::atomic pipeline_state_{STOPPED}; /*! * \brief Generate the ID of an input queue. * \param runtime_index The index of backend runtime. @@ -679,17 +724,47 @@ class BasicRuntime { ModuleInterfaceID GenerateQueueID(int runtime_index, int interface_index, InterfaceType type) { return ModuleInterfaceID(runtime_index, interface_index, type); } + /*! + * \brief Forwarding the data into the child runtimes. + * \param forward_queue_map The map includes the id of queue and the forwarding queue. + * \param child_runtime The child runtime. + * \param child_input_index The child runtime index. + * \param data The data used for forwarding. + */ + bool ForwardData(const ForwardQueueMap* forward_queue_map, + std::shared_ptr child_runtime, int child_input_index, + const DLTensor* data) { + auto child_runtime_index = child_runtime->GetModuleIndex(); + auto queue_id = GenerateQueueID(child_runtime_index, child_input_index, INPUT); + if (forward_queue_map->find(queue_id) == forward_queue_map->end()) { + LOG(FATAL) << "Not find the associated queue of the runtime(" << child_runtime_index + << ").input(" << child_input_index << ") which is connected with runtime(" + << runtime_idx_; + } + auto forward_queue = forward_queue_map->at(queue_id); + // If the queue is full, keep try until the push get success or the pipeline run into + // a STOP state. + while (!forward_queue->Push(data)) { + if (PipelineIsStop()) { + LOG(INFO) << "The forwarding process is stopped after the pipeline status is changed" + << " into stop."; + return false; + } + } + child_runtime->ParentNotify(child_input_index); + return true; + } /*! * \brief Creating a forwarding queue for the pair of an output interface and an input interface. - * \param output_idx The index of an output interface which will send the forwarding data. + * \param forward_inf_idx The index of an interface which will send the forwarding data. * \param child_runtime The backend runtime which owns the input interface. * \param input_index The index of an input interface which will receive the forwarding data. */ - void CreateForwardingQueue(int output_idx, std::shared_ptr child_runtime, + void CreateForwardingQueue(int forward_inf_idx, std::shared_ptr child_runtime, int input_index) { auto queue_id = GenerateQueueID(child_runtime->GetModuleIndex(), input_index, INPUT); // The forwarding queue map of a specified output interface. - auto& queue_map = output_queue_[output_idx]; + auto& queue_map = forward_queue_[forward_inf_idx]; if (queue_map.find(queue_id) != queue_map.end()) { LOG(FATAL) << "The queue " << queue_id.runtime_idx << "." << queue_id.runtime_interface_idx << " is already created!"; @@ -709,43 +784,10 @@ class BasicRuntime { void AppendInputQueue(int input_index, std::shared_ptr queue) { input_queue_[input_index] = queue; } -}; -/*! - * \brief This global runtime represents the pipeline executor and exposes the input and output - * interface. - */ -class GlobalRuntime : public BasicRuntime { - public: - explicit GlobalRuntime(int runtime_idx) : BasicRuntime(runtime_idx) {} - /*!\brief Whether the output data is ready.*/ - bool DataIsReady(bool wait_data) { - bool data_ready = true; - for (auto queue_pair : input_queue_) { - auto queue = queue_pair.second; - if (queue->Empty()) { - data_ready = false; - break; - } - } - if (!data_ready && wait_data) { - // TODO(huajsj): Waitting the data ready message. - } - return data_ready; - } - /*!\brief Get the output data.*/ - bool GetOutput(Array* outputs, bool wait_data = false) { - if (!DataIsReady(wait_data)) { - return false; - } - for (auto queue_pair : input_queue_) { - auto output_index = queue_pair.first; - auto queue = queue_pair.second; - QueueData data(const_cast(((*outputs)[output_index]).operator->())); - if (!queue->Poll(&data)) { - LOG(FATAL) << "There is no data in the data queue, it should not happen!"; - } - } - return true; + /*!\brief Checking if the pipeline is stopped or stopping.*/ + const bool PipelineIsStop() const { + auto state = pipeline_state_.load(std::memory_order_acquire); + return state == STOPPING || state == STOPPED; } }; /* @@ -759,10 +801,6 @@ class BackendRuntime : public BasicRuntime { Module module_; /*\brief The thread is associated with the current runtime*/ std::thread thread_; - /*!\brief The state of the pipeline.*/ - std::atomic pipeline_state_{STOPPED}; - /*!\brief A map including the runtime input index and the notification data structure.*/ - std::unordered_map> parents_notify_; /*!\brief The execution count of the 'RunPipeline' function. */ uint32_t pipeline_execution_count_ = 0; /*! @@ -783,7 +821,6 @@ class BackendRuntime : public BasicRuntime { void StartWorkThread() { SetPipelineState(RUNNING); if (runtime_idx_ == 0) { - this->CreateParentsNotify(0, GLOBAL_MODULE_INDEX, 0); this->SetCPUAffinity(); } else { // Only launching the worker thread for the runtimes after the first runtime. @@ -799,11 +836,6 @@ class BackendRuntime : public BasicRuntime { } return; } - /*!\brief Checking if the pipeline is stopped or stopping.*/ - const bool PipelineIsStop() const { - auto state = pipeline_state_.load(std::memory_order_acquire); - return state == STOPPING || state == STOPPED; - } /*!\brief Setting the state of the pipeline.*/ void SetPipelineState(PipelineState state) { pipeline_state_.store(state, std::memory_order_release); @@ -871,34 +903,20 @@ class BackendRuntime : public BasicRuntime { bool ForwardingOutputDataToChildren(void) { for (auto child : children_) { auto output_idx = child.first; - if (output_queue_.find(output_idx) == output_queue_.end()) { + if (forward_queue_.find(output_idx) == forward_queue_.end()) { LOG(FATAL) << "Not find the forwarding queue map for output(" << output_idx << ")!"; return false; } NDArray output = GetOutput(output_idx); - auto forward_queue_map = output_queue_[output_idx]; + auto forward_queue_map = forward_queue_[output_idx]; // Notifying the 'children runtime' that the forwarding data are ready. for (auto module_pair : child.second) { auto child_runtime = module_pair.first; - auto child_runtime_index = child_runtime->GetModuleIndex(); auto child_input_index = module_pair.second; - auto queue_id = GenerateQueueID(child_runtime_index, child_input_index, INPUT); - if (forward_queue_map.find(queue_id) == forward_queue_map.end()) { - LOG(FATAL) << "Not find the associated queue of the runtime(" << child_runtime_index - << ").input(" << child_input_index << ") which is connected with runtime(" - << runtime_idx_ << ").output(" << output_idx << ")"; - } - auto forward_queue = forward_queue_map[queue_id]; - // If the queue is full, keep try until the push get success or the pipeline run into - // a STOP state. - while (!forward_queue->Push(output)) { - if (PipelineIsStop()) { - LOG(INFO) << "The forwarding process is stopped after the pipeline status is changed" - << " into stop."; - return false; - } + auto output_data = const_cast(output.operator->()); + if (!ForwardData(&forward_queue_map, child_runtime, child_input_index, output_data)) { + return false; } - child_runtime->ParentNotify(child_input_index); } } return true; @@ -974,22 +992,6 @@ class BackendRuntime : public BasicRuntime { } StopPipeline(); } - /*! - *\brief Creating a parent notification. - *\param input_index The input index of the 'current runtime'. - *\param parent_idx The index of 'parent runtime' which will send the notification. - *\param parent_output_idx The output index of the 'parent runtime' which will send - * the notification. - */ - void CreateParentsNotify(int input_index, int parent_idx, int parent_output_idx) { - if (parents_notify_.find(input_index) != parents_notify_.end()) { - LOG(FATAL) << "The notification associated with the input interface " << input_index - << " in runtime " << runtime_idx_ << " already been created!"; - return; - } - parents_notify_[input_index] = - std::make_shared(ModuleInterfaceID(parent_idx, parent_output_idx, OUTPUT)); - } /*! * \brief Getting the times of using pipeline function. * \return The times of using pipeline function. @@ -1002,7 +1004,7 @@ class BackendRuntime : public BasicRuntime { */ void InitializePipeline(ConfigPipelineExecution config, std::vector>* runtimes, - std::shared_ptr global_runtime) { + std::shared_ptr global_runtime) { // Getting the current BackendRuntime's cpu affinity setting. cpu_affinity_ = config.GetCPUAffinity(runtime_idx_); // Getting the 'binding configuration' for each child runtime. @@ -1061,7 +1063,7 @@ class BackendRuntime : public BasicRuntime { int NumOutputs() const { return get_num_output_(); } /*!\brief Return the number of input*/ int NumInputs() const { return get_num_inputs_(); } - /*!\brief Setting the data to this module via input index.*/ + /*!\brief Setting the data to this runtime via input index.*/ void SetInput(const int index, DLTensor* data_in) { NDArray input = get_input_(index); DLTensor* dltensor_input = const_cast(input.operator->()); @@ -1091,6 +1093,99 @@ class BackendRuntime : public BasicRuntime { return ret; } }; +/*! + * \brief This global runtime represents the pipeline executor and exposes the input and output + * interface. + */ +class GlobalRuntime : public BasicRuntime { + public: + explicit GlobalRuntime(int runtime_idx) : BasicRuntime(runtime_idx) {} + /**/ + std::vector> GetRuntimeList() { return runtimes_; } + /*!\brief Push the data into the queue for the current runtime.*/ + void SetPipelineInput(const std::string input_name, DLTensor* data_in) { + auto input_index = input_config_.GetInputIndex(input_name); + auto child_iter = children_.find(input_index); + if (child_iter == children_.end()) { + return; + } + auto forward_queue_map = forward_queue_[input_index]; + // Notifying the 'children runtime' that the forwarding data are ready. + for (auto module_pair : child_iter->second) { + auto child_runtime = module_pair.first; + auto child_input_index = module_pair.second; + // No need to go through the forward queue when the runtime is the first one. + if (child_runtime->GetModuleIndex() == 0) { + child_runtime->SetInput(child_input_index, data_in); + } else { + if (!ForwardData(&forward_queue_map, child_runtime, child_input_index, data_in)) { + return; + } + } + } + return; + } + /*!\brief Whether the output data is ready.*/ + bool DataIsReady(bool wait_data) { + bool data_ready = true; + for (auto queue_pair : input_queue_) { + auto queue = queue_pair.second; + if (queue->Empty()) { + data_ready = false; + break; + } + } + if (!data_ready && wait_data) { + // TODO(huajsj): Waitting the data ready message. + } + return data_ready; + } + /*!\brief Get the output data.*/ + bool GetOutput(Array* outputs, bool wait_data = false) { + if (!DataIsReady(wait_data)) { + return false; + } + for (auto queue_pair : input_queue_) { + auto output_index = queue_pair.first; + auto queue = queue_pair.second; + QueueData data(const_cast(((*outputs)[output_index]).operator->())); + if (!queue->Poll(&data)) { + LOG(FATAL) << "There is no data in the data queue, it should not happen!"; + } + } + return true; + } + /*!\brief Initialized the data structures for pipeline.*/ + void InitializePipeline(InputConnectionConfig input_config, + const std::vector> runtimes) { + input_config_ = input_config; + runtimes_ = runtimes; + for (auto child_runtime : runtimes) { + int runtime_idx = child_runtime->GetModuleIndex(); + input_config.VisitConfig( + [&](int input_index, int child_idx, std::string child_input_name) { + auto child_input_index = child_runtime->GetInputIndex(child_input_name); + if (child_input_index < 0) { + LOG(FATAL) << "Can not find the input " << child_input_name << "in runtime " + << child_idx; + } + children_[input_index].push_back(std::make_pair(child_runtime, child_input_index)); + // Only create notify and queue for the runtime after the first runtime. + if (runtime_idx != 0) { + child_runtime->CreateParentsNotify(input_index, GLOBAL_MODULE_INDEX, + child_input_index); + // Creating the pipeline forwarding queue. + this->CreateForwardingQueue(input_index, child_runtime, child_input_index); + } + }, + runtime_idx); + } + } + + private: + std::vector> runtimes_; + InputConnectionConfig input_config_; +}; /*! * \brief The information used to initialize the graph executor module, the information * come from the export library function call. diff --git a/tests/python/relay/test_pipeline_executor.py b/tests/python/relay/test_pipeline_executor.py index cc58b8128e24c..71be7c72cc098 100644 --- a/tests/python/relay/test_pipeline_executor.py +++ b/tests/python/relay/test_pipeline_executor.py @@ -372,6 +372,7 @@ def test_pipeline(): assert module_index == 0 # Using the parameters group name to set parameters. pipeline_module_test.set_params("param_0", customized_parameters) + normal_outputs = [] for round in range(0, len(datas)): data = datas[round] # Getting the result without setting customized parameters. @@ -398,27 +399,34 @@ def test_pipeline(): customized_parameters_mod, customized_parameters, ) + # Append the normal output into a list to do future correctness verification. + normal_outputs.append(normal_output) + # Set the input data into pipeline executor. pipeline_module_test.set_input("data_a", data) pipeline_module_test.set_input("data_b", data) - input_data = pipeline_module_test.get_input("data_a") - tvm.testing.assert_allclose(data, input_data.numpy()) + input_map = pipeline_module_test.get_input_pipeline_map("data_a") + # The input data will be set into runtime directly when the index of runtime is 0 + if input_map[0] == '0': + input_data = pipeline_module_test.get_input("data_a") + tvm.testing.assert_allclose(data, input_data.numpy()) # Running the pipeline executor in the pipeline mode. pipeline_module_test.run() + for k in range(0, len(datas)): statistic_time = 0 outputs = pipeline_module_test.get_output() while len(outputs) == 0: outputs = pipeline_module_test.get_output() statistic_time = statistic_time + 1 # Setting the timeout to 10 seconds. - assert statistic_time < 10 + assert statistic_time < 5 time.sleep(1) for i in range(len(outputs)): - tvm.testing.assert_allclose(normal_output[i], outputs[i].numpy()) + tvm.testing.assert_allclose(normal_outputs[k][i], outputs[i].numpy()) assert not (normal_output[i] == wrong_output[i]).all() - assert pipeline_module_test.num_executing_pipeline == round + 1 + assert pipeline_module_test.num_executing_pipeline == round + 1 # Reset the cpu affinity after a test. reset_cpu_affinity(affinity)