From 82fc95f478b1944ab3cc83013cc4e36dc0cbeba5 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 24 Jun 2022 09:50:04 -0700 Subject: [PATCH] Deprecate APIs returning raw ptrs and provide replacements (#11922) Provider better documentation --- .../experimental_onnxruntime_cxx_inline.h | 15 +- .../core/session/onnxruntime_cxx_api.h | 185 +++++++++++++++++- .../core/session/onnxruntime_cxx_inline.h | 87 +++++++- onnxruntime/core/session/onnxruntime_c_api.cc | 22 ++- onnxruntime/test/onnx/dataitem_request.cc | 5 +- .../test/opaque_api/test_opaque_api.cc | 8 +- onnxruntime/test/perftest/ort_test_session.cc | 5 +- onnxruntime/test/shared_lib/test_inference.cc | 102 +++++----- 8 files changed, 346 insertions(+), 83 deletions(-) diff --git a/include/onnxruntime/core/session/experimental_onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/experimental_onnxruntime_cxx_inline.h index 075de83ead0bd..e754d59832c39 100644 --- a/include/onnxruntime/core/session/experimental_onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/experimental_onnxruntime_cxx_inline.h @@ -35,9 +35,8 @@ inline std::vector Session::GetInputNames() const { size_t node_count = GetInputCount(); std::vector out(node_count); for (size_t i = 0; i < node_count; i++) { - char* tmp = GetInputName(i, allocator); - out[i] = tmp; - allocator.Free(tmp); // prevent memory leak + auto tmp = GetInputNameAllocated(i, allocator); + out[i] = tmp.get(); } return out; } @@ -47,9 +46,8 @@ inline std::vector Session::GetOutputNames() const { size_t node_count = GetOutputCount(); std::vector out(node_count); for (size_t i = 0; i < node_count; i++) { - char* tmp = GetOutputName(i, allocator); - out[i] = tmp; - allocator.Free(tmp); // prevent memory leak + auto tmp = GetOutputNameAllocated(i, allocator); + out[i] = tmp.get(); } return out; } @@ -59,9 +57,8 @@ inline std::vector Session::GetOverridableInitializerNames() const size_t init_count = GetOverridableInitializerCount(); std::vector out(init_count); for (size_t i = 0; i < init_count; i++) { - char* tmp = GetOverridableInitializerName(i, allocator); - out[i] = tmp; - allocator.Free(tmp); // prevent memory leak + auto tmp = GetOverridableInitializerNameAllocated(i, allocator); + out[i] = tmp.get(); } return out; } diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 4dedec707d627..5efdabfa788c9 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -250,6 +250,22 @@ struct TypeInfo; struct Value; struct ModelMetadata; +namespace detail { +// Light functor to release memory with OrtAllocator +struct AllocatedFree { + OrtAllocator* allocator_; + explicit AllocatedFree(OrtAllocator* allocator) + : allocator_(allocator) {} + void operator()(void* ptr) const { if(ptr) allocator_->Free(allocator_, ptr); } +}; +} // namespace detail + +/** \brief unique_ptr typedef used to own strings allocated by OrtAllocators + * and release them at the end of the scope. The lifespan of the given allocator + * must eclipse the lifespan of AllocatedStringPtr instance + */ +using AllocatedStringPtr = std::unique_ptr; + /** \brief The Env (Environment) * * The Env holds the logging state used by all other objects. @@ -385,13 +401,108 @@ struct ModelMetadata : Base { explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used explicit ModelMetadata(OrtModelMetadata* p) : Base{p} {} ///< Used for interop with the C API - char* GetProducerName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName - char* GetGraphName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName - char* GetDomain(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain + /** \deprecated use GetProducerNameAllocated() + * [[deprecated]] + * This interface produces a pointer that must be released + * by the specified allocator and is often leaked. Not exception safe. + */ + char* GetProducerName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName + + /** \brief Returns a copy of the producer name. + * + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName + + /** \deprecated use GetGraphNameAllocated() + * [[deprecated]] + * This interface produces a pointer that must be released + * by the specified allocator and is often leaked. Not exception safe. + */ + char* GetGraphName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName + + /** \brief Returns a copy of the graph name. + * + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName + + /** \deprecated use GetDomainAllocated() + * [[deprecated]] + * This interface produces a pointer that must be released + * by the specified allocator and is often leaked. Not exception safe. + */ + char* GetDomain(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain + + /** \brief Returns a copy of the domain name. + * + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain + + /** \deprecated use GetDescriptionAllocated() + * [[deprecated]] + * This interface produces a pointer that must be released + * by the specified allocator and is often leaked. Not exception safe. + */ char* GetDescription(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription + + /** \brief Returns a copy of the description. + * + * \param allocator to allocate memory for the copy of the string returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription + + /** \deprecated use GetGraphDescriptionAllocated() + * [[deprecated]] + * This interface produces a pointer that must be released + * by the specified allocator and is often leaked. Not exception safe. + */ char* GetGraphDescription(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription + + /** \brief Returns a copy of the graph description. + * + * \param allocator to allocate memory for the copy of the string returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription + + /** \deprecated use GetCustomMetadataMapKeysAllocated() + * [[deprecated]] + * This interface produces multiple pointers that must be released + * by the specified allocator and is often leaked. Not exception safe. + */ char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys + + std::vector GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys + + /** \deprecated use LookupCustomMetadataMapAllocated() + * [[deprecated]] + * This interface produces a pointer that must be released + * by the specified allocator and is often leaked. Not exception safe. + */ char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap + + /** \brief Looks up a value by a key in the Custom Metadata map + * + * \param zero terminated string key to lookup + * \param allocator to allocate memory for the copy of the string returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * maybe nullptr if key is not found. + * + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap + int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion }; @@ -436,12 +547,70 @@ struct Session : Base { size_t GetOutputCount() const; ///< Returns the number of model outputs size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden - char* GetInputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetInputName - char* GetOutputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOutputName + /** \deprecated use GetInputNameAllocated() + * [[deprecated]] + * This interface produces a pointer that must be released + * by the specified allocator and is often leaked. Not exception safe. + */ + char* GetInputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetInputName + + /** \brief Returns a copy of input name at the specified index. + * + * \param index must less than the value returned by GetInputCount() + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const; + + /** \deprecated use GetOutputNameAllocated() + * [[deprecated]] + * This interface produces a pointer that must be released + * by the specified allocator and is often leaked. Not exception safe. + */ + char* GetOutputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOutputName + + /** \brief Returns a copy of output name at then specified index. + * + * \param index must less than the value returned by GetOutputCount() + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const; + + /** \deprecated use GetOverridableInitializerNameAllocated() + * [[deprecated]] + * This interface produces a pointer that must be released + * by the specified allocator and is often leaked. Not exception safe. + */ char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName - char* EndProfiling(OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionEndProfiling - uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs - ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata + + /** \brief Returns a copy of the overridable initializer name at then specified index. + * + * \param index must less than the value returned by GetOverridableInitializerCount() + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName + + /** \deprecated use EndProfilingAllocated() + * [[deprecated]] + * This interface produces a pointer that must be released + * by the specified allocator and is often leaked. Not exception safe. + */ + char* EndProfiling(OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionEndProfiling + + /** \brief Returns a copy of the profiling file name. + * + * \param allocator to allocate memory for the copy of the string returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionEndProfiling + uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs + ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 4d4ab50bc89d9..fa1c3aad3a7d7 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -200,7 +200,7 @@ inline void IoBinding::BindOutput(const char* name, const MemoryInfo& mem_info) inline std::vector IoBinding::GetOutputNamesHelper(OrtAllocator* allocator) const { std::vector result; - auto free_fn = [allocator](void* p) { if (p) allocator->Free(allocator, p); }; + auto free_fn = detail::AllocatedFree(allocator); using Ptr = std::unique_ptr; char* buffer = nullptr; @@ -656,18 +656,42 @@ inline char* Session::GetOutputName(size_t index, OrtAllocator* allocator) const return out; } +inline AllocatedStringPtr Session::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().SessionGetInputName(p_, index, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline AllocatedStringPtr Session::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().SessionGetOutputName(p_, index, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + inline char* Session::GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const { char* out; ThrowOnError(GetApi().SessionGetOverridableInitializerName(p_, index, allocator, &out)); return out; } +inline AllocatedStringPtr Session::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().SessionGetOverridableInitializerName(p_, index, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + inline char* Session::EndProfiling(OrtAllocator* allocator) const { char* out; ThrowOnError(GetApi().SessionEndProfiling(p_, allocator, &out)); return out; } +inline AllocatedStringPtr Session::EndProfilingAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().SessionEndProfiling(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + inline uint64_t Session::GetProfilingStartTimeNs() const { uint64_t out; ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(p_, &out)); @@ -686,42 +710,103 @@ inline char* ModelMetadata::GetProducerName(OrtAllocator* allocator) const { return out; } +inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + inline char* ModelMetadata::GetGraphName(OrtAllocator* allocator) const { char* out; ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out)); return out; } +inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + inline char* ModelMetadata::GetDomain(OrtAllocator* allocator) const { char* out; ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out)); return out; } +inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + inline char* ModelMetadata::GetDescription(OrtAllocator* allocator) const { char* out; ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out)); return out; } +inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + inline char* ModelMetadata::GetGraphDescription(OrtAllocator* allocator) const { char* out; ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out)); return out; } +inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + inline char* ModelMetadata::LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const { char* out; ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out)); return out; } +inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + inline char** ModelMetadata::GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const { char** out; ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys)); return out; } +inline std::vector ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const { + auto deletor = detail::AllocatedFree(allocator); + std::vector result; + + char** out = nullptr; + int64_t num_keys = 0; + ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys)); + if (num_keys <= 0) { + return result; + } + + // array of pointers will be freed + std::unique_ptr array_guard(out, deletor); + // reserve may throw + auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); }; + std::unique_ptr strings_guard(out, strings_deletor); + result.reserve(static_cast(num_keys)); + strings_guard.release(); + for (int64_t i = 0; i < num_keys; ++i) { + result.push_back(AllocatedStringPtr(out[i], deletor)); + } + + return result; +} + inline int64_t ModelMetadata::GetVersion() const { int64_t out; ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out)); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index d2396bca9c358..b42e80e0a8b72 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1382,17 +1382,31 @@ ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetCustomMetadataMapKeys, // To guard against overflow in the next step where we compute bytes to allocate SafeInt alloc_count(count); + InlinedVector string_holders; + string_holders.reserve(count); + + auto deletor = Ort::detail::AllocatedFree(allocator); // alloc_count * sizeof(...) will throw if there was an overflow which will be caught in API_IMPL_END // and be returned to the user as a status char** p = reinterpret_cast(allocator->Alloc(allocator, alloc_count * sizeof(char*))); assert(p != nullptr); - auto map_iter = custom_metadata_map.cbegin(); + + // StrDup may throw + std::unique_ptr array_guard(p, deletor); + int64_t i = 0; - while (map_iter != custom_metadata_map.cend()) { - p[i++] = StrDup(map_iter->first, allocator); - ++map_iter; + for (const auto& e : custom_metadata_map) { + auto* s = StrDup(e.first, allocator); + string_holders.push_back(Ort::AllocatedStringPtr(s, deletor)); + p[i++] = s; } + + for (auto& s : string_holders) { + s.release(); + } + *keys = p; + array_guard.release(); } *num_keys = static_cast(count); diff --git a/onnxruntime/test/onnx/dataitem_request.cc b/onnxruntime/test/onnx/dataitem_request.cc index abe8111a29227..1f56970ebdfe9 100644 --- a/onnxruntime/test/onnx/dataitem_request.cc +++ b/onnxruntime/test/onnx/dataitem_request.cc @@ -84,10 +84,9 @@ std::pair DataTaskRequestContext::RunImpl() { size_t output_count = session_.GetOutputCount(); std::vector output_names(output_count); for (size_t i = 0; i != output_count; ++i) { - char* output_name = session_.GetOutputName(i, default_allocator_); + auto output_name = session_.GetOutputNameAllocated(i, default_allocator_); assert(output_name != nullptr); - output_names[i] = output_name; - Ort::ThrowOnError(Ort::GetApi().AllocatorFree(default_allocator_, output_name)); + output_names[i] = output_name.get(); } TIME_SPEC start_time; diff --git a/onnxruntime/test/opaque_api/test_opaque_api.cc b/onnxruntime/test/opaque_api/test_opaque_api.cc index 8d1b1d224af14..41bd07d39b29d 100644 --- a/onnxruntime/test/opaque_api/test_opaque_api.cc +++ b/onnxruntime/test/opaque_api/test_opaque_api.cc @@ -204,14 +204,14 @@ TEST(OpaqueApiTest, RunModelWithOpaqueInputOutput) { // Expecting one input size_t num_input_nodes = session.GetInputCount(); EXPECT_EQ(num_input_nodes, 1U); - const char* input_name = session.GetInputName(0, allocator); + auto input_name = session.GetInputNameAllocated(0, allocator); size_t num_output_nodes = session.GetOutputCount(); EXPECT_EQ(num_output_nodes, 1U); - const char* output_name = session.GetOutputName(0, allocator); + auto output_name = session.GetOutputNameAllocated(0, allocator); - const char* const input_names[] = {input_name}; - const char* const output_names[] = {output_name}; + const char* const input_names[] = {input_name.get()}; + const char* const output_names[] = {output_name.get()}; // Input const std::string input_string{"hi, hello, high, highest"}; diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index d9a6a13c3f976..9b15ebb394f19 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -547,10 +547,9 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); output_names_.resize(output_count); Ort::AllocatorWithDefaultOptions a; for (size_t i = 0; i != output_count; ++i) { - char* output_name = session_.GetOutputName(i, a); + auto output_name = session_.GetOutputNameAllocated(i, a); assert(output_name != nullptr); - output_names_[i] = output_name; - a.Free(output_name); + output_names_[i] = output_name.get(); } output_names_raw_ptr.resize(output_count); for (size_t i = 0; i != output_count; ++i) { diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index ee1c4456e5140..cc840af0b6001 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1422,9 +1422,10 @@ TEST(CApiTest, override_initializer) { size_t init_count = session.GetOverridableInitializerCount(); ASSERT_EQ(init_count, 1U); - char* f1_init_name = session.GetOverridableInitializerName(0, allocator.get()); - ASSERT_TRUE(strcmp("F1", f1_init_name) == 0); - allocator->Free(f1_init_name); + { + auto f1_init_name = session.GetOverridableInitializerNameAllocated(0, allocator.get()); + ASSERT_TRUE(strcmp("F1", f1_init_name.get()) == 0); + } Ort::TypeInfo init_type_info = session.GetOverridableInitializerTypeInfo(0); ASSERT_EQ(ONNX_TYPE_TENSOR, init_type_info.GetONNXType()); @@ -1466,10 +1467,10 @@ TEST(CApiTest, end_profiling) { session_options_1.EnableProfiling("profile_prefix"); #endif Ort::Session session_1(*ort_env, MODEL_WITH_CUSTOM_MODEL_METADATA, session_options_1); - char* profile_file = session_1.EndProfiling(allocator.get()); - - ASSERT_TRUE(std::string(profile_file).find("profile_prefix") != std::string::npos); - allocator->Free(profile_file); + { + auto profile_file = session_1.EndProfilingAllocated(allocator.get()); + ASSERT_TRUE(std::string(profile_file.get()).find("profile_prefix") != std::string::npos); + } // Create session with profiling disabled Ort::SessionOptions session_options_2; #ifdef _WIN32 @@ -1478,9 +1479,10 @@ TEST(CApiTest, end_profiling) { session_options_2.DisableProfiling(); #endif Ort::Session session_2(*ort_env, MODEL_WITH_CUSTOM_MODEL_METADATA, session_options_2); - profile_file = session_2.EndProfiling(allocator.get()); - ASSERT_TRUE(std::string(profile_file) == std::string()); - allocator->Free(profile_file); + { + auto profile_file = session_2.EndProfilingAllocated(allocator.get()); + ASSERT_TRUE(std::string(profile_file.get()) == std::string()); + } } TEST(CApiTest, get_profiling_start_time) { @@ -1518,50 +1520,49 @@ TEST(CApiTest, model_metadata) { // Fetch model metadata auto model_metadata = session.GetModelMetadata(); - char* producer_name = model_metadata.GetProducerName(allocator.get()); - ASSERT_TRUE(strcmp("Hari", producer_name) == 0); - allocator.get()->Free(producer_name); + { + auto producer_name = model_metadata.GetProducerNameAllocated(allocator.get()); + ASSERT_TRUE(strcmp("Hari", producer_name.get()) == 0); + } - char* graph_name = model_metadata.GetGraphName(allocator.get()); - ASSERT_TRUE(strcmp("matmul test", graph_name) == 0); - allocator.get()->Free(graph_name); + { + auto graph_name = model_metadata.GetGraphNameAllocated(allocator.get()); + ASSERT_TRUE(strcmp("matmul test", graph_name.get()) == 0); + } - char* domain = model_metadata.GetDomain(allocator.get()); - ASSERT_TRUE(strcmp("", domain) == 0); - allocator.get()->Free(domain); + { + auto domain = model_metadata.GetDomainAllocated(allocator.get()); + ASSERT_TRUE(strcmp("", domain.get()) == 0); + } - char* description = model_metadata.GetDescription(allocator.get()); - ASSERT_TRUE(strcmp("This is a test model with a valid ORT config Json", description) == 0); - allocator.get()->Free(description); + { + auto description = model_metadata.GetDescriptionAllocated(allocator.get()); + ASSERT_TRUE(strcmp("This is a test model with a valid ORT config Json", description.get()) == 0); + } - char* graph_description = model_metadata.GetGraphDescription(allocator.get()); - ASSERT_TRUE(strcmp("graph description", graph_description) == 0); - allocator.get()->Free(graph_description); + { + auto graph_description = model_metadata.GetGraphDescriptionAllocated(allocator.get()); + ASSERT_TRUE(strcmp("graph description", graph_description.get()) == 0); + } int64_t version = model_metadata.GetVersion(); ASSERT_TRUE(version == 1); - int64_t num_keys_in_custom_metadata_map; - char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(), - num_keys_in_custom_metadata_map); - ASSERT_TRUE(num_keys_in_custom_metadata_map == 2); - - allocator.get()->Free(custom_metadata_map_keys[0]); - allocator.get()->Free(custom_metadata_map_keys[1]); - allocator.get()->Free(custom_metadata_map_keys); + { + auto custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeysAllocated(allocator.get()); + ASSERT_EQ(custom_metadata_map_keys.size(), 2U); + } - char* lookup_value_1 = model_metadata.LookupCustomMetadataMap("ort_config", allocator.get()); - ASSERT_TRUE(strcmp(lookup_value_1, + auto lookup_value_1 = model_metadata.LookupCustomMetadataMapAllocated("ort_config", allocator.get()); + ASSERT_TRUE(strcmp(lookup_value_1.get(), "{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, " "\"graph_optimization_level\": 99, \"enable_profiling\": 1}}") == 0); - allocator.get()->Free(lookup_value_1); - char* lookup_value_2 = model_metadata.LookupCustomMetadataMap("dummy_key", allocator.get()); - ASSERT_TRUE(strcmp(lookup_value_2, "dummy_value") == 0); - allocator.get()->Free(lookup_value_2); + auto lookup_value_2 = model_metadata.LookupCustomMetadataMapAllocated("dummy_key", allocator.get()); + ASSERT_TRUE(strcmp(lookup_value_2.get(), "dummy_value") == 0); // key doesn't exist in custom metadata map - char* lookup_value_3 = model_metadata.LookupCustomMetadataMap("key_doesnt_exist", allocator.get()); + auto lookup_value_3 = model_metadata.LookupCustomMetadataMapAllocated("key_doesnt_exist", allocator.get()); ASSERT_TRUE(lookup_value_3 == nullptr); } @@ -1575,21 +1576,20 @@ TEST(CApiTest, model_metadata) { auto model_metadata = session.GetModelMetadata(); // Model description is empty - char* description = model_metadata.GetDescription(allocator.get()); - ASSERT_TRUE(strcmp("", description) == 0); - allocator.get()->Free(description); + { + auto description = model_metadata.GetDescriptionAllocated(allocator.get()); + ASSERT_TRUE(strcmp("", description.get()) == 0); + } // Graph description is empty - char* graph_description = model_metadata.GetGraphDescription(allocator.get()); - ASSERT_TRUE(strcmp("", graph_description) == 0); - allocator.get()->Free(graph_description); + { + auto graph_description = model_metadata.GetGraphDescriptionAllocated(allocator.get()); + ASSERT_TRUE(strcmp("", graph_description.get()) == 0); + } // Model does not contain custom metadata map - int64_t num_keys_in_custom_metadata_map; - char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(), - num_keys_in_custom_metadata_map); - ASSERT_TRUE(num_keys_in_custom_metadata_map == 0); - ASSERT_TRUE(custom_metadata_map_keys == nullptr); + auto custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeysAllocated(allocator.get()); + ASSERT_TRUE(custom_metadata_map_keys.empty()); } }