Skip to content

Commit

Permalink
Fix cuda graph capture (microsoft#15005)
Browse files Browse the repository at this point in the history
Fix two issues related to cuda graph capture:
microsoft#14942 and
microsoft#15002

Issue 1: Previously, graph capture starts at the second run. However,
memory pattern optimization will allocate memory from the second run,
and cudamalloc is not allowed during graph capture. In this PR, the
graph capture will start graph capture after 2 runs to avoid the issue.

Issue 2: microsoft#13495 introduced
multiple stream support. But stream cleanup will call
cudaStreamSyncronize which is not allowed in cuda graph capture. In this
PR, we move stream cleanup after cuda graph capture.

Update the squeeze net test model with dynamic axis so that we can test
with larger batch size. Add a test that could reproduce the bug (when
changing min runs from 2 back to 1).
  • Loading branch information
tianleiwu authored Jun 15, 2023
1 parent 8a3de16 commit 9be1332
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 79 deletions.
11 changes: 11 additions & 0 deletions onnxruntime/core/framework/device_stream_collection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,5 +139,16 @@ Stream* DeviceStreamCollection::GetRootStream() const {
return impl_->GetRootStream();
}

DeviceStreamCollectionHolder::DeviceStreamCollectionHolder(const SessionState* session_state)
: session_state_(session_state),
p_(session_state->AcquireDeviceStreamCollection()) {
}

DeviceStreamCollectionHolder::~DeviceStreamCollectionHolder() {
if (p_) {
session_state_->RecycleDeviceStreamCollection(std::move(p_));
}
}

} // namespace onnxruntime
#endif
12 changes: 12 additions & 0 deletions onnxruntime/core/framework/device_stream_collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,17 @@ class DeviceStreamCollection {
private:
std::unique_ptr<DeviceStreamCollectionImpl> impl_;
};

struct DeviceStreamCollectionHolder {
DeviceStreamCollectionHolder(const SessionState* session_state);
DeviceStreamCollectionHolder() = delete;
DeviceStreamCollectionHolder(const DeviceStreamCollectionHolder&) = delete;

~DeviceStreamCollectionHolder();

const SessionState* session_state_;
std::unique_ptr<DeviceStreamCollection> p_;
};

} // namespace onnxruntime
#endif
37 changes: 12 additions & 25 deletions onnxruntime/core/framework/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,22 +489,6 @@ static common::Status CopyInputsAcrossDevices(const SessionState& session_state,
}

#ifdef ORT_ENABLE_STREAM
struct DeviceStreamCollectionHolder {
DeviceStreamCollectionHolder(
const SessionState& session_state) : session_state_(session_state),
p_(session_state.AcquireDeviceStreamCollection()) {
}

~DeviceStreamCollectionHolder() {
if (p_) {
session_state_.RecycleDeviceStreamCollection(std::move(p_));
}
}

const SessionState& session_state_;
std::unique_ptr<DeviceStreamCollection> p_;
};

static void UpdateWithParentStream(DeviceStreamCollection& device_stream_collection,
Stream* parent_stream) {
if (parent_stream) {
Expand Down Expand Up @@ -551,7 +535,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons

Stream* device_stream = nullptr;
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder device_stream_collection_holder(session_state);
DeviceStreamCollectionHolder device_stream_collection_holder(&session_state);
if (device_stream_collection_holder.p_ != nullptr) {
DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get();
size_t num_streams = device_stream_collection->NumStreams();
Expand Down Expand Up @@ -750,26 +734,25 @@ common::Status ExecuteGraph(const SessionState& session_state,
FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
ExecutionMode execution_mode, const bool& terminate_flag,
const logging::Logger& logger, bool sync_execution_provider,
const logging::Logger& logger,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
bool only_execute_path_to_fetches,
Stream* parent_stream) {
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(session_state, feeds_fetches_manager));

// finalize the copy info using the provided feeds and fetches. will update device_copy_checks in the background
FinalizeFeedFetchCopyInfo(feeds_fetches_manager, feeds, fetches);
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder device_stream_collection_holder(session_state);
DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get();
auto retval = ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, {},
execution_mode, terminate_flag, logger,
device_stream_collection,
only_execute_path_to_fetches,
parent_stream);
if (device_stream_collection)
ORT_CHECK_AND_SET_RETVAL(device_stream_collection->CleanUp(sync_execution_provider));
return retval;
#else
ORT_UNUSED_PARAMETER(sync_execution_provider);
return ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, {},
execution_mode, terminate_flag, logger,
only_execute_path_to_fetches,
Expand All @@ -781,6 +764,9 @@ common::Status ExecuteGraph(const SessionState& session_state,
FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
ExecutionMode execution_mode, const RunOptions& run_options,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
const logging::Logger& logger) {
#ifdef USE_AZURE
const auto iter = run_options.config_options.configurations.find("use_azure");
Expand All @@ -793,14 +779,15 @@ common::Status ExecuteGraph(const SessionState& session_state,
logger);
}
#endif
bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0";
return ExecuteGraph(session_state,
feeds_fetches_manager,
feeds, fetches,
execution_mode,
run_options.terminate,
logger,
synchronize_execution_providers,
#ifdef ORT_ENABLE_STREAM
device_stream_collection_holder,
#endif
run_options.only_execute_path_to_fetches);
}

Expand Down Expand Up @@ -946,7 +933,7 @@ common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFet
Stream* parent_stream,
bool sync_subgraph_fetches) {
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder device_stream_collection_holder(session_state);
DeviceStreamCollectionHolder device_stream_collection_holder(&session_state);
DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get();

auto retval = ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, fetch_allocators,
Expand Down
10 changes: 8 additions & 2 deletions onnxruntime/core/framework/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,19 @@ void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager,
common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger,
bool sync_execution_provider,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
bool only_execute_path_to_fetches = false,
Stream* parent_stream = nullptr);

common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
ExecutionMode execution_mode, const RunOptions& run_options, const logging::Logger& logger);
ExecutionMode execution_mode, const RunOptions& run_options,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
const logging::Logger& logger);

#ifdef ENABLE_TRAINING
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/core/providers/cuda/cuda_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,12 @@ class CUDAExecutionProvider : public IExecutionProvider {
CUDAGraph cuda_graph_;
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;
const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations.

// There is chance that the second regular run allocates GPU memory for causes like:
// (1) memory pattern is enabled. (2) arena allocation for stream.
// Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs
// to allocate enough memory in Arena before graph capturing.
const int min_num_runs_before_cuda_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations.
};

using PerThreadContextMap = std::unordered_map<const CUDAExecutionProvider*, std::weak_ptr<PerThreadContext>>;
Expand Down
103 changes: 60 additions & 43 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1533,32 +1533,30 @@ common::Status InferenceSession::Initialize() {
// Then the CUDA EP is cached for triggering a ReplayGraph() in Run().
auto* cuda_ep = execution_providers_.Get(onnxruntime::kCudaExecutionProvider);
if (cuda_ep && cuda_ep->IsGraphCaptureEnabled()) {
if (cuda_ep->IsGraphCaptureEnabled()) {
if (HasControlflowNodes(graph)) {
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user "
<< " as the model has control flow nodes which can't be supported by CUDA Graphs.";

// Return error status as we don't want the session initialization to complete successfully
// if the user has requested usage of CUDA Graph feature and we cannot honor that.
ORT_RETURN_IF_ERROR_SESSIONID_(
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA Graph feature as requested by the user "
" as the model has control flow nodes which can't be supported by CUDA Graphs."));
} else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, onnxruntime::kCudaExecutionProvider)) {
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user "
<< " as all the graph nodes have not been partitioned to the CUDA EP.";

// Return error status as we don't want the session initialization to complete successfully
// if the user has requested usage of CUDA Graph feature and we cannot honor that.
ORT_RETURN_IF_ERROR_SESSIONID_(
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA Graph feature as requested by the user "
" as all the graph nodes have not been partitioned to the CUDA EP."));

} else {
LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user.";
cached_execution_provider_for_graph_replay_.SetExecutionProvider(cuda_ep);
}
if (HasControlflowNodes(graph)) {
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user "
<< " as the model has control flow nodes which can't be supported by CUDA Graphs.";

// Return error status as we don't want the session initialization to complete successfully
// if the user has requested usage of CUDA Graph feature and we cannot honor that.
ORT_RETURN_IF_ERROR_SESSIONID_(
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA Graph feature as requested by the user "
" as the model has control flow nodes which can't be supported by CUDA Graphs."));
} else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, onnxruntime::kCudaExecutionProvider)) {
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user "
<< " as all the graph nodes have not been partitioned to the CUDA EP.";

// Return error status as we don't want the session initialization to complete successfully
// if the user has requested usage of CUDA Graph feature and we cannot honor that.
ORT_RETURN_IF_ERROR_SESSIONID_(
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA Graph feature as requested by the user "
" as all the graph nodes have not been partitioned to the CUDA EP."));

} else {
LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user.";
cached_execution_provider_for_graph_replay_.SetExecutionProvider(cuda_ep);
}
}

Expand Down Expand Up @@ -2141,9 +2139,37 @@ Status InferenceSession::Run(const RunOptions& run_options,
session_state_->IncrementGraphExecutionCounter();
#endif

ORT_CHECK_AND_SET_RETVAL(utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches,
session_options_.execution_mode,
run_options, run_logger));
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder device_stream_collection_holder(session_state_.get());
#endif

if (retval.IsOK()) {
retval = utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches,
session_options_.execution_mode,
run_options,
#ifdef ORT_ENABLE_STREAM
device_stream_collection_holder,
#endif
run_logger);
}

// info all execution providers InferenceSession:Run ended
for (auto* xp : exec_providers_to_stop) {
bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0";
auto status = xp->OnRunEnd(synchronize_execution_providers);
ORT_CHECK_AND_SET_RETVAL(status);
}

// Move stream cleanup from ExecuteGraph to here for cuda graph capture.
// Cleanup will call cudaStreamSyncronize, which is not allowed for graph capture.
// Note that graph capture ends when we call xp->OnRunEnd() in the above code so it is safe here.
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get();
if (device_stream_collection) {
bool sync_execution_provider = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0";
ORT_CHECK_AND_SET_RETVAL(device_stream_collection->CleanUp(sync_execution_provider));
}
#endif
}
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
Expand All @@ -2154,13 +2180,6 @@ Status InferenceSession::Run(const RunOptions& run_options,
retval = Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Run()");
}

// info all execution providers InferenceSession:Run ended
for (auto* xp : exec_providers_to_stop) {
bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0";
auto status = xp->OnRunEnd(synchronize_execution_providers);
ORT_CHECK_AND_SET_RETVAL(status);
}

if (!arenas_to_shrink.empty()) {
ShrinkMemoryArenas(arenas_to_shrink);
}
Expand Down Expand Up @@ -2192,15 +2211,13 @@ Status InferenceSession::Run(const RunOptions& run_options,
TraceLoggingWriteStop(ortrun_activity, "OrtRun");
#endif

// As two inference runs (one for memory allocation and one for graph capturing)
// are needed before replaying the captured graph, here run the inference again
// to capture the graph, so that users just need one session run to capture
// the graph.
// As N+1 inference runs (N for memory allocation and 1 for graph capturing)
// are needed before replaying the captured graph, here run N inference runs recursively until graph captured,
// so that users just need one session run to capture the graph.
// N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP, and the value could be different for other EP.
if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() &&
!cached_execution_provider_for_graph_replay_.IsGraphCaptured()) {
LOGS(*session_logger_, INFO) << "Start the second Run() to capture the graph. "
"The first one is for necessary memory allocation;"
"The second one is for capturing the graph.";
LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture.";
ORT_RETURN_IF_ERROR(Run(run_options, feed_names, feeds, output_names, p_fetches, p_fetches_device_info));
}
return retval;
Expand Down
Loading

0 comments on commit 9be1332

Please sign in to comment.