Skip to content

Commit

Permalink
Add SessionOptions use_deterministic_compute to the C and C++ APIs. (m…
Browse files Browse the repository at this point in the history
…icrosoft#18944)

### Description
<!-- Describe your changes. -->
SessionOptions use_deterministic_compute can be set via the python API.
User request to enable setting via C API.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
microsoft#17416
  • Loading branch information
skottmckay authored Jan 4, 2024
1 parent 3b8b914 commit 8e9188e
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 1 deletion.
15 changes: 14 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
*/

#pragma once
#include <stdlib.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>

/** \brief The API version defined in this header
Expand Down Expand Up @@ -4515,6 +4516,18 @@ struct OrtApi {
* \since Version 1.17.
*/
ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out);

/** \brief Set whether to use deterministic compute.
*
* Default is false. If set to true, this will enable deterministic compute for GPU kernels where possible.
* Note that this most likely will have a performance cost.
*
* \param[in] options
* \param[in] value
*
* \since Version 1.17.
*/
ORT_API2_STATUS(SetDeterministicCompute, _Inout_ OrtSessionOptions* options, bool value);
};

/*
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,7 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads
SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads
SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel
SessionOptionsImpl& SetDeterministicCompute(bool value); ///< Wraps OrtApi::SetDeterministicCompute

SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena
SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena
Expand Down
6 changes: 6 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,12 @@ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetGraphOptimizationLevel(G
return *this;
}

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetDeterministicCompute(bool value) {
ThrowOnError(GetApi().SetDeterministicCompute(this->p_, value));
return *this;
}

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath));
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/session/abi_session_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,10 @@ ORT_API_STATUS_IMPL(OrtApis::AddExternalInitializers, _In_ OrtSessionOptions* op
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External initializers are not supported in this build");
#endif
}

ORT_API_STATUS_IMPL(OrtApis::SetDeterministicCompute, _Inout_ OrtSessionOptions* options, bool value) {
API_IMPL_BEGIN
options->value.use_deterministic_compute = value;
return nullptr;
API_IMPL_END
}
1 change: 1 addition & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2721,6 +2721,7 @@ static constexpr OrtApi ort_api_1_to_17 = {
&OrtApis::ShapeInferContext_SetOutputTypeShape,
&OrtApis::SetSymbolicDimensions,
&OrtApis::ReadOpAttr,
&OrtApis::SetDeterministicCompute,
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,5 +500,6 @@ ORT_API_STATUS_IMPL(ShapeInferContext_GetAttribute, _In_ const OrtShapeInferCont
ORT_API_STATUS_IMPL(ShapeInferContext_SetOutputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _In_ const OrtTensorTypeAndShapeInfo* info);
ORT_API_STATUS_IMPL(SetSymbolicDimensions, _In_ OrtTensorTypeAndShapeInfo* info, _In_ const char* dim_params[], _In_ size_t dim_params_length);
ORT_API_STATUS_IMPL(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out);
ORT_API_STATUS_IMPL(SetDeterministicCompute, _Inout_ OrtSessionOptions* options, bool value);

} // namespace OrtApis
6 changes: 6 additions & 0 deletions onnxruntime/test/shared_lib/test_session_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ TEST(CApiTest, session_options_graph_optimization_level) {
options.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED);
}

TEST(CApiTest, session_options_deterministic_compute) {
// Manual validation currently. Check that SetDeterministicCompute in abi_session_options.cc is hit.
Ort::SessionOptions options;
options.SetDeterministicCompute(true);
}

#if !defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD) && !defined(ORT_NO_EXCEPTIONS)

TEST(CApiTest, session_options_oversized_affinity_string) {
Expand Down

0 comments on commit 8e9188e

Please sign in to comment.