Skip to content

Commit

Permalink
address review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
huajsj committed Dec 16, 2021
1 parent c9ea604 commit fe6a79a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 35 deletions.
30 changes: 15 additions & 15 deletions python/tvm/contrib/pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,16 @@ def check_binding_dict(self, connection_dict):
It is a dictionary of module connections.
"""
if "interface_name" not in connection_dict:
raise RuntimeError(f'"inteface_name" is missing in global config!"')
raise RuntimeError('"inteface_name" is missing in global config!"')
if "connection" not in connection_dict:
raise RuntimeError(f'"connection" is missing!"')
# The global interface mapping should be one-to-one.
if not connection_dict["connection"]:
raise RuntimeError(f"The global interface map is empty!")
raise RuntimeError("The global interface map is empty!")
if len(connection_dict["connection"]) > 1:
raise RuntimeError(f"A global interface maps multiple module interfaces!")
raise RuntimeError("A global interface maps multiple module interfaces!")
if "mod_idx" not in connection_dict["connection"][0]:
raise RuntimeError(f'"mod_idx" is missing!')
raise RuntimeError('"mod_idx" is missing!')

def get_binding_dict(self):
"""Returning the binding information in the form of dictionary.
Expand Down Expand Up @@ -309,34 +309,34 @@ def connect(self, binding):

# Check whether the binding setting is correct or not.
if self.io_owner == binding.io_owner:
raise RuntimeError(f"Can not bind itself.")
raise RuntimeError("Can not bind itself.")

if not self.is_pipeline_executor_interface() and self.io_type == "input":
raise RuntimeError(f"Module can only bind from output interface!")
raise RuntimeError("Module can only bind from output interface!")

if (
not self.is_pipeline_executor_interface()
and not binding.is_pipeline_executor_interface()
and binding.io_type == "output"
):
raise RuntimeError(f"Can not bind module output with another module output!")
raise RuntimeError("Can not bind module output with another module output!")

if (
not self.is_pipeline_executor_interface()
and binding.is_pipeline_executor_interface()
and binding.io_type == "input"
):
raise RuntimeError(f"Can not bind module output with pipeline input!")
raise RuntimeError("Can not bind module output with pipeline input!")

if self.is_pipeline_executor_interface() and self.io_type == "output":
raise RuntimeError(f"Global output can not be used as binding start point.")
raise RuntimeError("Global output can not be used as binding start point.")

if (
self.is_pipeline_executor_interface()
and self.io_type == "input"
and binding.io_type != "input"
):
raise RuntimeError(f"Global input can only bind with module input.")
raise RuntimeError("Global input can only bind with module input.")

self.bindings.append(binding)
if not self.is_pipeline_executor_interface():
Expand All @@ -358,7 +358,7 @@ def connect(self, binding):
if not self.check_dag_acyclic(
binding.io_owner, self.io_owner.input_bindings.bindings
):
raise RuntimeError(f"Illegal connection: Cause a cycle!")
raise RuntimeError("Illegal connection: Cause a cycle!")

class BindingList:
"""Container for bindings(input or output interface).
Expand Down Expand Up @@ -429,7 +429,7 @@ def __getitem__(self, key):

raise RuntimeError(f"{key} not found!")

raise RuntimeError(f'The data type of "key" is not supported!')
raise RuntimeError('The data type of "key" is not supported!')

def get_data_type(self, key, interface_type):
"""Get the module interface data type according to the key value and interface type.
Expand Down Expand Up @@ -584,7 +584,7 @@ def get_config(self):
for input_name in self.input_bindings.bindings:
input_dict = self.input_bindings.bindings[input_name].get_binding_dict()
if "interface_name" not in input_dict["connection"][0]:
raise RuntimeError(f"interface_name is missing in connection config!")
raise RuntimeError("interface_name is missing in connection config!")
# Creating the map of global interface and subgraph interface.
input_map = {
"global_interface_name": input_dict["interface_name"],
Expand Down Expand Up @@ -691,11 +691,11 @@ def export_library(self, directory_path):
Export the files to this directory.
"""
if not self.pipeline_mods:
raise RuntimeError(f"The pipeline executor has not been initialized.")
raise RuntimeError("The pipeline executor has not been initialized.")

# Check if the directory_path exists.
if not os.path.exists(directory_path):
raise RuntimeError(f"The directory {directory_path} does not exist.")
raise RuntimeError("The directory {directory_path} does not exist.")
# Create an load configuration.
load_config_file_name = "{}/load_config".format(directory_path)
pipeline_config_file_name = "{}/pipeline_config".format(directory_path)
Expand Down
40 changes: 20 additions & 20 deletions src/runtime/pipeline/pipeline_struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,6 @@
* \brief All binding information of a output interface.
*/
class ConfigBindings {
private:
/*!\brief Output interface binding information, 'int' is the index of the module that
* uses this output data as the input interface data, 'string' is the input interface name
* of the module.
*/
std::unordered_map<int, std::string> bindings_;
/*! The index value of the global interface to which the current output are bound.*/
int global_output_index_ = std::numeric_limits<int>::min();

public:
/*!\brief Whether this binding is bound to the PipelineExecutor output interface.*/
bool IsGlobalOutput() const { return global_output_index_ > GLOBAL_MODULE_INDEX; }
Expand Down Expand Up @@ -86,15 +77,20 @@ class ConfigBindings {
}
}
}

private:
/*!\brief Output interface binding information, 'int' is the index of the module that
* uses this output data as the input interface data, 'string' is the input interface name
* of the module.
*/
std::unordered_map<int, std::string> bindings_;
/*! The index value of the global interface to which the current output are bound.*/
int global_output_index_ = std::numeric_limits<int>::min();
};
/*!
* \brief The binding information of all outputs of a module.
*/
class ConfigOutputBindings {
private:
/*!\brief The map of output binding, 'int' is the output interface index.*/
std::unordered_map<int, ConfigBindings> output_binding_map_;

public:
ConfigOutputBindings& operator=(const ConfigOutputBindings& output) {
output_binding_map_ = output.GetOutBindings();
Expand Down Expand Up @@ -148,19 +144,16 @@ class ConfigOutputBindings {
output_binding_map_[output_idx] = binding;
}
}

private:
/*!\brief The map of output binding, 'int' is the output interface index.*/
std::unordered_map<int, ConfigBindings> output_binding_map_;
};

/*!
* \brief The binding or dependency information of each module output interface.
*/
class ConfigPipelineExecution {
private:
/*
*!\brief The key is the module index, this variable records all module pipeline configuration
* information.
*/
std::unordered_map<int, ConfigOutputBindings> config_;

public:
/*
*!\brief This function is used to verify whether config is loaded successfully.
Expand Down Expand Up @@ -208,6 +201,13 @@ class ConfigPipelineExecution {
config_[mod_idx] = output;
}
}

private:
/*
*!\brief The key is the module index, this variable records all module pipeline configuration
* information.
*/
std::unordered_map<int, ConfigOutputBindings> config_;
};

struct InputConnectionConfig {
Expand Down

0 comments on commit fe6a79a

Please sign in to comment.