diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 8e04050d089a0..c4a46cd422219 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -56,7 +56,4 @@ constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path"; constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point"; -// For Priority based graph topology sorting. -constexpr const char* kBackwardNodeAttributeName = "__backwardpass"; - } // namespace onnxruntime diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index d3c29e6a5d2a9..4f3377f0aa0c0 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1109,6 +1109,19 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi #endif +#ifdef ENABLE_TRAINING + /** + * @brief Performs topological sort with customized Kahn's algorithm on the graph/s. + * This is a specialized version for training where need memory efficient topological sort. + * @param yield_op The YieldOp used in ORTModule training. + * @param shape_size_parents The shape size parents nodes. + * @param node_orders The output node orders. + */ + void MemoryEfficientTopologicalSort(const Node* yield_op, + const InlinedHashMap>& shape_size_parents, + std::vector& node_orders) const; +#endif + /** Gets the map of operator domains to their opset versions. */ const std::unordered_map& DomainToVersionMap() const noexcept { return domain_to_version_; diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index 1023d50310181..1816099d3210f 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -207,6 +207,11 @@ class GraphViewer { std::vector nodes_in_topological_order_with_priority_; #endif +#ifdef ENABLE_TRAINING + // The NodeIndex values of the graph nodes sorted in memory efficient topological order. + std::vector nodes_in_mem_efficient_topological_order_; +#endif + // Graph root nodes. std::vector root_nodes_; diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 295d02a42ff83..0453a7ecac81f 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -22,8 +22,9 @@ namespace onnxruntime { enum class ExecutionOrder { - DEFAULT = 0, // default topological sort - PRIORITY_BASED = 1 // priority-based topological sort + DEFAULT = 0, // default topological sort + PRIORITY_BASED = 1, // priority-based topological sort + MEMORY_EFFICIENT = 2, // memory-efficient topological sort for training purposes. }; inline std::ostream& operator<<(std::ostream& os, const ExecutionOrder& order) { @@ -34,6 +35,9 @@ inline std::ostream& operator<<(std::ostream& os, const ExecutionOrder& order) { case ExecutionOrder::PRIORITY_BASED: os << "PRIORITY_BASED"; break; + case ExecutionOrder::MEMORY_EFFICIENT: + os << "MEMORY_EFFICIENT"; + break; default: os << "UNKNOWN"; break; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 62b5f7ad5da14..fec706b4ae9a4 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1879,6 +1879,389 @@ void Graph::KahnsTopologicalSort(const std::function& enter, } } +#ifdef ENABLE_TRAINING + +namespace { + +/** + * @brief The data struct to store the group of nodes. + * + * The group contains a set of nodes, its execution depends on either of backward input leaf nodes, + * graph inputs/initializers or other group node's output. + */ +struct GroupNode { + GroupNode(const InlinedVector& node_list) { + nodes = node_list; + + InlinedHashSet intermediate_args; + for (const Node* node : nodes) { + for (const NodeArg* arg : node->InputDefs()) { + if (intermediate_args.find(arg) == intermediate_args.end()) { + input_args.push_back(arg); + } + } + + for (const NodeArg* arg : node->OutputDefs()) { + intermediate_args.insert(arg); + } + + for (auto output_edge_it = node->OutputEdgesBegin(); output_edge_it != node->OutputEdgesEnd(); + ++output_edge_it) { + const Node* output_node = &output_edge_it->GetNode(); + // Only if the output arg is used by nodes outside the group, then it is an output arg. + if (std::find(nodes.begin(), nodes.end(), output_node) == nodes.end()) { + output_args.push_back(node->OutputDefs()[output_edge_it->GetSrcArgIndex()]); + } + } + } + } + + bool is_outputted{false}; + + InlinedVector input_args; + InlinedVector output_args; + + InlinedVector nodes; +}; + +void SortForwardNodesByReverseDFS(const Graph* graph, + const InlinedVector& forward_output_nodes, + const InlinedHashMap>& shape_size_parents, + InlinedHashSet& nodes_to_execute_before_yieldop, + std::vector& node_orders) { + // Note 1: YieldOp is the separator of forward and backward nodes. + // Note 2: While it is also possible some nodes not contributing to the forward output nodes will be + // executed before YieldOp, for example, if one forward node's output is used by Shape/Size, then + // the Shape/Size node should be executed before YieldOp to release the memory as soon as possible. + + // Reverse DFS from forward output nodes to find all "forward" nodes. + // The forward nodes are ordered by Reverse DFS tranverse. + graph->ReverseDFSFrom( + forward_output_nodes, + nullptr, + [&nodes_to_execute_before_yieldop, &node_orders](const Node* n) { + nodes_to_execute_before_yieldop.insert(n); + node_orders.push_back(n->Index()); + }, + NodeCompare()); + + for (const auto& parent_to_children_pair : shape_size_parents) { + const NodeIndex& parent_index = parent_to_children_pair.first; + if (nodes_to_execute_before_yieldop.find(graph->GetNode(parent_index)) == nodes_to_execute_before_yieldop.end()) { + continue; + } + + for (const NodeIndex& shape_size_node_index : parent_to_children_pair.second) { + const Node* shape_size_node = graph->GetNode(shape_size_node_index); + // If the Shape/Size is already in the node_orders, then skip it. + if (nodes_to_execute_before_yieldop.find(shape_size_node) != nodes_to_execute_before_yieldop.end()) { + continue; + } + + auto it = std::find(node_orders.begin(), node_orders.end(), parent_index); + ORT_ENFORCE(it != node_orders.end(), "Cannot find the parent node in the node orders."); + + node_orders.insert(it + 1, shape_size_node_index); + nodes_to_execute_before_yieldop.insert(shape_size_node); + } + } +} + +void PrepareToFindBranchGraph(const Graph* graph, + const InlinedHashSet& nodes_to_execute_before_yieldop, + InlinedVector& branch_graph_input_nodes, + InlinedVector& backward_node_in_degree, + std::queue& to_visit) { + for (auto& node : graph->Nodes()) { + // Ignore forward. + if (nodes_to_execute_before_yieldop.find(&node) != nodes_to_execute_before_yieldop.end()) { + continue; + } + + if (node.OpType() == "YieldOp") { + backward_node_in_degree[node.Index()] = 0; + to_visit.push(&node); + continue; + } + + size_t input_edge_count = node.GetInputEdgesCount(); + backward_node_in_degree[node.Index()] = input_edge_count; + + // A shortcut: input_edge_count could be 0 if it takes graph input directly. + if (input_edge_count == 0) { + branch_graph_input_nodes.push_back(&node); + continue; + } + + for (auto input_edge_it = node.InputEdgesBegin(); input_edge_it != node.InputEdgesEnd(); ++input_edge_it) { + const Node* input_node = &input_edge_it->GetNode(); + // If the input edge connect to forward nodes, then we remove the in_degree of the node. + if (nodes_to_execute_before_yieldop.find(input_node) != nodes_to_execute_before_yieldop.end()) { + input_edge_count--; + } + } + + backward_node_in_degree[node.Index()] = input_edge_count; + if (input_edge_count == 0) { + branch_graph_input_nodes.push_back(&node); + } + } +} + +void FindBranchGraph( + const InlinedVector& branch_graph_input_nodes, + const InlinedVector& backward_node_in_degree, + InlinedVector& branch_graph, + InlinedVector>& branch_subgraph_consumers) { + // Loop through the branch_graph_input_nodes to find the branch subgraphs by its output edges in BFS, + // and find the maximum self_contained subgraph taking the branch_graph_input_nodes as input nodes. + std::queue to_visit_queue; + InlinedVector in_degree_copy = backward_node_in_degree; + + // Add all nodes in branch_graph_input_nodes to the queue + for (auto branch_input_node : branch_graph_input_nodes) { + to_visit_queue.push(branch_input_node); + branch_graph.push_back(branch_input_node); + } + + while (!to_visit_queue.empty()) { + const Node* current = to_visit_queue.front(); + to_visit_queue.pop(); + + if (!current) continue; + + for (auto node_it = current->OutputNodesBegin(); node_it != current->OutputNodesEnd(); ++node_it) { + auto& node_in_degree = in_degree_copy[node_it->Index()]; + node_in_degree--; + + if (node_in_degree == 0) { + to_visit_queue.push(&*node_it); + branch_graph.push_back(&*node_it); + } + } + } + + // At this point, branch_graph is a big subgraph that contains all the nodes that are purely + // triggered by the branch_graph_input_nodes, other graph input/initializers and leaf nodes (for example Constant). + for (const Node* n : branch_graph) { + for (auto output_it = n->OutputEdgesBegin(); output_it != n->OutputEdgesEnd(); ++output_it) { + const Node* output_node = &output_it->GetNode(); + const size_t dest_in_port = output_it->GetDstArgIndex(); + if (std::find(branch_graph.begin(), branch_graph.end(), output_node) == branch_graph.end()) { + branch_subgraph_consumers.push_back({output_node, dest_in_port}); + } + } + } +} + +void TagNodeToAssociatedOutputs(const Graph* graph, + const InlinedHashSet& nodes_to_execute_before_yieldop, + const InlinedVector>& branch_subgraph_consumers, + const InlinedVector& branch_graph, + InlinedVector& group_node_collection, + InlinedHashMap& output_arg_to_grouped_node) { + // Reverse DFS from branch graph outputs (e.g. branch_subgraph_consumers) to tag each nodes: + // If one node N contributes to a graph output A, then we will tag A to N. + // If the node N contributes to multiple graph outputs A, B, C, then we will tag the A, B, C to N. + InlinedHashMap> node_to_its_associated_outputs; + node_to_its_associated_outputs.reserve(branch_graph.size()); + for (const auto& consumer : branch_subgraph_consumers) { + const NodeArg* output_arg = consumer.first->InputDefs()[consumer.second]; + const Node* end_node = graph->GetProducerNode(output_arg->Name()); + InlinedVector end_nodes{end_node}; + graph->ReverseDFSFrom( + end_nodes, + nullptr, + [&node_to_its_associated_outputs, &output_arg](const Node* n) { + node_to_its_associated_outputs[n].insert(output_arg); + }, + nullptr, + [&nodes_to_execute_before_yieldop](const Node*, const Node* to) -> bool { + if (nodes_to_execute_before_yieldop.find(to) != nodes_to_execute_before_yieldop.end()) { + return true; // Skip forward nodes. + } + + return false; + }); + } + + // Cluster the nodes in the branch_graph based on the associated outputs. + InlinedHashMap, InlinedVector> associated_outputs_to_nodes; + associated_outputs_to_nodes.reserve(node_to_its_associated_outputs.size()); + for (const auto& node : branch_graph) { + const std::set& associated_outputs = node_to_its_associated_outputs[node]; + associated_outputs_to_nodes[associated_outputs].push_back(node); + } + + // Finalize the subgraph inputs/output information. + group_node_collection.reserve(associated_outputs_to_nodes.size()); + for (auto& [associated_outputs, nodes] : associated_outputs_to_nodes) { + group_node_collection.push_back(nodes); + // Flatten the key into NodeArg* for better search. + GroupNode& grouped_node = group_node_collection.back(); + for (const auto& output_arg : grouped_node.output_args) { + output_arg_to_grouped_node.insert({output_arg, &grouped_node}); + } + } +} + +void UpdateBackwardInDegree(InlinedVector& backward_node_in_degree, + InlinedVector>& branch_subgraph_consumers) { + // For each GroupNode, its execution is non-blocking main critical path rooting from YieldOp. + // The only dependencies of a GroupNode is either graph input/initializer/forward nodes, or + // the output nodes of another GroupNode. + // So we treat those GroupNode(s) as a single unit that can be executed anytime when it is + // firstly needed by the main critipath path. + for (auto& [output_node, dest_in_port] : branch_subgraph_consumers) { + ORT_ENFORCE(backward_node_in_degree[output_node->Index()] > 0); + backward_node_in_degree[output_node->Index()]--; + } +} + +void OutputGroupedNodes(const Graph* graph, + const NodeArg* output_arg, + const InlinedHashMap& output_arg_to_grouped_node, + std::vector& node_orders, + InlinedVector& topo_order) { + ORT_ENFORCE(output_arg_to_grouped_node.find(output_arg) != output_arg_to_grouped_node.end(), + "output_arg_to_grouped_node does not contain output_arg named ", output_arg->Name()); + + GroupNode* grouped_node = output_arg_to_grouped_node.at(output_arg); + + if (grouped_node->is_outputted) { + return; + } + + for (const NodeArg* input_arg : grouped_node->input_args) { + if (!input_arg->Exists()) { + continue; + } + + auto it = output_arg_to_grouped_node.find(input_arg); + if (it != output_arg_to_grouped_node.end() && !it->second->is_outputted) { + OutputGroupedNodes(graph, input_arg, output_arg_to_grouped_node, node_orders, topo_order); + } + } + + for (const Node* n : grouped_node->nodes) { + node_orders.push_back(n->Index()); + topo_order.push_back(n->Index()); + } + + grouped_node->is_outputted = true; +} + +} // namespace + +void Graph::MemoryEfficientTopologicalSort(const Node* yield_op, + const InlinedHashMap>& shape_size_parents, + std::vector& node_orders) const { + /// Firstly, sort the forward nodes with customized ReverseDFS. + + const size_t num_nodes = NumberOfNodes(); + InlinedVector forward_output_nodes; + forward_output_nodes.reserve(yield_op->GetInputEdgesCount()); + for (auto input_it = yield_op->InputNodesBegin(); input_it != yield_op->InputNodesEnd(); ++input_it) { + forward_output_nodes.push_back(&*input_it); + } + + // Create a hash map (paired with node_orders) for cheaper search. + InlinedHashSet nodes_to_execute_before_yieldop; + nodes_to_execute_before_yieldop.reserve(num_nodes); + + SortForwardNodesByReverseDFS(this, forward_output_nodes, + shape_size_parents, + nodes_to_execute_before_yieldop, + node_orders); + + /// Secondly, sort the backward nodes with customized Kahn's algorithm. + + size_t num_of_backward_nodes = num_nodes - node_orders.size(); + InlinedVector backward_node_in_degree(MaxNodeIndex(), 0); + InlinedVector topo_order; + topo_order.reserve(num_of_backward_nodes); + std::queue to_visit; + + InlinedVector branch_graph_input_nodes; + branch_graph_input_nodes.reserve(num_of_backward_nodes); + PrepareToFindBranchGraph(this, + nodes_to_execute_before_yieldop, + branch_graph_input_nodes, + backward_node_in_degree, + to_visit); + + InlinedVector branch_graph; + branch_graph.reserve(num_of_backward_nodes); + InlinedVector> branch_subgraph_consumers; + FindBranchGraph(branch_graph_input_nodes, + backward_node_in_degree, + branch_graph, + branch_subgraph_consumers); + + // Cluster the nodes in the branch_graph based on the associated outputs. + InlinedVector group_node_collection; + InlinedHashMap output_arg_to_grouped_node; + TagNodeToAssociatedOutputs(this, + nodes_to_execute_before_yieldop, + branch_subgraph_consumers, + branch_graph, + group_node_collection, + output_arg_to_grouped_node); + + UpdateBackwardInDegree(backward_node_in_degree, branch_subgraph_consumers); + + while (!to_visit.empty()) { + const Node* current = to_visit.front(); + to_visit.pop(); + + if (!current) continue; + + for (auto input_edge_it = current->InputEdgesBegin(); input_edge_it != current->InputEdgesEnd(); + ++input_edge_it) { + const NodeArg* input_arg = current->InputDefs()[input_edge_it->GetDstArgIndex()]; + if (!input_arg->Exists()) { + continue; + } + + auto it = output_arg_to_grouped_node.find(input_arg); + if (it != output_arg_to_grouped_node.end() && !it->second->is_outputted) { + OutputGroupedNodes(this, input_arg, output_arg_to_grouped_node, node_orders, topo_order); + } + } + + node_orders.push_back(current->Index()); + + for (auto output_edge_it = current->OutputEdgesBegin(); output_edge_it != current->OutputEdgesEnd(); + ++output_edge_it) { + const Node* out_node = &output_edge_it->GetNode(); + auto& node_in_degree = backward_node_in_degree[out_node->Index()]; + node_in_degree--; + if (node_in_degree == 0) { + to_visit.push(out_node); + } + } + + topo_order.push_back(current->Index()); + } + + // For the group nodes that are not outputted, we need to output them. + // Hitting this code path means some nodes are consuming outputs of forward nodes, and their outputs + // are not used by main branch backward nodes. + for (const auto& [output_arg, grouped_node] : output_arg_to_grouped_node) { + if (!grouped_node->is_outputted) { + OutputGroupedNodes(this, output_arg, output_arg_to_grouped_node, node_orders, topo_order); + } + } + + if (num_of_backward_nodes != topo_order.size()) { + ORT_THROW("Some nodes for backward are not included in the topological sort: " + + std::to_string(num_of_backward_nodes) + " vs " + + std::to_string(topo_order.size())); + } +} + +#endif // ENABLE_TRAINING + GSL_SUPPRESS(es.84) // noisy warning about ignoring return value from insert(...) Status Graph::PerformTopologicalSortAndCheckIsAcyclic() { nodes_in_topological_order_.clear(); diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 119d420066a84..c639eeac5ea42 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -39,19 +39,6 @@ struct PriorityNodeCompare { return n1_priority > n2_priority; } -#ifdef ENABLE_TRAINING - // nodes of forward pass will be output first - auto n1_attrs = n1->GetAttributes(); - auto n2_attrs = n2->GetAttributes(); - int64_t n1_is_forward = static_cast(n1_attrs.find(kBackwardNodeAttributeName) == n1_attrs.cend()) || - (n1_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; - int64_t n2_is_forward = static_cast(n2_attrs.find(kBackwardNodeAttributeName) == n2_attrs.cend()) || - (n2_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; - if (n1_is_forward != n2_is_forward) { - return n2_is_forward > n1_is_forward; - } -#endif - // otherwise, nodes with lower index will be output first return n1->Index() > n2->Index(); } @@ -74,7 +61,10 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) : ConstGraphNodes::NodeFilterFunc(nullptr))}, filter_info_{filter_info} { std::vector leaf_nodes; + #ifdef ENABLE_TRAINING + const Node* yield_node = nullptr; + // Keep the info of shape and size nodes and their parents so that after topological sort, we can move them // right after their parents. This is to make sure the shape and size nodes are executed right after their parents // so it's possible the input tensor memory can be released as soon as possible. This is especially important @@ -101,6 +91,10 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) shape_size_parents[parent].push_back(node.Index()); } } + + if (node.OpType() == "YieldOp") { + yield_node = &node; + } #endif } @@ -111,6 +105,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) nodes_in_topological_order_.push_back(n->Index()); }, NodeCompare()); + #ifdef ENABLE_TRAINING auto original = std::move(nodes_in_topological_order_); nodes_in_topological_order_.reserve(original.size()); @@ -128,15 +123,35 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) } } } + #endif + #if !defined(ORT_MINIMAL_BUILD) - graph.KahnsTopologicalSort( + graph_->KahnsTopologicalSort( [this](const Node* n) { nodes_in_topological_order_with_priority_.push_back(n->Index()); }, PriorityNodeCompare()); #endif +#ifdef ENABLE_TRAINING + if (yield_node != nullptr) { + std::vector node_orders; + const size_t num_of_nodes = NumberOfNodes(); + node_orders.reserve(num_of_nodes); + graph_->MemoryEfficientTopologicalSort( + yield_node, + shape_size_parents, + node_orders); + + ORT_ENFORCE(node_orders.size() == num_of_nodes, + "Topological sort failed.", node_orders.size(), "!=", num_of_nodes); + nodes_in_mem_efficient_topological_order_ = std::move(node_orders); + } else { + nodes_in_mem_efficient_topological_order_ = nodes_in_topological_order_; + } +#endif + if (filter_info_) { // validate. if something is off here it's a bug in our code for (NodeIndex idx : filter_info->nodes) { @@ -195,11 +210,19 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) #if !defined(ORT_MINIMAL_BUILD) auto orig_priority_order = std::move(nodes_in_topological_order_with_priority_); - nodes_in_topological_order_with_priority_.reserve(filter_info->nodes.size()); + nodes_in_topological_order_with_priority_.reserve(filter_info_->nodes.size()); std::copy_if(orig_priority_order.cbegin(), orig_priority_order.cend(), std::back_inserter(nodes_in_topological_order_with_priority_), [this](NodeIndex idx) { return filtered_node_indices_.count(idx) != 0; }); #endif + +#ifdef ENABLE_TRAINING + auto orig_mem_efficient_order = std::move(nodes_in_mem_efficient_topological_order_); + nodes_in_mem_efficient_topological_order_.reserve(filter_info_->nodes.size()); + std::copy_if(orig_mem_efficient_order.cbegin(), orig_mem_efficient_order.cend(), + std::back_inserter(nodes_in_mem_efficient_topological_order_), + [this](NodeIndex idx) { return filtered_node_indices_.count(idx) != 0; }); +#endif } } @@ -291,9 +314,17 @@ const std::vector& GraphViewer::GetNodesInTopologicalOrder(ExecutionO switch (order) { case ExecutionOrder::DEFAULT: return nodes_in_topological_order_; -#if !defined(ORT_MINIMAL_BUILD) case ExecutionOrder::PRIORITY_BASED: +#if !defined(ORT_MINIMAL_BUILD) return nodes_in_topological_order_with_priority_; +#else + ORT_THROW("Priority based topological order is not enabled for ORT minimal build."); +#endif + case ExecutionOrder::MEMORY_EFFICIENT: +#ifdef ENABLE_TRAINING + return nodes_in_mem_efficient_topological_order_; +#else + ORT_THROW("Memory efficient topological order is not enabled for non-training build."); #endif default: ORT_THROW("Invalid ExecutionOrder"); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 236d2cfeb2b33..7fc6515d3d50a 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1381,7 +1381,8 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra py::enum_(m, "ExecutionOrder") .value("DEFAULT", ExecutionOrder::DEFAULT) - .value("PRIORITY_BASED", ExecutionOrder::PRIORITY_BASED); + .value("PRIORITY_BASED", ExecutionOrder::PRIORITY_BASED) + .value("MEMORY_EFFICIENT", ExecutionOrder::MEMORY_EFFICIENT); py::enum_(m, "OrtAllocatorType") .value("INVALID", OrtInvalidAllocator) diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 4b676021dfe6c..590d18be91bb2 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -2124,5 +2124,174 @@ TEST_F(GraphTest, SubgraphOutputIsOuterScopeValue) { ::testing::ContainsRegex("Subgraph output \\(.*\\) is an outer scope value being returned directly.")); } +#ifdef ENABLE_TRAINING + +TEST_F(GraphTest, GraphConstruction_MemoryEfficientTopologicalSort_Recompute) { + Model model("graph_1", false, *logger_); + auto& graph = model.MainGraph(); + + /* + | + node_0 (Identity) + / \ + node_1 (Identity) \ + | | + node_4 (Identity) | + | | + YieldOp recompute_node_1 + \ / + node_1_grad (Merge) + | + */ + + TypeProto tensor_int32; + tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); + tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + auto& input_arg0 = graph.GetOrCreateNodeArg("node_0_in_1", &tensor_int32); + auto& output_arg0 = graph.GetOrCreateNodeArg("node_0_out_1", &tensor_int32); + auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out_1", &tensor_int32); + auto& output_arg2 = graph.GetOrCreateNodeArg("node_2_out_1", &tensor_int32); + auto& output_arg4 = graph.GetOrCreateNodeArg("node_4_out_1", &tensor_int32); + auto& output_arg5 = graph.GetOrCreateNodeArg("node_yield_out_1", &tensor_int32); + auto& output_arg6 = graph.GetOrCreateNodeArg("node_5_out_1", &tensor_int32); + + graph.AddNode("node_0", "Identity_Fake", "node 0", {&input_arg0}, {&output_arg0}); + graph.AddNode("node_1", "Identity_Fake", "node 1", {&output_arg0}, {&output_arg1}); + graph.AddNode("recompute_node_1", "Identity_Fake", "recompute node 1", {&output_arg0}, {&output_arg2}); + + graph.AddNode("node_4", "Identity_Fake", "node 4", {&output_arg1}, {&output_arg4}); + + ONNX_NAMESPACE::AttributeProto full_shape_outputs; + const std::string attribute_name = "full_shape_outputs"; + full_shape_outputs.set_name(attribute_name); + full_shape_outputs.set_type(ONNX_NAMESPACE::AttributeProto::INTS); + full_shape_outputs.add_ints(static_cast(0)); + NodeAttributes attributes({{attribute_name, full_shape_outputs}}); + + graph.AddNode("node_yield", "YieldOp", "node yield", {&output_arg4}, {&output_arg5}, &attributes, kMSDomain); + graph.AddNode("node_1_grad", "Merge_Fake", "node_1 gradient", {&output_arg5, &output_arg2}, {&output_arg6}); + + auto status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + GraphViewer graph_viewer(graph); + + // MEMORY_EFFICIENT order + { + auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::MEMORY_EFFICIENT); + const std::vector expected_priority_based_order = + {"node_0", "node_1", "node_4", "node_yield", "recompute_node_1", "node_1_grad"}; + for (size_t i = 0; i < order.size(); ++i) { + auto node = graph.GetNode(order[i]); + EXPECT_TRUE(node->Name() == expected_priority_based_order[i]) << "MEMORY_EFFICIENT based execution order is wrong."; + } + } +} + +TEST_F(GraphTest, GraphConstruction_MemoryEfficientTopologicalSort_MultiLayerRecompute) { + Model model("graph_1", false, *logger_); + auto& graph = model.MainGraph(); + + /* + | + node_0 (Identity) + / \ + node_1 (Identity) \ + | \ \ + node_2 (Identity) \ \ + | \ \ \ + node_3 (Identity) \ \ \ + | \ \ \ \ + loss (Identity) \ \ \ \ + | | \ \ \ + YieldOp | | \ \ + \ / | \ | + loss_grad recom_node_3 | | + \ / | | + node_3_grad recom_node_2 | + \ / | + node_2_grad recom_node_1 + \ / + node_1_grad + | + */ + + TypeProto tensor_int32; + tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); + tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + // FW graph + auto& input_arg0 = graph.GetOrCreateNodeArg("node_0_in", &tensor_int32); + auto& output_arg0 = graph.GetOrCreateNodeArg("node_0_out", &tensor_int32); + auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out", &tensor_int32); + auto& output_arg2 = graph.GetOrCreateNodeArg("node_2_out", &tensor_int32); + auto& output_arg3 = graph.GetOrCreateNodeArg("node_3_out", &tensor_int32); + auto& output_loss = graph.GetOrCreateNodeArg("loss_out", &tensor_int32); + auto& output_yield = graph.GetOrCreateNodeArg("yield_out", &tensor_int32); + + graph.AddNode("node_0", "Identity_Fake", "node 0", {&input_arg0}, {&output_arg0}); + graph.AddNode("node_1", "Identity_Fake", "node 1", {&output_arg0}, {&output_arg1}); + graph.AddNode("node_2", "Identity_Fake", "node 2", {&output_arg1}, {&output_arg2}); + graph.AddNode("node_3", "Identity_Fake", "node 3", {&output_arg2}, {&output_arg3}); + graph.AddNode("loss", "Identity_Fake", "loss node", {&output_arg3}, {&output_loss}); + ONNX_NAMESPACE::AttributeProto full_shape_outputs; + const std::string attribute_name = "full_shape_outputs"; + full_shape_outputs.set_name(attribute_name); + full_shape_outputs.set_type(ONNX_NAMESPACE::AttributeProto::INTS); + full_shape_outputs.add_ints(static_cast(0)); + NodeAttributes attributes({{attribute_name, full_shape_outputs}}); + graph.AddNode("node_yield", "YieldOp", "node yield", {&output_loss}, {&output_yield}, &attributes, kMSDomain); + + // Recompute graph + auto& recomputed_arg3 = graph.GetOrCreateNodeArg("node_3_out_recomputed", &tensor_int32); + auto& recomputed_arg2 = graph.GetOrCreateNodeArg("node_2_out_recomputed", &tensor_int32); + auto& recomputed_arg1 = graph.GetOrCreateNodeArg("node_1_out_recomputed", &tensor_int32); + + graph.AddNode("node_3_recompute", "Identity_Fake", "node 3 recompute", {&output_arg2}, {&recomputed_arg3}); + graph.AddNode("node_2_recompute", "Identity_Fake", "node 2 recompute", {&output_arg1}, {&recomputed_arg2}); + graph.AddNode("node_1_recompute", "Identity_Fake", "node 1 recompute", {&output_arg0}, {&recomputed_arg1}); + + // BW Graph + auto& loss_grad_output = graph.GetOrCreateNodeArg("loss_grad_output", &tensor_int32); + auto& node_3_grad_output = graph.GetOrCreateNodeArg("node_3_grad_output", &tensor_int32); + auto& node_2_grad_output = graph.GetOrCreateNodeArg("node_2_grad_output", &tensor_int32); + auto& node_1_grad_output = graph.GetOrCreateNodeArg("node_1_grad_output", &tensor_int32); + + graph.AddNode("loss_grad", "Merge_Fake", "loss gradient", {&output_yield, &output_arg3}, {&loss_grad_output}); + graph.AddNode("node_3_grad", "Merge_Fake", "node 3 gradient", {&loss_grad_output, &recomputed_arg3}, {&node_3_grad_output}); + graph.AddNode("node_2_grad", "Merge_Fake", "node 2 gradient", {&node_3_grad_output, &recomputed_arg2}, {&node_2_grad_output}); + graph.AddNode("node_1_grad", "Merge_Fake", "node 1 gradient", {&node_2_grad_output, &recomputed_arg1}, {&node_1_grad_output}); + + auto status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + GraphViewer graph_viewer(graph); + + // MEMORY_EFFICIENT order + { + auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::MEMORY_EFFICIENT); + const std::vector expected_priority_based_order = { + "node_0", + "node_1", + "node_2", + "node_3", + "loss", + "node_yield", + "loss_grad", + "node_3_recompute", + "node_3_grad", + "node_2_recompute", + "node_2_grad", + "node_1_recompute", + "node_1_grad", + }; + for (size_t i = 0; i < order.size(); ++i) { + auto node = graph.GetNode(order[i]); + EXPECT_TRUE(node->Name() == expected_priority_based_order[i]) << "MEMORY_EFFICIENT based execution order is wrong."; + } + } +} + +#endif + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/common.h b/orttraining/orttraining/core/optimizer/memory_optimizer/common.h index 268ed84f7a85f..952fe49ffa657 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/common.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/common.h @@ -73,4 +73,6 @@ int ParseIntValueFromString(std::string_view str); Status ParseOptimizationConfigFromString(std::string_view memory_optimization_config, InlinedHashMap& cluster_id_to_config_map); +constexpr const ExecutionOrder TOPOLOGICAL_SORT_ALGORITHM = ExecutionOrder::MEMORY_EFFICIENT; + } // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 3d0fa942fd2d4..038ff0049b32a 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -49,7 +49,7 @@ void GetForwardOutputUsageMap(const GraphViewer& graph_viewer, ActivationUsedMap& fw_op_output_arg_used_map, InlinedHashMap& is_forward_nodes) { ORT_ENFORCE(boundary_op_order_in_topological_sort >= 0); - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(TOPOLOGICAL_SORT_ALGORITHM); is_forward_nodes.clear(); is_forward_nodes.reserve(node_ids.size()); @@ -128,7 +128,7 @@ Status GetStashedActivationCandidates(const GraphViewer& graph_viewer, return Status::OK(); } - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(TOPOLOGICAL_SORT_ALGORITHM); InlinedHashMap node_index_to_its_order_in_topological_sort_map; for (size_t i = 0; i < node_ids.size(); ++i) { @@ -171,52 +171,6 @@ Status GetStashedActivationCandidates(const GraphViewer& graph_viewer, return Status::OK(); } -Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified) { - // Find the YieldOp node. - Node* yield_op_node = nullptr; - for (auto& node : graph.Nodes()) { - if (node.OpType() == "YieldOp") { - yield_op_node = &node; - break; - } - } - - if (yield_op_node == nullptr) { - return Status::OK(); - } - - // Reverse BFS from YieldOp to find all "forward" nodes. - std::vector fw_nodes; - std::vector end_nodes{yield_op_node}; - graph.ReverseDFSFrom( - end_nodes, - nullptr, - [&fw_nodes](const Node* n) { - fw_nodes.push_back(n); - }, - nullptr); - - // Set the attribute to true for all backward nodes. - for (auto& node : graph.Nodes()) { - if (std::find(fw_nodes.begin(), fw_nodes.end(), &node) == fw_nodes.end()) { - auto& attrs = node.GetAttributes(); - if (attrs.count(kBackwardNodeAttributeName)) { - continue; - } - node.AddAttribute(kBackwardNodeAttributeName, static_cast(1)); - modified = true; - } else { - auto& attrs = node.GetAttributes(); - if (attrs.count(kBackwardNodeAttributeName)) { - node.ClearAttribute(kBackwardNodeAttributeName); - modified = true; - } - } - } - - return Status::OK(); -} - Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, const ProbeConfig& probe_config, const logging::Logger& logger, @@ -226,7 +180,7 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, InlinedHashMap>& candidate_output_args_map, MemoryOptimizationPlanner& memory_opt_planner) { - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(TOPOLOGICAL_SORT_ALGORITHM); // Find boundary ops between forward and backward pass, currently, it's limited to YieldOp. yield_op_order_in_topological_sort = -1; diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h index 3f0a1a9a96f88..ca1df0633eb8f 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h @@ -57,16 +57,6 @@ class MemoryRecord { int freq = 0; }; -/** - * @brief Reset `__backwardpass` attribute for all backward nodes in the graph. - * `__backwardpass` is used by Priority-Based topology sorting. - * - * @param graph To be scanned and modified. - * @param modified Whether the graph is modified. - * @return Status - */ -Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified); - /** * @brief Iterate the graph and find all possible memory optimization opportunities for related nodes. * diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc index ac619bdc390d3..e0c255e37daf3 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc @@ -146,9 +146,6 @@ bool MemoryOptimizer::ModifyGraph(Graph& graph, Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& logger) const { - // Reset the backward pass attribute for all nodes. - ORT_RETURN_IF_ERROR(optimizer::memory_optimizer::ResetNodeBackwardPassAttribute(graph, modified)); - LOGS(logger, VERBOSE) << "Memory optimization config: " << optimizer_config_ << ", probe level: " << static_cast(recompute_probe_config_.probe_level) << ", enable_transformer_layer_as_boundary:" @@ -189,44 +186,8 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve .IsOK()); // The second pass - apply the transformation. - // Note 1: Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated. - // The reason we do reversed topological order is that we want the later layers' recompute nodes can be appended - // earlier than the earlier layers, in this way, the execution order of later layers will be in front of the earlier - // layers. - // - // Note 2: Here we use default typo order (which tries to BFS from the outputs, - // so the nearest node to graph output will be visited last). So in reversed default typo order, - // the neareast node to graph output will be visited first. - // Imagine there is a such subgraph - // input1 input2 input3 - // \ | / - // multiple layers - // | - // node M - // labels-------|----- - // \ | | - // node1 | | - // \ | | - // node2 / | - // \ / | - // node loss / - // | / - // YieldOp node1_recompute - // | / - // \ node2 recompute - // \ / - // node loss_grad - // | - // critical grad path - // - // In PriorityBased order, node1 will be visited first, so it's recompute node node1_recompute will be added - // at last because we do this following reversed topological order. Then node1_recompute node will have lowest - // priority to execute, as a result, if at this time, the queue to visit contains only recompute nodes, then - // node1_recompute will be run at last, affecting the backward critical path, which is not what we want. - // Current workaround is to use default order, which will execute node1_recompute earlier than other recompute nodes - // in this case. - - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT); + const auto& node_ids = + graph_viewer.GetNodesInTopologicalOrder(optimizer::memory_optimizer::TOPOLOGICAL_SORT_ALGORITHM); for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { Node* p_node = graph.GetNode(node_ids[i]); if (p_node == nullptr) { @@ -327,7 +288,6 @@ Status MemoryOptimizer::CreateRecomputeGraph(Graph& graph, &node_to_duplicate->GetAttributes(), node_to_duplicate->Domain()); - recompute_node.SetPriority(static_cast(ExecutionPriority::LOCAL_LOW)); recompute_node.SetExecutionProviderType(node_to_duplicate->GetExecutionProviderType()); ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(recompute_node), "Failed to set op schema for added recompute node."); diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h index 1d837038e76c1..45d7c10cea41f 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h @@ -23,7 +23,7 @@ Find recompute subgraphs and enable them according to user configs. The way we c a. If yes, add it in the subgraph, and append its input in the queue to scan next; b. otherwise, stop collecting and return the subgraph (could be empty). 3. Pick up the input node from the queue, and do 2 again. The process ends when the queue is empty or 2.b happens. -4. Clone the recomputable subgraphs with lower node priority (to execute) and insert them back to the original graph. +4. Clone the recomputable subgraphs and insert them back to the original graph. */ class MemoryOptimizer : public GraphTransformer { diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc index 4ce896c5350b0..bd6f6a0c380ae 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc @@ -15,9 +15,10 @@ namespace onnxruntime::optimizer::memory_optimizer { -Status MemoryOptimizationPlanner::UpdateNodePlansFromExecutionPlan(const GraphViewer& graph_viewer, - const OrtValueNameIdxMap& ortvalue_name_to_idx_map, - const SequentialExecutionPlan& p_seq_exec_plan) { +Status MemoryOptimizationPlanner::UpdateNodePlansFromExecutionPlan( + const GraphViewer& graph_viewer, + const OrtValueNameIdxMap& ortvalue_name_to_idx_map, + const SequentialExecutionPlan& p_seq_exec_plan) { InlinedHashMap idx_to_ortvalue_name_map; for (const auto& entry : ortvalue_name_to_idx_map) { idx_to_ortvalue_name_map[entry.second] = entry.first; diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc index 3bcfbd324ba3c..35ecf1159d321 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc @@ -110,7 +110,7 @@ void FindLayerBoundaryLayerNormNodes( layer_boundary_ln_nodes.clear(); - const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(TOPOLOGICAL_SORT_ALGORITHM); for (auto node_index : node_topology_list) { auto& node = *graph_viewer.GetNode(node_index); diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index b4f2be4150256..8a493ed87a70e 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -255,14 +255,17 @@ def _get_session_config(self): session_options.enable_mem_pattern = False session_options.enable_mem_reuse = False session_options.use_deterministic_compute = _are_deterministic_algorithms_enabled() - # DEFAULT order is reversed DFS order, while PRIORITY_BASED order is forward BFS order. - # DEFAULT order is likely to be better than PRIORITY_BASED order on memory. However, our recompute feature - # requires PRIORITY_BASED order to work properly. So we use PRIORITY_BASED order when recompute is enabled. + # Enable memory efficient execution order for training if 1). memory efficient grad management is enabled + # or 2). memory optimizer is enabled. + use_memory_efficient_topo_sort = (self._export_mode == torch.onnx.TrainingMode.TRAINING) and ( + self._mem_efficient_grad_management_is_enabled or self._runtime_options.memory_optimizer_is_enabled() + ) session_options.execution_order = ( - onnxruntime.ExecutionOrder.PRIORITY_BASED - if self._runtime_options.memory_optimizer_is_enabled() + onnxruntime.ExecutionOrder.MEMORY_EFFICIENT + if use_memory_efficient_topo_sort else onnxruntime.ExecutionOrder.DEFAULT ) + # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. session_options.log_severity_level = int(self._debug_options.logging.log_level) diff --git a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc index 360095dea6697..c34d0be5657e6 100644 --- a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc @@ -60,7 +60,7 @@ TEST(MemoryOptimizerTests, GeluRecompute) { } } - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; const std::string alleviation_config("Gelu+:1:-1"); const std::string probe_config("1:0"); @@ -89,8 +89,6 @@ TEST(MemoryOptimizerTests, GeluRecompute) { } ASSERT_EQ(recompute_gelu_node->MutableInputDefs()[0]->Name(), original_gelu_node->MutableInputDefs()[0]->Name()); - ASSERT_EQ(recompute_gelu_node->Priority(), static_cast(ExecutionPriority::LOCAL_LOW)); - ASSERT_EQ(original_gelu_node->Priority(), static_cast(ExecutionPriority::DEFAULT)); } TEST(MemoryOptimizerTests, TileRecompute) { @@ -121,12 +119,11 @@ TEST(MemoryOptimizerTests, TileRecompute) { Node* recompute_tile_node{nullptr}; Node* original_tile_node{nullptr}; for (auto& node : graph.Nodes()) { - if (node.Priority() == static_cast(ExecutionPriority::LOCAL_LOW)) { - if (node.OpType().compare("Tile") == 0) { + if (node.OpType().compare("Tile") == 0) { + // if name ends with _recompute, it's the recomputed node + if (node.Name().find("_recompute") != std::string::npos) { recompute_tile_node = &node; - } - } else if (node.Priority() == static_cast(ExecutionPriority::DEFAULT)) { - if (node.OpType().compare("Tile") == 0) { + } else { original_tile_node = &node; } } @@ -146,10 +143,6 @@ TEST(MemoryOptimizerTests, TileRecompute) { ASSERT_EQ(recompute_expand_node->InputDefs()[0]->Name(), original_expand_node->InputDefs()[0]->Name()); ASSERT_EQ(query_layer_grad_node->InputDefs()[1]->Name(), recompute_tile_node->OutputDefs()[0]->Name()); - - ASSERT_EQ(recompute_tile_node->Priority(), static_cast(ExecutionPriority::LOCAL_LOW)); - ASSERT_EQ(original_tile_node->Priority(), static_cast(ExecutionPriority::DEFAULT)); - ASSERT_EQ(query_layer_grad_node->Priority(), static_cast(ExecutionPriority::DEFAULT)); } TEST(MemoryOptimizerTests, TransformerPerLayerRecompute) { @@ -225,22 +218,17 @@ TEST(MemoryOptimizerTests, TransformerPerLayerRecompute) { if (consumer->OpType().compare("LayerNormalization") == 0) { if (consumer->Name().find("_recompute") != std::string::npos) { recompute_ln_node = consumer; - ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::LOCAL_LOW)); recompute_ln_node_parent_add_or_ln_node = graph.GetProducerNode(consumer->InputDefs()[0]->Name()); ASSERT_TRUE(recompute_ln_node_parent_add_or_ln_node != nullptr); - ASSERT_EQ(recompute_ln_node_parent_add_or_ln_node->Priority(), static_cast(ExecutionPriority::DEFAULT)); ASSERT_TRUE(recompute_ln_node_parent_add_or_ln_node->Name().find("_recompute") == std::string::npos); } else { original_ln_node = consumer; - ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::DEFAULT)); original_ln_node_parent_add_or_ln_node = graph.GetProducerNode(consumer->InputDefs()[0]->Name()); ASSERT_TRUE(original_ln_node_parent_add_or_ln_node); - ASSERT_EQ(original_ln_node_parent_add_or_ln_node->Priority(), static_cast(ExecutionPriority::DEFAULT)); ASSERT_TRUE(original_ln_node_parent_add_or_ln_node->Name().find("_recompute") == std::string::npos); } } else if (consumer->OpType().compare("LayerNormalizationGrad") == 0) { input_layer_norm_grad_node = consumer; - ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::DEFAULT)); } } @@ -264,21 +252,17 @@ TEST(MemoryOptimizerTests, TransformerPerLayerRecompute) { if (consumer->OpType().compare("LayerNormalization") == 0) { if (consumer->Name().find("_recompute") != std::string::npos) { recompute_ln_node = consumer; - ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::LOCAL_LOW)); recompute_ln_node_parent_add_node = graph.GetProducerNode(consumer->InputDefs()[0]->Name()); ASSERT_TRUE(recompute_ln_node_parent_add_node); ASSERT_EQ(recompute_ln_node_parent_add_node->OpType(), "Add"); - ASSERT_EQ(recompute_ln_node_parent_add_node->Priority(), static_cast(ExecutionPriority::LOCAL_LOW)); ASSERT_TRUE(recompute_ln_node_parent_add_node->Name().find("_recompute") != std::string::npos); } else { original_ln_node = consumer; - ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::DEFAULT)); original_ln_node_parent_add_node = graph.GetProducerNode(consumer->InputDefs()[0]->Name()); ASSERT_TRUE(original_ln_node_parent_add_node); } } else if (consumer->OpType().compare("LayerNormalizationGrad") == 0) { ln_grad_node = consumer; - ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::DEFAULT)); } } @@ -294,7 +278,8 @@ TEST(MemoryOptimizerTests, TransformerPerLayerRecompute) { std::vector nodes_in_topological_order; nodes_in_topological_order.reserve(bw_nodes_in_expected_order.size()); - const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); // ExecutionOrder::PRIORITY_BASED + const auto& node_topology_list = + graph_viewer.GetNodesInTopologicalOrder(optimizer::memory_optimizer::TOPOLOGICAL_SORT_ALGORITHM); size_t j = 0; for (auto node_index : node_topology_list) { @@ -322,7 +307,8 @@ TEST(MemoryOptimizerTests, TransformerLayerDetectionTest) { GraphViewer graph_viewer(graph); InlinedHashMap node_index_to_its_order_in_topological_sort_map; - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + const auto& node_ids = + graph_viewer.GetNodesInTopologicalOrder(optimizer::memory_optimizer::TOPOLOGICAL_SORT_ALGORITHM); // Find boundary ops between forward and backward pass, currently, it's limited to YieldOp. ptrdiff_t yield_op_order_in_topological_sort = -1;