Skip to content

Commit

Permalink
Introduce memory efficient topological sort (microsoft#20258)
Browse files Browse the repository at this point in the history
### Introduce memory efficient topo sort (for training)

~~and laze initialize Priority-Based and Memory-Efficient topo sort.
Because in most cases, they are not needed, so we free the overheads of
GraphViewer construction for most use cases.~~

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
pengwa authored Apr 23, 2024
1 parent 9372e9a commit a7787a0
Show file tree
Hide file tree
Showing 17 changed files with 655 additions and 156 deletions.
3 changes: 0 additions & 3 deletions include/onnxruntime/core/graph/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,4 @@ constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider";
constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path";
constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point";

// For Priority based graph topology sorting.
constexpr const char* kBackwardNodeAttributeName = "__backwardpass";

} // namespace onnxruntime
13 changes: 13 additions & 0 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,19 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi

#endif

#ifdef ENABLE_TRAINING
/**
* @brief Performs topological sort with customized Kahn's algorithm on the graph/s.
* This is a specialized version for training where need memory efficient topological sort.
* @param yield_op The YieldOp used in ORTModule training.
* @param shape_size_parents The shape size parents nodes.
* @param node_orders The output node orders.
*/
void MemoryEfficientTopologicalSort(const Node* yield_op,
const InlinedHashMap<NodeIndex, InlinedVector<NodeIndex>>& shape_size_parents,
std::vector<NodeIndex>& node_orders) const;
#endif

/** Gets the map of operator domains to their opset versions. */
const std::unordered_map<std::string, int>& DomainToVersionMap() const noexcept {
return domain_to_version_;
Expand Down
5 changes: 5 additions & 0 deletions include/onnxruntime/core/graph/graph_viewer.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ class GraphViewer {
std::vector<NodeIndex> nodes_in_topological_order_with_priority_;
#endif

#ifdef ENABLE_TRAINING
// The NodeIndex values of the graph nodes sorted in memory efficient topological order.
std::vector<NodeIndex> nodes_in_mem_efficient_topological_order_;
#endif

// Graph root nodes.
std::vector<NodeIndex> root_nodes_;

Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/framework/session_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
namespace onnxruntime {

enum class ExecutionOrder {
DEFAULT = 0, // default topological sort
PRIORITY_BASED = 1 // priority-based topological sort
DEFAULT = 0, // default topological sort
PRIORITY_BASED = 1, // priority-based topological sort
MEMORY_EFFICIENT = 2, // memory-efficient topological sort for training purposes.
};

inline std::ostream& operator<<(std::ostream& os, const ExecutionOrder& order) {
Expand All @@ -34,6 +35,9 @@ inline std::ostream& operator<<(std::ostream& os, const ExecutionOrder& order) {
case ExecutionOrder::PRIORITY_BASED:
os << "PRIORITY_BASED";
break;
case ExecutionOrder::MEMORY_EFFICIENT:
os << "MEMORY_EFFICIENT";
break;
default:
os << "UNKNOWN";
break;
Expand Down
Loading

0 comments on commit a7787a0

Please sign in to comment.