Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Runtime][PipelineExecutor]Add forwarding queue logic for set input. #10990

Merged
merged 2 commits into from
Apr 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions python/tvm/contrib/pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,7 @@ def set_input(self, key, value):
value : array_like.
The input value
"""
v = self._get_input(key)
if v is None:
raise RuntimeError("Could not find '%s' in pipeline's inputs" % key)
v.copyfrom(value)
self._set_input(key, tvm.nd.array(value))

def set_params(self, params_group_name, params_data):
"""Set the parameter group value given the parameter group name. Note that the parameter
Expand Down
20 changes: 9 additions & 11 deletions src/runtime/pipeline/pipeline_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,7 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
* \param data_in The input data.
*/
void PipelineExecutor::SetInput(std::string input_name, DLTensor* data_in) {
std::pair<int, int> indexs = this->GetInputIndex(input_name);
if (indexs.first < 0 || indexs.first >= static_cast<int>(runtimes_.size())) {
LOG(FATAL) << "input name " << input_name << " not found.";
}
runtimes_[indexs.first]->SetInput(indexs.second, data_in);
global_runtime_->SetPipelineInput(input_name, data_in);
}
/*!
* \brief get input from the runtime module.
Expand All @@ -118,7 +114,7 @@ NDArray PipelineExecutor::GetInput(std::string input_name) {
* \return int The module index.
*/
int PipelineExecutor::GetParamModuleIndex(const std::string& name) {
return param_connection_config[name];
return param_connection_config_[name];
}
/*!
* \brief Using the global input name to get the index, and also get the input interface name
Expand All @@ -127,7 +123,7 @@ int PipelineExecutor::GetParamModuleIndex(const std::string& name) {
* \return Returning the index and the input interface name of corresponding subgraph.
*/
Array<String> PipelineExecutor::GetInputPipeplineMap(std::string input_name) {
std::pair<int, std::string> map = input_connection_config[input_name];
std::pair<int, std::string> map = input_connection_config_[input_name];
return {std::to_string(map.first), map.second};
}

Expand All @@ -137,11 +133,11 @@ Array<String> PipelineExecutor::GetInputPipeplineMap(std::string input_name) {
* \return int The module index.
*/
int PipelineExecutor::GetParamsGroupPipelineMap(const std::string& name) {
return param_connection_config[name];
return param_connection_config_[name];
}

/*!\brief Run the pipeline executor.*/
void PipelineExecutor::Run() { pipeline_scheduler_.PipelineRun(runtimes_, pipeline_config_); }
void PipelineExecutor::Run() { pipeline_scheduler_.PipelineRun(runtimes_); }
/*!
* \brief return A list of global output data.
*/
Expand Down Expand Up @@ -226,7 +222,7 @@ void PipelineExecutor::SetParam(std::string param_group_name, std::string param_
* \return std::pair<int, int> A pair of module index and the input index.
*/
std::pair<int, int> PipelineExecutor::GetInputIndex(const std::string& name) {
std::pair<int, std::string> index = input_connection_config[name];
std::pair<int, std::string> index = input_connection_config_[name];
auto gruntime = runtimes_[index.first];
return std::make_pair(index.first, gruntime->GetInputIndex(index.second));
}
Expand All @@ -250,7 +246,9 @@ void PipelineExecutor::Init(const std::vector<Module>& modules, const std::strin
num_outputs_ = pipeline_config_.GetGlobalOutputNum();
// Initialize the pipeline function class used for pipeline thread pool management
// and schedule etc. This function returns a list of runtime.
runtimes_ = pipeline_scheduler_.PipelineInit(modules, pipeline_config_);
global_runtime_ =
pipeline_scheduler_.PipelineInit(modules, pipeline_config_, input_connection_config_);
runtimes_ = global_runtime_->GetRuntimeList();
return;
}

Expand Down
9 changes: 5 additions & 4 deletions src/runtime/pipeline/pipeline_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,16 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
/*!\brief The dependency information of each graph runtime module of the pipeline.*/
ConfigPipelineExecution pipeline_config_;
/*!\brief The map of global input and subgraph input.*/
InputConnectionConfig input_connection_config;
InputConnectionConfig input_connection_config_;
/*!\brief The map includes global parameters groups and runtime modules.*/
ParamConnectionConfig param_connection_config;
ParamConnectionConfig param_connection_config_;
/*!\brief The module information used to create the graph runtimes.*/
ModuleConfig mod_config_;
/*!\brief How many outputs are in this pipeline executor.*/
size_t num_outputs_ = 0;
/*!The list of backend runtime module.*/
std::vector<std::shared_ptr<BackendRuntime>> runtimes_;
std::shared_ptr<GlobalRuntime> global_runtime_;
/*!\brief Json loader.*/
void LoadConfig(dmlc::JSONReader* reader) {
reader->BeginObject();
Expand All @@ -193,9 +194,9 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
if (key == "module_connection") {
reader->Read(&pipeline_config_);
} else if (key == "input_connection") {
reader->Read(&input_connection_config);
reader->Read(&input_connection_config_);
} else if (key == "param_connection") {
reader->Read(&param_connection_config);
reader->Read(&param_connection_config_);
} else {
LOG(FATAL) << "do not support key " << key;
}
Expand Down
15 changes: 9 additions & 6 deletions src/runtime/pipeline/pipeline_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,20 @@ namespace runtime {
* \param modules The list of graph executor modules.
* \param pipeline_conf The dependency information of each graph executor module.
*/
std::vector<std::shared_ptr<BackendRuntime>> PipelineScheduler::PipelineInit(
const std::vector<Module>& modules, const ConfigPipelineExecution& pipeline_config) {
std::shared_ptr<GlobalRuntime> PipelineScheduler::PipelineInit(
const std::vector<Module>& modules, const ConfigPipelineExecution& pipeline_config,
const InputConnectionConfig& input_connection_config) {
std::vector<std::shared_ptr<BackendRuntime>> runtimes;
graph_modules_ = modules;
global_runtime_ = std::make_shared<GlobalRuntime>(GLOBAL_MODULE_INDEX);
// Creating a list of runtimes.
for (size_t i = 0; i < graph_modules_.size(); i++) {
auto run_item = std::make_shared<BackendRuntime>(graph_modules_[i], i);
runtimes.push_back(run_item);
}
// Creating the global runtime to represent the pipeline executor.
global_runtime_ = std::make_shared<GlobalRuntime>(GLOBAL_MODULE_INDEX);
// Initializing the data structures used by pipeline logic.
global_runtime_->InitializePipeline(input_connection_config, runtimes);
// Creating a list of NDArray in order to storage the outputs data.
auto global_output_map = pipeline_config.GetGlobalConfigOutputBindings();
for (size_t i = 0; i < global_output_map.size(); i++) {
Expand All @@ -52,15 +56,14 @@ std::vector<std::shared_ptr<BackendRuntime>> PipelineScheduler::PipelineInit(
for (auto runtime : runtimes) {
runtime->InitializePipeline(pipeline_config, &runtimes, global_runtime_);
}
return runtimes;
return global_runtime_;
}
/*!
* \brief Running pipeline logic.
* \param runtimes A list of backend runtime modules.
* \param pipeline_config The dependency configuration of each runtime module.
*/
void PipelineScheduler::PipelineRun(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes,
ConfigPipelineExecution pipeline_config) {
void PipelineScheduler::PipelineRun(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes) {
runtimes.front()->RunPipeline();
}
/*!
Expand Down
9 changes: 4 additions & 5 deletions src/runtime/pipeline/pipeline_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,14 @@ class PipelineScheduler {
* \param modules The list of graph executor module.
* \param pipeline_config The dependency information of each graph executor module.
*/
std::vector<std::shared_ptr<BackendRuntime>> PipelineInit(
const std::vector<Module>& modules, const ConfigPipelineExecution& pipeline_config);
std::shared_ptr<GlobalRuntime> PipelineInit(const std::vector<Module>& modules,
const ConfigPipelineExecution& pipeline_config,
const InputConnectionConfig& input_connection_config);
/*!
* \brief Running the pipeline logic.
* \param runtimes A list of backend runtime modules.
* \param pipeline_config The dependency configuration of each runtime module.
*/
void PipelineRun(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes,
ConfigPipelineExecution pipeline_config);
void PipelineRun(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes);
/*!
* \brief Get a list of outputs.
*/
Expand Down
Loading