Skip to content

Commit

Permalink
[Runtime][PipelineExecutor]Add forwarding queue logic for set input.
Browse files Browse the repository at this point in the history
When the set_input function get called, a runtime of pipeline may not
yet finish the former computation work then the new set_input call would
break the current computation logic, to avoid such issue, we add the
forwarding queue logic to guarantee the order of input data consuming.
  • Loading branch information
huajsj committed Apr 13, 2022
1 parent bf7a27b commit 33540ad
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 130 deletions.
5 changes: 1 addition & 4 deletions python/tvm/contrib/pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,7 @@ def set_input(self, key, value):
value : array_like.
The input value
"""
v = self._get_input(key)
if v is None:
raise RuntimeError("Could not find '%s' in pipeline's inputs" % key)
v.copyfrom(value)
self._set_input(key, tvm.nd.array(value))

def set_params(self, params_group_name, params_data):
"""Set the parameter group value given the parameter group name. Note that the parameter
Expand Down
20 changes: 9 additions & 11 deletions src/runtime/pipeline/pipeline_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,7 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
* \param data_in The input data.
*/
void PipelineExecutor::SetInput(std::string input_name, DLTensor* data_in) {
std::pair<int, int> indexs = this->GetInputIndex(input_name);
if (indexs.first < 0 || indexs.first >= static_cast<int>(runtimes_.size())) {
LOG(FATAL) << "input name " << input_name << " not found.";
}
runtimes_[indexs.first]->SetInput(indexs.second, data_in);
global_runtime_->SetPipelineInput(input_name, data_in);
}
/*!
* \brief get input from the runtime module.
Expand All @@ -118,7 +114,7 @@ NDArray PipelineExecutor::GetInput(std::string input_name) {
* \return int The module index.
*/
int PipelineExecutor::GetParamModuleIndex(const std::string& name) {
return param_connection_config[name];
return param_connection_config_[name];
}
/*!
* \brief Using the global input name to get the index, and also get the input interface name
Expand All @@ -127,7 +123,7 @@ int PipelineExecutor::GetParamModuleIndex(const std::string& name) {
* \return Returning the index and the input interface name of corresponding subgraph.
*/
Array<String> PipelineExecutor::GetInputPipeplineMap(std::string input_name) {
std::pair<int, std::string> map = input_connection_config[input_name];
std::pair<int, std::string> map = input_connection_config_[input_name];
return {std::to_string(map.first), map.second};
}

Expand All @@ -137,11 +133,11 @@ Array<String> PipelineExecutor::GetInputPipeplineMap(std::string input_name) {
* \return int The module index.
*/
int PipelineExecutor::GetParamsGroupPipelineMap(const std::string& name) {
return param_connection_config[name];
return param_connection_config_[name];
}

/*!\brief Run the pipeline executor.*/
void PipelineExecutor::Run() { pipeline_scheduler_.PipelineRun(runtimes_, pipeline_config_); }
void PipelineExecutor::Run() { pipeline_scheduler_.PipelineRun(runtimes_); }
/*!
* \brief return A list of global output data.
*/
Expand Down Expand Up @@ -226,7 +222,7 @@ void PipelineExecutor::SetParam(std::string param_group_name, std::string param_
* \return std::pair<int, int> A pair of module index and the input index.
*/
std::pair<int, int> PipelineExecutor::GetInputIndex(const std::string& name) {
std::pair<int, std::string> index = input_connection_config[name];
std::pair<int, std::string> index = input_connection_config_[name];
auto gruntime = runtimes_[index.first];
return std::make_pair(index.first, gruntime->GetInputIndex(index.second));
}
Expand All @@ -250,7 +246,9 @@ void PipelineExecutor::Init(const std::vector<Module>& modules, const std::strin
num_outputs_ = pipeline_config_.GetGlobalOutputNum();
// Initialize the pipeline function class used for pipeline thread pool management
// and schedule etc. This function returns a list of runtime.
runtimes_ = pipeline_scheduler_.PipelineInit(modules, pipeline_config_);
global_runtime_ =
pipeline_scheduler_.PipelineInit(modules, pipeline_config_, input_connection_config_);
runtimes_ = global_runtime_->GetRuntimeList();
return;
}

Expand Down
9 changes: 5 additions & 4 deletions src/runtime/pipeline/pipeline_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,16 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
/*!\brief The dependency information of each graph runtime module of the pipeline.*/
ConfigPipelineExecution pipeline_config_;
/*!\brief The map of global input and subgraph input.*/
InputConnectionConfig input_connection_config;
InputConnectionConfig input_connection_config_;
/*!\brief The map includes global parameters groups and runtime modules.*/
ParamConnectionConfig param_connection_config;
ParamConnectionConfig param_connection_config_;
/*!\brief The module information used to create the graph runtimes.*/
ModuleConfig mod_config_;
/*!\brief How many outputs are in this pipeline executor.*/
size_t num_outputs_ = 0;
/*!The list of backend runtime module.*/
std::vector<std::shared_ptr<BackendRuntime>> runtimes_;
std::shared_ptr<GlobalRuntime> global_runtime_;
/*!\brief Json loader.*/
void LoadConfig(dmlc::JSONReader* reader) {
reader->BeginObject();
Expand All @@ -193,9 +194,9 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
if (key == "module_connection") {
reader->Read(&pipeline_config_);
} else if (key == "input_connection") {
reader->Read(&input_connection_config);
reader->Read(&input_connection_config_);
} else if (key == "param_connection") {
reader->Read(&param_connection_config);
reader->Read(&param_connection_config_);
} else {
LOG(FATAL) << "do not support key " << key;
}
Expand Down
15 changes: 9 additions & 6 deletions src/runtime/pipeline/pipeline_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,20 @@ namespace runtime {
* \param modules The list of graph executor modules.
* \param pipeline_conf The dependency information of each graph executor module.
*/
std::vector<std::shared_ptr<BackendRuntime>> PipelineScheduler::PipelineInit(
const std::vector<Module>& modules, const ConfigPipelineExecution& pipeline_config) {
std::shared_ptr<GlobalRuntime> PipelineScheduler::PipelineInit(
const std::vector<Module>& modules, const ConfigPipelineExecution& pipeline_config,
const InputConnectionConfig& input_connection_config) {
std::vector<std::shared_ptr<BackendRuntime>> runtimes;
graph_modules_ = modules;
global_runtime_ = std::make_shared<GlobalRuntime>(GLOBAL_MODULE_INDEX);
// Creating a list of runtimes.
for (size_t i = 0; i < graph_modules_.size(); i++) {
auto run_item = std::make_shared<BackendRuntime>(graph_modules_[i], i);
runtimes.push_back(run_item);
}
// Creating the global runtime to represent the pipeline executor.
global_runtime_ = std::make_shared<GlobalRuntime>(GLOBAL_MODULE_INDEX);
// Initialize the data structures that are used by pipeline logic.
global_runtime_->InitializePipeline(input_connection_config, runtimes);
// Creating a list of NDArray in order to storage the outputs data.
auto global_output_map = pipeline_config.GetGlobalConfigOutputBindings();
for (size_t i = 0; i < global_output_map.size(); i++) {
Expand All @@ -52,15 +56,14 @@ std::vector<std::shared_ptr<BackendRuntime>> PipelineScheduler::PipelineInit(
for (auto runtime : runtimes) {
runtime->InitializePipeline(pipeline_config, &runtimes, global_runtime_);
}
return runtimes;
return global_runtime_;
}
/*!
* \brief Running pipeline logic.
* \param runtimes A list of backend runtime modules.
* \param pipeline_config The dependency configuration of each runtime module.
*/
void PipelineScheduler::PipelineRun(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes,
ConfigPipelineExecution pipeline_config) {
void PipelineScheduler::PipelineRun(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes) {
runtimes.front()->RunPipeline();
}
/*!
Expand Down
9 changes: 4 additions & 5 deletions src/runtime/pipeline/pipeline_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,14 @@ class PipelineScheduler {
* \param modules The list of graph executor module.
* \param pipeline_config The dependency information of each graph executor module.
*/
std::vector<std::shared_ptr<BackendRuntime>> PipelineInit(
const std::vector<Module>& modules, const ConfigPipelineExecution& pipeline_config);
std::shared_ptr<GlobalRuntime> PipelineInit(const std::vector<Module>& modules,
const ConfigPipelineExecution& pipeline_config,
const InputConnectionConfig& input_connection_config);
/*!
* \brief Running the pipeline logic.
* \param runtimes A list of backend runtime modules.
* \param pipeline_config The dependency configuration of each runtime module.
*/
void PipelineRun(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes,
ConfigPipelineExecution pipeline_config);
void PipelineRun(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes);
/*!
* \brief Get a list of outputs.
*/
Expand Down
Loading

0 comments on commit 33540ad

Please sign in to comment.