diff --git a/onnxruntime/core/framework/device_stream_collection.cc b/onnxruntime/core/framework/device_stream_collection.cc index 669fb0bc79cb0..3c102f2679160 100644 --- a/onnxruntime/core/framework/device_stream_collection.cc +++ b/onnxruntime/core/framework/device_stream_collection.cc @@ -139,5 +139,16 @@ Stream* DeviceStreamCollection::GetRootStream() const { return impl_->GetRootStream(); } +DeviceStreamCollectionHolder::DeviceStreamCollectionHolder(const SessionState* session_state) + : session_state_(session_state), + p_(session_state->AcquireDeviceStreamCollection()) { +} + +DeviceStreamCollectionHolder::~DeviceStreamCollectionHolder() { + if (p_) { + session_state_->RecycleDeviceStreamCollection(std::move(p_)); + } +} + } // namespace onnxruntime #endif diff --git a/onnxruntime/core/framework/device_stream_collection.h b/onnxruntime/core/framework/device_stream_collection.h index 8a1ed8a41ef89..8a5f784845b90 100644 --- a/onnxruntime/core/framework/device_stream_collection.h +++ b/onnxruntime/core/framework/device_stream_collection.h @@ -45,5 +45,17 @@ class DeviceStreamCollection { private: std::unique_ptr impl_; }; + +struct DeviceStreamCollectionHolder { + DeviceStreamCollectionHolder(const SessionState* session_state); + DeviceStreamCollectionHolder() = delete; + DeviceStreamCollectionHolder(const DeviceStreamCollectionHolder&) = delete; + + ~DeviceStreamCollectionHolder(); + + const SessionState* session_state_; + std::unique_ptr p_; +}; + } // namespace onnxruntime #endif diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 74c1f19580e00..fcb73825f1c79 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -489,22 +489,6 @@ static common::Status CopyInputsAcrossDevices(const SessionState& session_state, } #ifdef ORT_ENABLE_STREAM -struct DeviceStreamCollectionHolder { - DeviceStreamCollectionHolder( - const SessionState& session_state) : session_state_(session_state), - p_(session_state.AcquireDeviceStreamCollection()) { - } - - ~DeviceStreamCollectionHolder() { - if (p_) { - session_state_.RecycleDeviceStreamCollection(std::move(p_)); - } - } - - const SessionState& session_state_; - std::unique_ptr p_; -}; - static void UpdateWithParentStream(DeviceStreamCollection& device_stream_collection, Stream* parent_stream) { if (parent_stream) { @@ -551,7 +535,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons Stream* device_stream = nullptr; #ifdef ORT_ENABLE_STREAM - DeviceStreamCollectionHolder device_stream_collection_holder(session_state); + DeviceStreamCollectionHolder device_stream_collection_holder(&session_state); if (device_stream_collection_holder.p_ != nullptr) { DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get(); size_t num_streams = device_stream_collection->NumStreams(); @@ -750,7 +734,10 @@ common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, gsl::span feeds, std::vector& fetches, ExecutionMode execution_mode, const bool& terminate_flag, - const logging::Logger& logger, bool sync_execution_provider, + const logging::Logger& logger, +#ifdef ORT_ENABLE_STREAM + DeviceStreamCollectionHolder& device_stream_collection_holder, +#endif bool only_execute_path_to_fetches, Stream* parent_stream) { ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(session_state, feeds_fetches_manager)); @@ -758,18 +745,14 @@ common::Status ExecuteGraph(const SessionState& session_state, // finalize the copy info using the provided feeds and fetches. will update device_copy_checks in the background FinalizeFeedFetchCopyInfo(feeds_fetches_manager, feeds, fetches); #ifdef ORT_ENABLE_STREAM - DeviceStreamCollectionHolder device_stream_collection_holder(session_state); DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get(); auto retval = ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, {}, execution_mode, terminate_flag, logger, device_stream_collection, only_execute_path_to_fetches, parent_stream); - if (device_stream_collection) - ORT_CHECK_AND_SET_RETVAL(device_stream_collection->CleanUp(sync_execution_provider)); return retval; #else - ORT_UNUSED_PARAMETER(sync_execution_provider); return ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, {}, execution_mode, terminate_flag, logger, only_execute_path_to_fetches, @@ -781,6 +764,9 @@ common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, gsl::span feeds, std::vector& fetches, ExecutionMode execution_mode, const RunOptions& run_options, +#ifdef ORT_ENABLE_STREAM + DeviceStreamCollectionHolder& device_stream_collection_holder, +#endif const logging::Logger& logger) { #ifdef USE_AZURE const auto iter = run_options.config_options.configurations.find("use_azure"); @@ -793,14 +779,15 @@ common::Status ExecuteGraph(const SessionState& session_state, logger); } #endif - bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0"; return ExecuteGraph(session_state, feeds_fetches_manager, feeds, fetches, execution_mode, run_options.terminate, logger, - synchronize_execution_providers, +#ifdef ORT_ENABLE_STREAM + device_stream_collection_holder, +#endif run_options.only_execute_path_to_fetches); } @@ -946,7 +933,7 @@ common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFet Stream* parent_stream, bool sync_subgraph_fetches) { #ifdef ORT_ENABLE_STREAM - DeviceStreamCollectionHolder device_stream_collection_holder(session_state); + DeviceStreamCollectionHolder device_stream_collection_holder(&session_state); DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get(); auto retval = ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, fetch_allocators, diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index 3ca9fef62ada0..56f41154b719c 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -84,13 +84,19 @@ void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager, common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, gsl::span feeds, std::vector& fetches, ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger, - bool sync_execution_provider, +#ifdef ORT_ENABLE_STREAM + DeviceStreamCollectionHolder& device_stream_collection_holder, +#endif bool only_execute_path_to_fetches = false, Stream* parent_stream = nullptr); common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, gsl::span feeds, std::vector& fetches, - ExecutionMode execution_mode, const RunOptions& run_options, const logging::Logger& logger); + ExecutionMode execution_mode, const RunOptions& run_options, +#ifdef ORT_ENABLE_STREAM + DeviceStreamCollectionHolder& device_stream_collection_holder, +#endif + const logging::Logger& logger); #ifdef ENABLE_TRAINING common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index cb12fc15637fa..89a5fb83bea88 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -221,7 +221,12 @@ class CUDAExecutionProvider : public IExecutionProvider { CUDAGraph cuda_graph_; bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; - const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + + // There is chance that the second regular run allocates GPU memory for causes like: + // (1) memory pattern is enabled. (2) arena allocation for stream. + // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs + // to allocate enough memory in Arena before graph capturing. + const int min_num_runs_before_cuda_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations. }; using PerThreadContextMap = std::unordered_map>; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 50ffb730871f0..4d0f0ccde749a 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1533,32 +1533,30 @@ common::Status InferenceSession::Initialize() { // Then the CUDA EP is cached for triggering a ReplayGraph() in Run(). auto* cuda_ep = execution_providers_.Get(onnxruntime::kCudaExecutionProvider); if (cuda_ep && cuda_ep->IsGraphCaptureEnabled()) { - if (cuda_ep->IsGraphCaptureEnabled()) { - if (HasControlflowNodes(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << " as the model has control flow nodes which can't be supported by CUDA Graphs."; - - // Return error status as we don't want the session initialization to complete successfully - // if the user has requested usage of CUDA Graph feature and we cannot honor that. - ORT_RETURN_IF_ERROR_SESSIONID_( - ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as the model has control flow nodes which can't be supported by CUDA Graphs.")); - } else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, onnxruntime::kCudaExecutionProvider)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << " as all the graph nodes have not been partitioned to the CUDA EP."; - - // Return error status as we don't want the session initialization to complete successfully - // if the user has requested usage of CUDA Graph feature and we cannot honor that. - ORT_RETURN_IF_ERROR_SESSIONID_( - ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as all the graph nodes have not been partitioned to the CUDA EP.")); - - } else { - LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; - cached_execution_provider_for_graph_replay_.SetExecutionProvider(cuda_ep); - } + if (HasControlflowNodes(graph)) { + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " + << " as the model has control flow nodes which can't be supported by CUDA Graphs."; + + // Return error status as we don't want the session initialization to complete successfully + // if the user has requested usage of CUDA Graph feature and we cannot honor that. + ORT_RETURN_IF_ERROR_SESSIONID_( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "This session cannot use the CUDA Graph feature as requested by the user " + " as the model has control flow nodes which can't be supported by CUDA Graphs.")); + } else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, onnxruntime::kCudaExecutionProvider)) { + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " + << " as all the graph nodes have not been partitioned to the CUDA EP."; + + // Return error status as we don't want the session initialization to complete successfully + // if the user has requested usage of CUDA Graph feature and we cannot honor that. + ORT_RETURN_IF_ERROR_SESSIONID_( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "This session cannot use the CUDA Graph feature as requested by the user " + " as all the graph nodes have not been partitioned to the CUDA EP.")); + + } else { + LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; + cached_execution_provider_for_graph_replay_.SetExecutionProvider(cuda_ep); } } @@ -2141,9 +2139,37 @@ Status InferenceSession::Run(const RunOptions& run_options, session_state_->IncrementGraphExecutionCounter(); #endif - ORT_CHECK_AND_SET_RETVAL(utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches, - session_options_.execution_mode, - run_options, run_logger)); +#ifdef ORT_ENABLE_STREAM + DeviceStreamCollectionHolder device_stream_collection_holder(session_state_.get()); +#endif + + if (retval.IsOK()) { + retval = utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches, + session_options_.execution_mode, + run_options, +#ifdef ORT_ENABLE_STREAM + device_stream_collection_holder, +#endif + run_logger); + } + + // info all execution providers InferenceSession:Run ended + for (auto* xp : exec_providers_to_stop) { + bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0"; + auto status = xp->OnRunEnd(synchronize_execution_providers); + ORT_CHECK_AND_SET_RETVAL(status); + } + + // Move stream cleanup from ExecuteGraph to here for cuda graph capture. + // Cleanup will call cudaStreamSyncronize, which is not allowed for graph capture. + // Note that graph capture ends when we call xp->OnRunEnd() in the above code so it is safe here. +#ifdef ORT_ENABLE_STREAM + DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get(); + if (device_stream_collection) { + bool sync_execution_provider = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0"; + ORT_CHECK_AND_SET_RETVAL(device_stream_collection->CleanUp(sync_execution_provider)); + } +#endif } ORT_CATCH(const std::exception& e) { ORT_HANDLE_EXCEPTION([&]() { @@ -2154,13 +2180,6 @@ Status InferenceSession::Run(const RunOptions& run_options, retval = Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Run()"); } - // info all execution providers InferenceSession:Run ended - for (auto* xp : exec_providers_to_stop) { - bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0"; - auto status = xp->OnRunEnd(synchronize_execution_providers); - ORT_CHECK_AND_SET_RETVAL(status); - } - if (!arenas_to_shrink.empty()) { ShrinkMemoryArenas(arenas_to_shrink); } @@ -2192,15 +2211,13 @@ Status InferenceSession::Run(const RunOptions& run_options, TraceLoggingWriteStop(ortrun_activity, "OrtRun"); #endif - // As two inference runs (one for memory allocation and one for graph capturing) - // are needed before replaying the captured graph, here run the inference again - // to capture the graph, so that users just need one session run to capture - // the graph. + // As N+1 inference runs (N for memory allocation and 1 for graph capturing) + // are needed before replaying the captured graph, here run N inference runs recursively until graph captured, + // so that users just need one session run to capture the graph. + // N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP, and the value could be different for other EP. if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() && !cached_execution_provider_for_graph_replay_.IsGraphCaptured()) { - LOGS(*session_logger_, INFO) << "Start the second Run() to capture the graph. " - "The first one is for necessary memory allocation;" - "The second one is for capturing the graph."; + LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture."; ORT_RETURN_IF_ERROR(Run(run_options, feed_names, feeds, output_names, p_fetches, p_fetches_device_info)); } return retval; diff --git a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py index 5dd927a566e81..30e299863f4f4 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py +++ b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py @@ -1,20 +1,63 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import gc # noqa: F401 -import os # noqa: F401 -import sys # noqa: F401 -import threading # noqa: F401 -import time # noqa: F401 - -# -*- coding: UTF-8 -*- import unittest +from typing import Dict, List import numpy as np from helper import get_name import onnxruntime as onnxrt -from onnxruntime.capi.onnxruntime_pybind11_state import Fail # noqa: F401 + + +class CudaGraphHelper: + def __init__( + self, + ort_session: onnxrt.InferenceSession, + input_and_output_shape: Dict[str, List[int]], + device_id: int = 0, + ): + self.input_names = [input.name for input in ort_session.get_inputs()] + self.output_names = [output.name for output in ort_session.get_outputs()] + + self.input_and_output_shape = input_and_output_shape + self.io_numpy_type = self.get_io_numpy_type_map(ort_session) + self.io_binding = ort_session.io_binding() + self.io_ort_value = {} + + for name in self.input_names + self.output_names: + ort_value = onnxrt.OrtValue.ortvalue_from_shape_and_type( + input_and_output_shape[name], self.io_numpy_type[name], "cuda", device_id + ) + self.io_ort_value[name] = ort_value + if name in self.input_names: + self.io_binding.bind_ortvalue_input(name, ort_value) + else: + self.io_binding.bind_ortvalue_output(name, ort_value) + + def get_io_numpy_type_map(self, ort_session: onnxrt.InferenceSession): + ort_type_to_numpy_type = { + "tensor(int64)": np.longlong, + "tensor(int32)": np.intc, + "tensor(float)": np.float32, + "tensor(float16)": np.float16, + } + + name_to_numpy_type = {} + for _input in ort_session.get_inputs(): + name_to_numpy_type[_input.name] = ort_type_to_numpy_type[_input.type] + + for output in ort_session.get_outputs(): + name_to_numpy_type[output.name] = ort_type_to_numpy_type[output.type] + + return name_to_numpy_type + + def update_inputs(self, inputs: Dict[str, np.ndarray]): + for input_name in self.input_names: + self.io_ort_value[input_name].update_inplace(inputs[input_name]) + + def get_output(self, output_name: str): + return self.io_ort_value[output_name].numpy() class TestInferenceSessionWithCudaGraph(unittest.TestCase): @@ -74,6 +117,44 @@ def testRunModelWithCudaGraph(self): # noqa: N802 atol=1e-05, ) + def testArenaWithCudaGraph(self): # noqa: N802 + if "CUDAExecutionProvider" in onnxrt.get_available_providers(): + # To test cuda graph catpure, we set Arena extend strategy to be SameAsRequested so as to detect any + # potential memory allocation after the first run. + providers = [ + ("CUDAExecutionProvider", {"enable_cuda_graph": True, "arena_extend_strategy": "kSameAsRequested"}) + ] + test_model_path = get_name("squeezenet/model.onnx") + + input_and_output_shape = { + "data_0": [16, 3, 224, 224], + "softmaxout_1": [16, 1000, 1, 1], + } + + session_options = onnxrt.SessionOptions() + # It is optional to disable memory pattern since min_num_runs_before_cuda_graph_capture_ = 2. + session_options.enable_mem_pattern = False + session = onnxrt.InferenceSession(test_model_path, session_options, providers=providers) + + cuda_graph_helper = CudaGraphHelper(session, input_and_output_shape) + io_binding = cuda_graph_helper.io_binding + + # Create a random input for testing. + np.random.seed(0) + inputs = {"data_0": np.random.randint(0, 256, size=input_and_output_shape["data_0"]).astype(np.float32)} + + # One regular run for the necessary memory allocation and cuda graph capturing + cuda_graph_helper.update_inputs(inputs) + session.run_with_iobinding(io_binding) + expected_output = cuda_graph_helper.get_output("softmaxout_1") + + # After capturing, CUDA graph replay happens from this Run onwards + cuda_graph_helper.update_inputs(inputs) + session.run_with_iobinding(io_binding) + output = cuda_graph_helper.get_output("softmaxout_1") + + np.testing.assert_allclose(expected_output, output, rtol=1e-02, atol=1e-02) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/testdata/squeezenet/model.onnx b/onnxruntime/test/testdata/squeezenet/model.onnx index b8e1dfce26d99..24de98fd4bc1e 100644 Binary files a/onnxruntime/test/testdata/squeezenet/model.onnx and b/onnxruntime/test/testdata/squeezenet/model.onnx differ